@@ -60,30 +60,18 @@ def __init__(self, dim, eps: float, elementwise_affine: bool = True, scale_facto
6060
6161# Modified from diffusers.models.autoencoders.ecae.GLUMBConv
6262@maybe_allow_in_graph
63- class SanaGLUMBConv (GLUMBConv ):
64- def __init__ (
65- self ,
66- in_channels : int ,
67- out_channels : int ,
68- kernel_size = 3 ,
69- stride = 1 ,
70- mid_channels = None ,
71- expand_ratio = 2.5 ,
72- use_bias = False ,
73- norm = (None , None , None ),
74- act_func = ("silu" , "silu" , None ),
75- ):
76- super ().__init__ (
77- in_channels = in_channels ,
78- out_channels = out_channels ,
79- kernel_size = kernel_size ,
80- stride = stride ,
81- mid_channels = mid_channels ,
82- expand_ratio = expand_ratio ,
83- use_bias = use_bias ,
84- norm = norm ,
85- act_func = act_func ,
86- )
63+ class SanaGLUMBConv (nn .Module ):
64+ def __init__ (self , in_channels : int , out_channels : int ) -> None :
65+ super ().__init__ ()
66+
67+ hidden_channels = int (2.5 * in_channels )
68+
69+ self .nonlinearity = nn .SiLU ()
70+
71+ self .conv_inverted = nn .Conv2d (in_channels , hidden_channels * 2 , 1 , 1 , 0 )
72+ self .conv_depth = nn .Conv2d (hidden_channels * 2 , hidden_channels * 2 , 3 , 1 , 1 , groups = hidden_channels * 2 )
73+ self .conv_point = nn .Conv2d (hidden_channels , out_channels , 1 , 1 , 0 , bias = False )
74+ self .norm = RMSNorm (out_channels , eps = 1e-5 , elementwise_affine = True , bias = True )
8775
8876 def forward (self , x : torch .Tensor , HW = None ) -> torch .Tensor :
8977 B , N , C = x .shape
0 commit comments