3333from .vae import DecoderOutput
3434
3535
36- class RMSNorm2d (nn .LayerNorm ):
36+ class RMSNorm2d (nn .Module ):
37+ def __init__ (self , num_features : int , eps : float = 1e-5 , elementwise_affine : bool = True , bias : bool = True , device = None , dtype = None ) -> None :
38+ factory_kwargs = {'device' : device , 'dtype' : dtype }
39+ super ().__init__ ()
40+ self .num_features = num_features
41+ self .eps = eps
42+ self .elementwise_affine = elementwise_affine
43+ if self .elementwise_affine :
44+ self .weight = torch .nn .parameter .Parameter (torch .empty (self .num_features , ** factory_kwargs ))
45+ if bias :
46+ self .bias = torch .nn .parameter .Parameter (torch .empty (self .num_features , ** factory_kwargs ))
47+ else :
48+ self .register_parameter ('bias' , None )
49+ else :
50+ self .register_parameter ('weight' , None )
51+ self .register_parameter ('bias' , None )
52+
53+ self .reset_parameters ()
54+
55+ def reset_parameters (self ) -> None :
56+ if self .elementwise_affine :
57+ torch .nn .init .ones_ (self .weight )
58+ if self .bias is not None :
59+ torch .nn .init .zeros_ (self .bias )
60+
3761 def forward (self , x : torch .Tensor ) -> torch .Tensor :
3862 x = (x / torch .sqrt (torch .square (x .float ()).mean (dim = 1 , keepdim = True ) + self .eps )).to (x .dtype )
3963 if self .elementwise_affine :
@@ -74,7 +98,7 @@ def __init__(
7498 if norm is None :
7599 self .norm = None
76100 elif norm == "rms2d" :
77- self .norm = RMSNorm2d (normalized_shape = out_channels )
101+ self .norm = RMSNorm2d (num_features = out_channels )
78102 elif norm == "bn2d" :
79103 self .norm = BatchNorm2d (num_features = out_channels )
80104 else :
@@ -469,54 +493,6 @@ def build_stage_main(
469493 return stage
470494
471495
472- def build_downsample_block (block_type : str , in_channels : int , out_channels : int , shortcut : Optional [str ]) -> nn .Module :
473- if block_type == "Conv" :
474- block = nn .Conv2d (
475- in_channels = in_channels ,
476- out_channels = out_channels ,
477- kernel_size = 3 ,
478- stride = 2 ,
479- padding = 1 ,
480- )
481- elif block_type == "ConvPixelUnshuffle" :
482- block = ConvPixelUnshuffleDownsample2D (
483- in_channels = in_channels , out_channels = out_channels , kernel_size = 3 , factor = 2
484- )
485- else :
486- raise ValueError (f"block_type { block_type } is not supported for downsampling" )
487- if shortcut is None :
488- pass
489- elif shortcut == "averaging" :
490- shortcut_block = PixelUnshuffleChannelAveragingDownsample2D (
491- in_channels = in_channels , out_channels = out_channels , factor = 2
492- )
493- block = ResidualBlock (block , shortcut_block )
494- else :
495- raise ValueError (f"shortcut { shortcut } is not supported for downsample" )
496- return block
497-
498-
499- def build_upsample_block (block_type : str , in_channels : int , out_channels : int , shortcut : Optional [str ]) -> nn .Module :
500- if block_type == "ConvPixelShuffle" :
501- block = ConvPixelShuffleUpsample2D (
502- in_channels = in_channels , out_channels = out_channels , kernel_size = 3 , factor = 2
503- )
504- elif block_type == "InterpolateConv" :
505- block = Upsample2D (channels = in_channels , use_conv = True , out_channels = out_channels )
506- else :
507- raise ValueError (f"block_type { block_type } is not supported for upsampling" )
508- if shortcut is None :
509- pass
510- elif shortcut == "duplicating" :
511- shortcut_block = ChannelDuplicatingPixelUnshuffleUpsample2D (
512- in_channels = in_channels , out_channels = out_channels , factor = 2
513- )
514- block = ResidualBlock (block , shortcut_block )
515- else :
516- raise ValueError (f"shortcut { shortcut } is not supported for upsample" )
517- return block
518-
519-
520496class Encoder (nn .Module ):
521497 def __init__ (
522498 self ,
@@ -547,18 +523,30 @@ def __init__(
547523
548524 # project in
549525 if depth_list [0 ] > 0 :
550- self . project_in = nn .Conv2d (
526+ project_in_block = nn .Conv2d (
551527 in_channels = in_channels ,
552528 out_channels = width_list [0 ],
553529 kernel_size = 3 ,
554530 padding = 1 ,
555531 )
556532 elif depth_list [1 ] > 0 :
557- self .project_in = build_downsample_block (
558- block_type = downsample_block_type , in_channels = in_channels , out_channels = width_list [1 ], shortcut = None
559- )
533+ if downsample_block_type == "Conv" :
534+ project_in_block = nn .Conv2d (
535+ in_channels = in_channels ,
536+ out_channels = width_list [1 ],
537+ kernel_size = 3 ,
538+ stride = 2 ,
539+ padding = 1 ,
540+ )
541+ elif downsample_block_type == "ConvPixelUnshuffle" :
542+ project_in_block = ConvPixelUnshuffleDownsample2D (
543+ in_channels = in_channels , out_channels = width_list [1 ], kernel_size = 3 , factor = 2
544+ )
545+ else :
546+ raise ValueError (f"block_type { downsample_block_type } is not supported for downsampling" )
560547 else :
561548 raise ValueError (f"depth list { depth_list } is not supported for encoder project in" )
549+ self .project_in = project_in_block
562550
563551 # stages
564552 self .stages : list [nn .Module ] = []
@@ -568,12 +556,30 @@ def __init__(
568556 width = width , depth = depth , block_type = stage_block_type , norm = norm , act = act , input_width = width
569557 )
570558 if stage_id < num_stages - 1 and depth > 0 :
571- downsample_block = build_downsample_block (
572- block_type = downsample_block_type ,
573- in_channels = width ,
574- out_channels = width_list [stage_id + 1 ] if downsample_match_channel else width ,
575- shortcut = downsample_shortcut ,
576- )
559+ downsample_out_channels = width_list [stage_id + 1 ] if downsample_match_channel else width
560+ if downsample_block_type == "Conv" :
561+ downsample_block = nn .Conv2d (
562+ in_channels = width ,
563+ out_channels = downsample_out_channels ,
564+ kernel_size = 3 ,
565+ stride = 2 ,
566+ padding = 1 ,
567+ )
568+ elif downsample_block_type == "ConvPixelUnshuffle" :
569+ downsample_block = ConvPixelUnshuffleDownsample2D (
570+ in_channels = width , out_channels = downsample_out_channels , kernel_size = 3 , factor = 2
571+ )
572+ else :
573+ raise ValueError (f"downsample_block_type { downsample_block_type } is not supported for downsampling" )
574+ if downsample_shortcut is None :
575+ pass
576+ elif downsample_shortcut == "averaging" :
577+ shortcut_block = PixelUnshuffleChannelAveragingDownsample2D (
578+ in_channels = width , out_channels = downsample_out_channels , factor = 2
579+ )
580+ downsample_block = ResidualBlock (downsample_block , shortcut_block )
581+ else :
582+ raise ValueError (f"shortcut { downsample_shortcut } is not supported for downsample" )
577583 stage .append (downsample_block )
578584 self .stages .append (nn .Sequential (OrderedDict ([("op_list" , nn .Sequential (* stage ))])))
579585 self .stages = nn .ModuleList (self .stages )
@@ -583,7 +589,7 @@ def __init__(
583589 if out_norm is None :
584590 pass
585591 elif out_norm == "rms2d" :
586- project_out_layers .append (RMSNorm2d (normalized_shape = width_list [- 1 ]))
592+ project_out_layers .append (RMSNorm2d (num_features = width_list [- 1 ]))
587593 elif out_norm == "bn2d" :
588594 project_out_layers .append (BatchNorm2d (num_features = width_list [- 1 ]))
589595 else :
@@ -679,12 +685,24 @@ def __init__(
679685 for stage_id , (width , depth ) in reversed (list (enumerate (zip (width_list , depth_list )))):
680686 stage = []
681687 if stage_id < num_stages - 1 and depth > 0 :
682- upsample_block = build_upsample_block (
683- block_type = upsample_block_type ,
684- in_channels = width_list [stage_id + 1 ],
685- out_channels = width if upsample_match_channel else width_list [stage_id + 1 ],
686- shortcut = upsample_shortcut ,
687- )
688+ upsample_out_channels = width if upsample_match_channel else width_list [stage_id + 1 ]
689+ if upsample_block_type == "ConvPixelShuffle" :
690+ upsample_block = ConvPixelShuffleUpsample2D (
691+ in_channels = width_list [stage_id + 1 ], out_channels = upsample_out_channels , kernel_size = 3 , factor = 2
692+ )
693+ elif upsample_block_type == "InterpolateConv" :
694+ upsample_block = Upsample2D (channels = width_list [stage_id + 1 ], use_conv = True , out_channels = upsample_out_channels )
695+ else :
696+ raise ValueError (f"upsample_block_type { upsample_block_type } is not supported" )
697+ if upsample_shortcut is None :
698+ pass
699+ elif upsample_shortcut == "duplicating" :
700+ shortcut_block = ChannelDuplicatingPixelUnshuffleUpsample2D (
701+ in_channels = width_list [stage_id + 1 ], out_channels = upsample_out_channels , factor = 2
702+ )
703+ upsample_block = ResidualBlock (upsample_block , shortcut_block )
704+ else :
705+ raise ValueError (f"shortcut { upsample_shortcut } is not supported for upsample" )
688706 stage .append (upsample_block )
689707
690708 stage_block_type = block_type [stage_id ] if isinstance (block_type , list ) else block_type
@@ -716,7 +734,7 @@ def __init__(
716734 if out_norm is None :
717735 pass
718736 elif out_norm == "rms2d" :
719- project_out_layers .append (RMSNorm2d (normalized_shape = project_out_in_channels ))
737+ project_out_layers .append (RMSNorm2d (num_features = project_out_in_channels ))
720738 elif out_norm == "bn2d" :
721739 project_out_layers .append (BatchNorm2d (num_features = project_out_in_channels ))
722740 else :
@@ -735,11 +753,16 @@ def __init__(
735753 )
736754 )
737755 elif depth_list [1 ] > 0 :
738- project_out_layers . append (
739- build_upsample_block (
740- block_type = upsample_block_type , in_channels = project_out_in_channels , out_channels = in_channels , shortcut = None
756+ if upsample_block_type == "ConvPixelShuffle" :
757+ project_out_conv = ConvPixelShuffleUpsample2D (
758+ in_channels = project_out_in_channels , out_channels = in_channels , kernel_size = 3 , factor = 2
741759 )
742- )
760+ elif upsample_block_type == "InterpolateConv" :
761+ project_out_conv = Upsample2D (channels = project_out_in_channels , use_conv = True , out_channels = in_channels )
762+ else :
763+ raise ValueError (f"upsample_block_type { upsample_block_type } is not supported for upsampling" )
764+
765+ project_out_layers .append (project_out_conv )
743766 else :
744767 raise ValueError (f"depth list { depth_list } is not supported for decoder project out" )
745768 self .project_out = nn .Sequential (OrderedDict ([("op_list" , nn .Sequential (* project_out_layers ))]))
0 commit comments