@@ -111,59 +111,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
111111 return x
112112
113113
114- class UpsamplePixelShuffle (nn .Module ):
115- def __init__ (
116- self ,
117- in_channels : int ,
118- out_channels : int ,
119- kernel_size : int ,
120- factor : int ,
121- ):
122- super ().__init__ ()
123- self .factor = factor
124- out_ratio = factor ** 2
125- self .conv = DCConv2d (
126- in_channels = in_channels ,
127- out_channels = out_channels * out_ratio ,
128- kernel_size = kernel_size ,
129- use_bias = True ,
130- norm = None ,
131- act_func = None ,
132- )
133-
134- def forward (self , x : torch .Tensor ) -> torch .Tensor :
135- x = self .conv (x )
136- x = F .pixel_shuffle (x , self .factor )
137- return x
138-
139-
140- class UpsampleInterpolate (nn .Module ):
141- def __init__ (
142- self ,
143- in_channels : int ,
144- out_channels : int ,
145- kernel_size : int ,
146- factor : int ,
147- mode : str = "nearest" ,
148- ) -> None :
149- super ().__init__ ()
150- self .factor = factor
151- self .mode = mode
152- self .conv = DCConv2d (
153- in_channels = in_channels ,
154- out_channels = out_channels ,
155- kernel_size = kernel_size ,
156- use_bias = True ,
157- norm = None ,
158- act_func = None ,
159- )
160-
161- def forward (self , x : torch .Tensor ) -> torch .Tensor :
162- x = torch .nn .functional .interpolate (x , scale_factor = self .factor , mode = self .mode )
163- x = self .conv (x )
164- return x
165-
166-
167114class UpsampleChannelDuplicatingPixelUnshuffle (nn .Module ):
168115 def __init__ (
169116 self ,
@@ -184,11 +131,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
184131 return x
185132
186133
187- class IdentityLayer (nn .Module ):
188- def forward (self , x : torch .Tensor ) -> torch .Tensor :
189- return x
190-
191-
192134class GLUMBConv (nn .Module ):
193135 def __init__ (
194136 self ,
@@ -210,15 +152,15 @@ def __init__(
210152 mid_channels = round (in_channels * expand_ratio ) if mid_channels is None else mid_channels
211153
212154 self .glu_act = get_activation (act_func [1 ])
213- self .inverted_conv = DCConv2d (
155+ self .conv_inverted = DCConv2d (
214156 in_channels ,
215157 mid_channels * 2 ,
216158 1 ,
217159 use_bias = use_bias [0 ],
218160 norm = norm [0 ],
219161 act_func = act_func [0 ],
220162 )
221- self .depth_conv = DCConv2d (
163+ self .conv_depth = DCConv2d (
222164 mid_channels * 2 ,
223165 mid_channels * 2 ,
224166 kernel_size ,
@@ -228,7 +170,7 @@ def __init__(
228170 norm = norm [1 ],
229171 act_func = None ,
230172 )
231- self .point_conv = DCConv2d (
173+ self .conv_point = DCConv2d (
232174 mid_channels ,
233175 out_channels ,
234176 1 ,
@@ -238,15 +180,16 @@ def __init__(
238180 )
239181
240182 def forward (self , x : torch .Tensor ) -> torch .Tensor :
241- x = self .inverted_conv (x )
242- x = self .depth_conv (x )
183+ residual = x
184+ x = self .conv_inverted (x )
185+ x = self .conv_depth (x )
243186
244187 x , gate = torch .chunk (x , 2 , dim = 1 )
245188 gate = self .glu_act (gate )
246189 x = x * gate
247190
248- x = self .point_conv (x )
249- return x
191+ x = self .conv_point (x )
192+ return x + residual
250193
251194
252195class ResBlock (nn .Module ):
@@ -289,9 +232,10 @@ def __init__(
289232 )
290233
291234 def forward (self , x : torch .Tensor ) -> torch .Tensor :
235+ residual = x
292236 x = self .conv1 (x )
293237 x = self .conv2 (x )
294- return x
238+ return x + residual
295239
296240
297241class LiteMLA (nn .Module ):
@@ -357,7 +301,6 @@ def __init__(
357301 act_func = act_func [1 ],
358302 )
359303
360- @torch .autocast (device_type = "cuda" , enabled = False )
361304 def relu_linear_att (self , qkv : torch .Tensor ) -> torch .Tensor :
362305 B , _ , H , W = list (qkv .size ())
363306
@@ -429,6 +372,7 @@ def relu_quadratic_att(self, qkv: torch.Tensor) -> torch.Tensor:
429372 return out
430373
431374 def forward (self , x : torch .Tensor ) -> torch .Tensor :
375+ residual = x
432376 # generate multi-scale q, k, v
433377 qkv = self .qkv (x )
434378 multi_scale_qkv = [qkv ]
@@ -443,7 +387,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
443387 out = self .relu_quadratic_att (qkv )
444388 out = self .proj (out )
445389
446- return out
390+ return out + residual
447391
448392
449393class EfficientViTBlock (nn .Module ):
@@ -461,30 +405,24 @@ def __init__(
461405 ):
462406 super ().__init__ ()
463407 if context_module == "LiteMLA" :
464- self .context_module = ResidualBlock (
465- LiteMLA (
466- in_channels = in_channels ,
467- out_channels = in_channels ,
468- heads_ratio = heads_ratio ,
469- dim = dim ,
470- norm = (None , norm ),
471- scales = scales ,
472- ),
473- IdentityLayer (),
408+ self .context_module = LiteMLA (
409+ in_channels = in_channels ,
410+ out_channels = in_channels ,
411+ heads_ratio = heads_ratio ,
412+ dim = dim ,
413+ norm = (None , norm ),
414+ scales = scales ,
474415 )
475416 else :
476417 raise ValueError (f"context_module { context_module } is not supported" )
477418 if local_module == "GLUMBConv" :
478- self .local_module = ResidualBlock (
479- GLUMBConv (
480- in_channels = in_channels ,
481- out_channels = in_channels ,
482- expand_ratio = expand_ratio ,
483- use_bias = (True , True , False ),
484- norm = (None , None , norm ),
485- act_func = (act_func , act_func , None ),
486- ),
487- IdentityLayer (),
419+ self .local_module = GLUMBConv (
420+ in_channels = in_channels ,
421+ out_channels = in_channels ,
422+ expand_ratio = expand_ratio ,
423+ use_bias = (True , True , False ),
424+ norm = (None , None , norm ),
425+ act_func = (act_func , act_func , None ),
488426 )
489427 else :
490428 raise NotImplementedError (f"local_module { local_module } is not supported" )
@@ -546,7 +484,7 @@ def build_stage_main(
546484
547485 if current_block_type == "ResBlock" :
548486 assert in_channels == out_channels
549- main_block = ResBlock (
487+ block = ResBlock (
550488 in_channels = in_channels ,
551489 out_channels = out_channels ,
552490 kernel_size = 3 ,
@@ -555,7 +493,6 @@ def build_stage_main(
555493 norm = (None , norm ),
556494 act_func = (act , None ),
557495 )
558- block = ResidualBlock (main_block , IdentityLayer ())
559496 elif current_block_type == "EViT_GLU" :
560497 assert in_channels == out_channels
561498 block = EfficientViTBlock (in_channels , norm = norm , act_func = act , local_module = "GLUMBConv" , scales = ())
@@ -619,7 +556,7 @@ def __init__(
619556 self .conv = nn .Conv2d (
620557 in_channels ,
621558 out_channels ,
622- kernel_size = ( kernel_size , kernel_size ) ,
559+ kernel_size = kernel_size ,
623560 stride = self .stride ,
624561 padding = kernel_size // 2 ,
625562 )
@@ -654,32 +591,21 @@ def __init__(
654591 super ().__init__ ()
655592
656593 self .interpolate = interpolate
657- self .interpolation_method = interpolation_mode
594+ self .interpolation_mode = interpolation_mode
658595 self .factor = 2
659596 self .stride = 1
660597
661598 out_ratio = self .factor ** 2
662599 if not interpolate :
663600 out_channels = out_channels * out_ratio
664601
665- if interpolate :
666- nn .conv = DCConv2d (
667- in_channels = in_channels ,
668- out_channels = out_channels ,
669- kernel_size = kernel_size ,
670- )
671- else :
672- self .conv = DCConv2d (
673- in_channels = in_channels ,
674- out_channels = out_channels ,
675- kernel_size = kernel_size ,
676- use_bias = True ,
677- norm = None ,
678- act_func = None ,
679- )
680- self .conv = UpsamplePixelShuffle (
681- in_channels = in_channels , out_channels = out_channels , kernel_size = kernel_size , factor = 2
682- )
602+ self .conv = nn .Conv2d (
603+ in_channels ,
604+ out_channels ,
605+ kernel_size = kernel_size ,
606+ stride = self .stride ,
607+ padding = kernel_size // 2 ,
608+ )
683609
684610 self .shortcut = None
685611 if shortcut :
@@ -689,14 +615,17 @@ def __init__(
689615
690616 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
691617 if self .interpolate :
692- x = torch . nn . functional . interpolate (x , scale_factor = self .factor , mode = self .interpolation_mode )
618+ x = F . interpolate (hidden_states , scale_factor = self .factor , mode = self .interpolation_mode )
693619 x = self .conv (x )
694620 else :
695621 x = self .conv (hidden_states )
622+ x = F .pixel_shuffle (x , self .factor )
623+
696624 if self .shortcut is not None :
697625 hidden_states = x + self .shortcut (hidden_states )
698626 else :
699627 hidden_states = x
628+
700629 return hidden_states
701630
702631
@@ -770,8 +699,6 @@ def __init__(
770699 def forward (self , x : torch .Tensor ) -> torch .Tensor :
771700 x = self .conv_in (x )
772701 for stage in self .stages :
773- if len (stage .op_list ) == 0 :
774- continue
775702 x = stage (x )
776703 x = self .conv_out (x ) + self .norm_out (x )
777704 return x
@@ -858,14 +785,13 @@ def __init__(
858785 self .conv_out = DCUpBlock2d (
859786 block_out_channels [0 ] if layers_per_block [0 ] > 0 else block_out_channels [1 ],
860787 in_channels ,
788+ interpolate = upsample_block_type == "InterpolateConv" ,
861789 shortcut = False ,
862790 )
863791
864792 def forward (self , x : torch .Tensor ) -> torch .Tensor :
865793 x = self .conv_in (x ) + self .norm_in (x )
866794 for stage in reversed (self .stages ):
867- if len (stage .op_list ) == 0 :
868- continue
869795 x = stage (x )
870796 x = self .norm_out (x )
871797 x = self .conv_act (x )
0 commit comments