@@ -431,21 +431,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
431431 return res
432432
433433
434- class OpSequential (nn .Module ):
435- def __init__ (self , op_list : list [Optional [nn .Module ]]):
436- super ().__init__ ()
437- valid_op_list = []
438- for op in op_list :
439- if op is not None :
440- valid_op_list .append (op )
441- self .op_list = nn .ModuleList (valid_op_list )
442-
443- def forward (self , x : torch .Tensor ) -> torch .Tensor :
444- for op in self .op_list :
445- x = op (x )
446- return x
447-
448-
449434def build_block (
450435 block_type : str , in_channels : int , out_channels : int , norm : Optional [str ], act : Optional [str ]
451436) -> nn .Module :
@@ -557,21 +542,22 @@ def build_encoder_project_in_block(in_channels: int, out_channels: int, factor:
557542def build_encoder_project_out_block (
558543 in_channels : int , out_channels : int , norm : Optional [str ], act : Optional [str ], shortcut : Optional [str ]
559544):
560- block = OpSequential (
561- [
562- build_norm (norm ),
563- get_activation (act ) if act is not None else None ,
564- ConvLayer (
565- in_channels = in_channels ,
566- out_channels = out_channels ,
567- kernel_size = 3 ,
568- stride = 1 ,
569- use_bias = True ,
570- norm = None ,
571- act_func = None ,
572- ),
573- ]
574- )
545+ layers = []
546+ if norm is not None :
547+ layers .append (build_norm (norm ))
548+ if act is not None :
549+ layers .append (get_activation (act ))
550+ layers .append (ConvLayer (
551+ in_channels = in_channels ,
552+ out_channels = out_channels ,
553+ kernel_size = 3 ,
554+ stride = 1 ,
555+ use_bias = True ,
556+ norm = None ,
557+ act_func = None ,
558+ ))
559+ block = nn .Sequential (OrderedDict ([("op_list" , nn .Sequential (* layers ))]))
560+
575561 if shortcut is None :
576562 pass
577563 elif shortcut == "averaging" :
@@ -609,10 +595,12 @@ def build_decoder_project_in_block(in_channels: int, out_channels: int, shortcut
609595def build_decoder_project_out_block (
610596 in_channels : int , out_channels : int , factor : int , upsample_block_type : str , norm : Optional [str ], act : Optional [str ]
611597):
612- layers : list [nn .Module ] = [
613- build_norm (norm , in_channels ),
614- get_activation (act ) if act is not None else None ,
615- ]
598+ layers : list [nn .Module ] = []
599+ if norm is not None :
600+ layers .append (build_norm (norm , in_channels ))
601+ if act is not None :
602+ layers .append (get_activation (act ))
603+
616604 if factor == 1 :
617605 layers .append (
618606 ConvLayer (
@@ -633,7 +621,8 @@ def build_decoder_project_out_block(
633621 )
634622 else :
635623 raise ValueError (f"upsample factor { factor } is not supported for decoder project out" )
636- return OpSequential (layers )
624+ block = nn .Sequential (OrderedDict ([("op_list" , nn .Sequential (* layers ))]))
625+ return block
637626
638627
639628class Encoder (nn .Module ):
@@ -671,7 +660,7 @@ def __init__(
671660 downsample_block_type = downsample_block_type ,
672661 )
673662
674- self .stages : list [OpSequential ] = []
663+ self .stages : list [nn . Module ] = []
675664 for stage_id , (width , depth ) in enumerate (zip (width_list , depth_list )):
676665 stage_block_type = block_type [stage_id ] if isinstance (block_type , list ) else block_type
677666 stage = build_stage_main (
@@ -685,7 +674,7 @@ def __init__(
685674 shortcut = downsample_shortcut ,
686675 )
687676 stage .append (downsample_block )
688- self .stages .append (OpSequential ( stage ))
677+ self .stages .append (nn . Sequential ( OrderedDict ([( "op_list" , nn . Sequential ( * stage ))]) ))
689678 self .stages = nn .ModuleList (self .stages )
690679
691680 self .project_out = build_encoder_project_out_block (
@@ -743,7 +732,7 @@ def __init__(
743732 shortcut = in_shortcut ,
744733 )
745734
746- self .stages : list [OpSequential ] = []
735+ self .stages : list [nn . Module ] = []
747736 for stage_id , (width , depth ) in reversed (list (enumerate (zip (width_list , depth_list )))):
748737 stage = []
749738 if stage_id < num_stages - 1 and depth > 0 :
@@ -770,7 +759,7 @@ def __init__(
770759 ),
771760 )
772761 )
773- self .stages .insert (0 , OpSequential ( stage ))
762+ self .stages .insert (0 , nn . Sequential ( OrderedDict ([( "op_list" , nn . Sequential ( * stage ))]) ))
774763 self .stages = nn .ModuleList (self .stages )
775764
776765 self .project_out = build_decoder_project_out_block (
0 commit comments