@@ -153,15 +153,15 @@ def __init__(
153153 )
154154
155155 def forward (self , x : torch .Tensor ) -> torch .Tensor :
156- x = self .inverted_conv (x )
157- x = self .depth_conv (x )
156+ y = self .inverted_conv (x )
157+ y = self .depth_conv (y )
158158
159- x , gate = torch .chunk (x , 2 , dim = 1 )
159+ y , gate = torch .chunk (y , 2 , dim = 1 )
160160 gate = self .glu_act (gate )
161- x = x * gate
161+ y = y * gate
162162
163- x = self .point_conv (x )
164- return x
163+ y = self .point_conv (y )
164+ return x + y
165165
166166
167167class ResBlock (nn .Module ):
@@ -349,7 +349,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
349349 out = self .relu_quadratic_att (qkv )
350350 out = self .proj (out )
351351
352- return out
352+ return x + out
353353
354354
355355class EfficientViTBlock (nn .Module ):
@@ -367,30 +367,24 @@ def __init__(
367367 ):
368368 super ().__init__ ()
369369 if context_module == "LiteMLA" :
370- self .context_module = ResidualBlock (
371- LiteMLA (
372- in_channels = in_channels ,
373- out_channels = in_channels ,
374- heads_ratio = heads_ratio ,
375- dim = dim ,
376- norm = (None , norm ),
377- scales = scales ,
378- ),
379- nn .Identity (),
370+ self .context_module = LiteMLA (
371+ in_channels = in_channels ,
372+ out_channels = in_channels ,
373+ heads_ratio = heads_ratio ,
374+ dim = dim ,
375+ norm = (None , norm ),
376+ scales = scales ,
380377 )
381378 else :
382379 raise ValueError (f"context_module { context_module } is not supported" )
383380 if local_module == "GLUMBConv" :
384- self .local_module = ResidualBlock (
385- GLUMBConv (
386- in_channels = in_channels ,
387- out_channels = in_channels ,
388- expand_ratio = expand_ratio ,
389- use_bias = (True , True , False ),
390- norm = (None , None , norm ),
391- act_func = (act_func , act_func , None ),
392- ),
393- nn .Identity (),
381+ self .local_module = GLUMBConv (
382+ in_channels = in_channels ,
383+ out_channels = in_channels ,
384+ expand_ratio = expand_ratio ,
385+ use_bias = (True , True , False ),
386+ norm = (None , None , norm ),
387+ act_func = (act_func , act_func , None ),
394388 )
395389 else :
396390 raise NotImplementedError (f"local_module { local_module } is not supported" )
0 commit comments