@@ -439,41 +439,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
439439 return res
440440
441441
442- def build_stage_main (
443- width : int , depth : int , block_type : str | list [str ], norm : str , act : str , input_width : int
444- ) -> list [nn .Module ]:
445- assert isinstance (block_type , str ) or (isinstance (block_type , list ) and depth == len (block_type ))
446- stage = []
447- for d in range (depth ):
448- current_block_type = block_type [d ] if isinstance (block_type , list ) else block_type
449-
450- in_channels = width if d > 0 else input_width
451- out_channels = width
452-
453- if current_block_type == "ResBlock" :
454- assert in_channels == out_channels
455- block = ResBlock (
456- in_channels = in_channels ,
457- out_channels = out_channels ,
458- kernel_size = 3 ,
459- stride = 1 ,
460- use_bias = (True , False ),
461- norm = (None , norm ),
462- act_func = (act , None ),
463- )
464- elif current_block_type == "EViTGLU" :
465- assert in_channels == out_channels
466- block = EfficientViTBlock (in_channels , norm = norm , act_func = act , local_module = "GLUMBConv" , scales = ())
467- elif current_block_type == "EViTS5GLU" :
468- assert in_channels == out_channels
469- block = EfficientViTBlock (in_channels , norm = norm , act_func = act , local_module = "GLUMBConv" , scales = (5 ,))
470- else :
471- raise ValueError (f"block_type { current_block_type } is not supported" )
472-
473- stage .append (block )
474- return stage
475-
476-
477442class Encoder (nn .Module ):
478443 def __init__ (
479444 self ,
@@ -485,7 +450,6 @@ def __init__(
485450 norm : str = "rms2d" ,
486451 act : str = "silu" ,
487452 downsample_block_type : str = "ConvPixelUnshuffle" ,
488- downsample_match_channel : bool = True ,
489453 downsample_shortcut : Optional [str ] = "averaging" ,
490454 out_norm : Optional [str ] = None ,
491455 out_act : Optional [str ] = None ,
@@ -533,11 +497,32 @@ def __init__(
533497 self .stages : list [nn .Module ] = []
534498 for stage_id , (width , depth ) in enumerate (zip (width_list , depth_list )):
535499 stage_block_type = block_type [stage_id ] if isinstance (block_type , list ) else block_type
536- stage = build_stage_main (
537- width = width , depth = depth , block_type = stage_block_type , norm = norm , act = act , input_width = width
538- )
500+ if not (isinstance (stage_block_type , str ) or (isinstance (stage_block_type , list ) and depth == len (stage_block_type ))):
501+ raise ValueError (f"block type { stage_block_type } is not supported for encoder stage { stage_id } with depth { depth } " )
502+ stage = []
503+ # stage main
504+ for d in range (depth ):
505+ current_block_type = stage_block_type [d ] if isinstance (stage_block_type , list ) else stage_block_type
506+ if current_block_type == "ResBlock" :
507+ block = ResBlock (
508+ in_channels = width ,
509+ out_channels = width ,
510+ kernel_size = 3 ,
511+ stride = 1 ,
512+ use_bias = (True , False ),
513+ norm = (None , norm ),
514+ act_func = (act , None ),
515+ )
516+ elif current_block_type == "EViTGLU" :
517+ block = EfficientViTBlock (width , norm = norm , act_func = act , local_module = "GLUMBConv" , scales = ())
518+ elif current_block_type == "EViTS5GLU" :
519+ block = EfficientViTBlock (width , norm = norm , act_func = act , local_module = "GLUMBConv" , scales = (5 ,))
520+ else :
521+ raise ValueError (f"block type { current_block_type } is not supported" )
522+ stage .append (block )
523+ # downsample
539524 if stage_id < num_stages - 1 and depth > 0 :
540- downsample_out_channels = width_list [stage_id + 1 ] if downsample_match_channel else width
525+ downsample_out_channels = width_list [stage_id + 1 ]
541526 if downsample_block_type == "Conv" :
542527 downsample_block = nn .Conv2d (
543528 in_channels = width ,
@@ -621,7 +606,6 @@ def __init__(
621606 norm : str | list [str ] = "rms2d" ,
622607 act : str | list [str ] = "silu" ,
623608 upsample_block_type : str = "ConvPixelShuffle" ,
624- upsample_match_channel : bool = True ,
625609 upsample_shortcut : str = "duplicating" ,
626610 out_norm : str = "rms2d" ,
627611 out_act : str = "relu" ,
@@ -665,8 +649,9 @@ def __init__(
665649 self .stages : list [nn .Module ] = []
666650 for stage_id , (width , depth ) in reversed (list (enumerate (zip (width_list , depth_list )))):
667651 stage = []
652+ # upsample
668653 if stage_id < num_stages - 1 and depth > 0 :
669- upsample_out_channels = width if upsample_match_channel else width_list [ stage_id + 1 ]
654+ upsample_out_channels = width
670655 if upsample_block_type == "ConvPixelShuffle" :
671656 upsample_block = ConvPixelShuffleUpsample2D (
672657 in_channels = width_list [stage_id + 1 ], out_channels = upsample_out_channels , kernel_size = 3 , factor = 2
@@ -685,22 +670,30 @@ def __init__(
685670 else :
686671 raise ValueError (f"shortcut { upsample_shortcut } is not supported for upsample" )
687672 stage .append (upsample_block )
688-
673+ # stage main
689674 stage_block_type = block_type [stage_id ] if isinstance (block_type , list ) else block_type
690675 stage_norm = norm [stage_id ] if isinstance (norm , list ) else norm
691676 stage_act = act [stage_id ] if isinstance (act , list ) else act
692- stage .extend (
693- build_stage_main (
694- width = width ,
695- depth = depth ,
696- block_type = stage_block_type ,
697- norm = stage_norm ,
698- act = stage_act ,
699- input_width = (
700- width if upsample_match_channel else width_list [min (stage_id + 1 , num_stages - 1 )]
701- ),
702- )
703- )
677+ for d in range (depth ):
678+ current_block_type = stage_block_type [d ] if isinstance (stage_block_type , list ) else stage_block_type
679+ if current_block_type == "ResBlock" :
680+ block = ResBlock (
681+ in_channels = width ,
682+ out_channels = width ,
683+ kernel_size = 3 ,
684+ stride = 1 ,
685+ use_bias = (True , False ),
686+ norm = (None , stage_norm ),
687+ act_func = (stage_act , None ),
688+ )
689+ elif current_block_type == "EViTGLU" :
690+ block = EfficientViTBlock (width , norm = stage_norm , act_func = stage_act , local_module = "GLUMBConv" , scales = ())
691+ elif current_block_type == "EViTS5GLU" :
692+ block = EfficientViTBlock (width , norm = stage_norm , act_func = stage_act , local_module = "GLUMBConv" , scales = (5 ,))
693+ else :
694+ raise ValueError (f"block type { current_block_type } is not supported" )
695+ stage .append (block )
696+
704697 self .stages .insert (0 , nn .Sequential (* stage ))
705698 self .stages = nn .ModuleList (self .stages )
706699
0 commit comments