@@ -41,8 +41,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4141 return x
4242
4343
44-
45-
4644class ConvLayer (nn .Module ):
4745 def __init__ (
4846 self ,
@@ -519,125 +517,6 @@ def build_upsample_block(block_type: str, in_channels: int, out_channels: int, s
519517 return block
520518
521519
522- def build_encoder_project_in_block (in_channels : int , out_channels : int , factor : int , downsample_block_type : str ):
523- if factor == 1 :
524- block = nn .Conv2d (
525- in_channels = in_channels ,
526- out_channels = out_channels ,
527- kernel_size = 3 ,
528- padding = 1 ,
529- )
530- elif factor == 2 :
531- block = build_downsample_block (
532- block_type = downsample_block_type , in_channels = in_channels , out_channels = out_channels , shortcut = None
533- )
534- else :
535- raise ValueError (f"downsample factor { factor } is not supported for encoder project in" )
536- return block
537-
538-
539- def build_encoder_project_out_block (
540- in_channels : int , out_channels : int , norm : Optional [str ], act : Optional [str ], shortcut : Optional [str ]
541- ):
542- layers : list [nn .Module ] = []
543-
544- if norm is None :
545- pass
546- elif norm == "rms2d" :
547- layers .append (RMSNorm2d (normalized_shape = in_channels ))
548- elif norm == "bn2d" :
549- layers .append (BatchNorm2d (num_features = in_channels ))
550- else :
551- raise ValueError (f"norm { norm } is not supported" )
552-
553- if act is not None :
554- layers .append (get_activation (act ))
555- layers .append (ConvLayer (
556- in_channels = in_channels ,
557- out_channels = out_channels ,
558- kernel_size = 3 ,
559- stride = 1 ,
560- use_bias = True ,
561- norm = None ,
562- act_func = None ,
563- ))
564- block = nn .Sequential (OrderedDict ([("op_list" , nn .Sequential (* layers ))]))
565-
566- if shortcut is None :
567- pass
568- elif shortcut == "averaging" :
569- shortcut_block = PixelUnshuffleChannelAveragingDownsample2D (
570- in_channels = in_channels , out_channels = out_channels , factor = 1
571- )
572- block = ResidualBlock (block , shortcut_block )
573- else :
574- raise ValueError (f"shortcut { shortcut } is not supported for encoder project out" )
575- return block
576-
577-
578- def build_decoder_project_in_block (in_channels : int , out_channels : int , shortcut : Optional [str ]):
579- block = ConvLayer (
580- in_channels = in_channels ,
581- out_channels = out_channels ,
582- kernel_size = 3 ,
583- stride = 1 ,
584- use_bias = True ,
585- norm = None ,
586- act_func = None ,
587- )
588- if shortcut is None :
589- pass
590- elif shortcut == "duplicating" :
591- shortcut_block = ChannelDuplicatingPixelUnshuffleUpsample2D (
592- in_channels = in_channels , out_channels = out_channels , factor = 1
593- )
594- block = ResidualBlock (block , shortcut_block )
595- else :
596- raise ValueError (f"shortcut { shortcut } is not supported for decoder project in" )
597- return block
598-
599-
600- def build_decoder_project_out_block (
601- in_channels : int , out_channels : int , factor : int , upsample_block_type : str , norm : Optional [str ], act : Optional [str ]
602- ):
603- layers : list [nn .Module ] = []
604-
605- if norm is None :
606- pass
607- elif norm == "rms2d" :
608- layers .append (RMSNorm2d (normalized_shape = in_channels ))
609- elif norm == "bn2d" :
610- layers .append (BatchNorm2d (num_features = in_channels ))
611- else :
612- raise ValueError (f"norm { norm } is not supported" )
613-
614- if act is not None :
615- layers .append (get_activation (act ))
616-
617- if factor == 1 :
618- layers .append (
619- ConvLayer (
620- in_channels = in_channels ,
621- out_channels = out_channels ,
622- kernel_size = 3 ,
623- stride = 1 ,
624- use_bias = True ,
625- norm = None ,
626- act_func = None ,
627- )
628- )
629- elif factor == 2 :
630- layers .append (
631- build_upsample_block (
632- block_type = upsample_block_type , in_channels = in_channels , out_channels = out_channels , shortcut = None
633- )
634- )
635- else :
636- raise ValueError (f"upsample factor { factor } is not supported for decoder project out" )
637- block = nn .Sequential (OrderedDict ([("op_list" , nn .Sequential (* layers ))]))
638- return block
639-
640-
641520class Encoder (nn .Module ):
642521 def __init__ (
643522 self ,
@@ -665,14 +544,23 @@ def __init__(
665544 raise ValueError (f"len(depth_list) { len (depth_list )} and len(width_list) { len (width_list )} should be equal to num_stages { num_stages } " )
666545 if not isinstance (block_type , (str , list )) or (isinstance (block_type , list ) and len (block_type ) != num_stages ):
667546 raise ValueError (f"block_type should be either a str or a list of str with length { num_stages } , but got { block_type } " )
547+
548+ # project in
549+ if depth_list [0 ] > 0 :
550+ self .project_in = nn .Conv2d (
551+ in_channels = in_channels ,
552+ out_channels = width_list [0 ],
553+ kernel_size = 3 ,
554+ padding = 1 ,
555+ )
556+ elif depth_list [1 ] > 0 :
557+ self .project_in = build_downsample_block (
558+ block_type = downsample_block_type , in_channels = in_channels , out_channels = width_list [1 ], shortcut = None
559+ )
560+ else :
561+ raise ValueError (f"depth list { depth_list } is not supported for encoder project in" )
668562
669- self .project_in = build_encoder_project_in_block (
670- in_channels = in_channels ,
671- out_channels = width_list [0 ] if depth_list [0 ] > 0 else width_list [1 ],
672- factor = 1 if depth_list [0 ] > 0 else 2 ,
673- downsample_block_type = downsample_block_type ,
674- )
675-
563+ # stages
676564 self .stages : list [nn .Module ] = []
677565 for stage_id , (width , depth ) in enumerate (zip (width_list , depth_list )):
678566 stage_block_type = block_type [stage_id ] if isinstance (block_type , list ) else block_type
@@ -690,13 +578,39 @@ def __init__(
690578 self .stages .append (nn .Sequential (OrderedDict ([("op_list" , nn .Sequential (* stage ))])))
691579 self .stages = nn .ModuleList (self .stages )
692580
693- self .project_out = build_encoder_project_out_block (
581+ # project out
582+ project_out_layers : list [nn .Module ] = []
583+ if out_norm is None :
584+ pass
585+ elif out_norm == "rms2d" :
586+ project_out_layers .append (RMSNorm2d (normalized_shape = width_list [- 1 ]))
587+ elif out_norm == "bn2d" :
588+ project_out_layers .append (BatchNorm2d (num_features = width_list [- 1 ]))
589+ else :
590+ raise ValueError (f"norm { out_norm } is not supported for encoder project out" )
591+ if out_act is not None :
592+ project_out_layers .append (get_activation (out_act ))
593+ project_out_out_channels = 2 * latent_channels if double_latent else latent_channels
594+ project_out_layers .append (ConvLayer (
694595 in_channels = width_list [- 1 ],
695- out_channels = 2 * latent_channels if double_latent else latent_channels ,
696- norm = out_norm ,
697- act = out_act ,
698- shortcut = out_shortcut ,
699- )
596+ out_channels = project_out_out_channels ,
597+ kernel_size = 3 ,
598+ stride = 1 ,
599+ use_bias = True ,
600+ norm = None ,
601+ act_func = None ,
602+ ))
603+ project_out_block = nn .Sequential (OrderedDict ([("op_list" , nn .Sequential (* project_out_layers ))]))
604+ if out_shortcut is None :
605+ pass
606+ elif out_shortcut == "averaging" :
607+ shortcut_block = PixelUnshuffleChannelAveragingDownsample2D (
608+ in_channels = width_list [- 1 ], out_channels = project_out_out_channels , factor = 1
609+ )
610+ project_out_block = ResidualBlock (project_out_block , shortcut_block )
611+ else :
612+ raise ValueError (f"shortcut { out_shortcut } is not supported for encoder project out" )
613+ self .project_out = project_out_block
700614
701615 def forward (self , x : torch .Tensor ) -> torch .Tensor :
702616 x = self .project_in (x )
@@ -739,12 +653,28 @@ def __init__(
739653 if not isinstance (act , (str , list )) or (isinstance (act , list ) and len (act ) != num_stages ):
740654 raise ValueError (f"act should be either a str or a list of str with length { num_stages } , but got { act } " )
741655
742- self .project_in = build_decoder_project_in_block (
656+ # project in
657+ project_in_block = ConvLayer (
743658 in_channels = latent_channels ,
744659 out_channels = width_list [- 1 ],
745- shortcut = in_shortcut ,
660+ kernel_size = 3 ,
661+ stride = 1 ,
662+ use_bias = True ,
663+ norm = None ,
664+ act_func = None ,
746665 )
666+ if in_shortcut is None :
667+ pass
668+ elif in_shortcut == "duplicating" :
669+ shortcut_block = ChannelDuplicatingPixelUnshuffleUpsample2D (
670+ in_channels = latent_channels , out_channels = width_list [- 1 ], factor = 1
671+ )
672+ project_in_block = ResidualBlock (project_in_block , shortcut_block )
673+ else :
674+ raise ValueError (f"shortcut { in_shortcut } is not supported for decoder project in" )
675+ self .project_in = project_in_block
747676
677+ # stages
748678 self .stages : list [nn .Module ] = []
749679 for stage_id , (width , depth ) in reversed (list (enumerate (zip (width_list , depth_list )))):
750680 stage = []
@@ -775,14 +705,44 @@ def __init__(
775705 self .stages .insert (0 , nn .Sequential (OrderedDict ([("op_list" , nn .Sequential (* stage ))])))
776706 self .stages = nn .ModuleList (self .stages )
777707
778- self .project_out = build_decoder_project_out_block (
779- in_channels = width_list [0 ] if depth_list [0 ] > 0 else width_list [1 ],
780- out_channels = in_channels ,
781- factor = 1 if depth_list [0 ] > 0 else 2 ,
782- upsample_block_type = upsample_block_type ,
783- norm = out_norm ,
784- act = out_act ,
785- )
708+ # project out
709+ project_out_layers : list [nn .Module ] = []
710+ if depth_list [0 ] > 0 :
711+ project_out_in_channels = width_list [0 ]
712+ elif depth_list [1 ] > 0 :
713+ project_out_in_channels = width_list [1 ]
714+ else :
715+ raise ValueError (f"depth list { depth_list } is not supported for decoder project out" )
716+ if out_norm is None :
717+ pass
718+ elif out_norm == "rms2d" :
719+ project_out_layers .append (RMSNorm2d (normalized_shape = project_out_in_channels ))
720+ elif out_norm == "bn2d" :
721+ project_out_layers .append (BatchNorm2d (num_features = project_out_in_channels ))
722+ else :
723+ raise ValueError (f"norm { out_norm } is not supported for decoder project out" )
724+ project_out_layers .append (get_activation (out_act ))
725+ if depth_list [0 ] > 0 :
726+ project_out_layers .append (
727+ ConvLayer (
728+ in_channels = project_out_in_channels ,
729+ out_channels = in_channels ,
730+ kernel_size = 3 ,
731+ stride = 1 ,
732+ use_bias = True ,
733+ norm = None ,
734+ act_func = None ,
735+ )
736+ )
737+ elif depth_list [1 ] > 0 :
738+ project_out_layers .append (
739+ build_upsample_block (
740+ block_type = upsample_block_type , in_channels = project_out_in_channels , out_channels = in_channels , shortcut = None
741+ )
742+ )
743+ else :
744+ raise ValueError (f"depth list { depth_list } is not supported for decoder project out" )
745+ self .project_out = nn .Sequential (OrderedDict ([("op_list" , nn .Sequential (* project_out_layers ))]))
786746
787747 def forward (self , x : torch .Tensor ) -> torch .Tensor :
788748 x = self .project_in (x )
0 commit comments