@@ -73,7 +73,14 @@ def __init__(
7373 groups = groups ,
7474 bias = use_bias ,
7575 )
76- self .norm = build_norm (norm , num_features = out_channels )
76+ if norm is None :
77+ self .norm = None
78+ elif norm == "rms2d" :
79+ self .norm = RMSNorm2d (normalized_shape = out_channels )
80+ elif norm == "bn2d" :
81+ self .norm = BatchNorm2d (num_features = out_channels )
82+ else :
83+ raise ValueError (f"norm { norm } is not supported" )
7784 self .act = get_activation (act_func ) if act_func is not None else None
7885
7986 def forward (self , x : torch .Tensor ) -> torch .Tensor :
@@ -532,9 +539,17 @@ def build_encoder_project_in_block(in_channels: int, out_channels: int, factor:
532539def build_encoder_project_out_block (
533540 in_channels : int , out_channels : int , norm : Optional [str ], act : Optional [str ], shortcut : Optional [str ]
534541):
535- layers = []
536- if norm is not None :
537- layers .append (build_norm (norm ))
542+ layers : list [nn .Module ] = []
543+
544+ if norm is None :
545+ pass
546+ elif norm == "rms2d" :
547+ layers .append (RMSNorm2d (normalized_shape = in_channels ))
548+ elif norm == "bn2d" :
549+ layers .append (BatchNorm2d (num_features = in_channels ))
550+ else :
551+ raise ValueError (f"norm { norm } is not supported" )
552+
538553 if act is not None :
539554 layers .append (get_activation (act ))
540555 layers .append (ConvLayer (
@@ -586,8 +601,16 @@ def build_decoder_project_out_block(
586601 in_channels : int , out_channels : int , factor : int , upsample_block_type : str , norm : Optional [str ], act : Optional [str ]
587602):
588603 layers : list [nn .Module ] = []
589- if norm is not None :
590- layers .append (build_norm (norm , in_channels ))
604+
605+ if norm is None :
606+ pass
607+ elif norm == "rms2d" :
608+ layers .append (RMSNorm2d (normalized_shape = in_channels ))
609+ elif norm == "bn2d" :
610+ layers .append (BatchNorm2d (num_features = in_channels ))
611+ else :
612+ raise ValueError (f"norm { norm } is not supported" )
613+
591614 if act is not None :
592615 layers .append (get_activation (act ))
593616
0 commit comments