1414# limitations under the License.
1515
1616from typing import Any , Optional , Callable , Union
17+ from collections import OrderedDict
1718
1819import torch
1920import torch .nn as nn
@@ -169,28 +170,30 @@ def __init__(
169170 super ().__init__ ()
170171 mid_channels = round (in_channels * expand_ratio ) if mid_channels is None else mid_channels
171172
172- self .conv1 = ConvLayer (
173- in_channels ,
174- mid_channels ,
175- kernel_size ,
176- stride ,
177- use_bias = use_bias [0 ],
178- norm = norm [0 ],
179- act_func = act_func [0 ],
180- )
181- self .conv2 = ConvLayer (
182- mid_channels ,
183- out_channels ,
184- kernel_size ,
185- 1 ,
186- use_bias = use_bias [1 ],
187- norm = norm [1 ],
188- act_func = act_func [1 ],
189- )
173+ self .main = nn .Sequential (OrderedDict ([
174+ ("conv1" , ConvLayer (
175+ in_channels ,
176+ mid_channels ,
177+ kernel_size ,
178+ stride ,
179+ use_bias = use_bias [0 ],
180+ norm = norm [0 ],
181+ act_func = act_func [0 ],
182+ )),
183+ ("conv2" , ConvLayer (
184+ mid_channels ,
185+ out_channels ,
186+ kernel_size ,
187+ 1 ,
188+ use_bias = use_bias [1 ],
189+ norm = norm [1 ],
190+ act_func = act_func [1 ],
191+ )),
192+ ]))
193+ self .shortcut = nn .Identity ()
190194
191195 def forward (self , x : torch .Tensor ) -> torch .Tensor :
192- x = self .conv1 (x )
193- x = self .conv2 (x )
196+ x = self .main (x ) + self .shortcut (x )
194197 return x
195198
196199
@@ -448,7 +451,7 @@ def build_block(
448451) -> nn .Module :
449452 if block_type == "ResBlock" :
450453 assert in_channels == out_channels
451- main_block = ResBlock (
454+ block = ResBlock (
452455 in_channels = in_channels ,
453456 out_channels = out_channels ,
454457 kernel_size = 3 ,
@@ -457,7 +460,6 @@ def build_block(
457460 norm = (None , norm ),
458461 act_func = (act , None ),
459462 )
460- block = ResidualBlock (main_block , nn .Identity ())
461463 elif block_type == "EViTGLU" :
462464 assert in_channels == out_channels
463465 block = EfficientViTBlock (in_channels , norm = norm , act_func = act , local_module = "GLUMBConv" , scales = ())
0 commit comments