1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- from typing import Any , Callable , Tuple , Optional
16+ from typing import Any , Callable , Optional , Tuple
1717
1818import torch
1919import torch .nn as nn
@@ -34,7 +34,7 @@ def val2tuple(x: list | tuple | Any, min_len: int = 1) -> tuple:
3434 return tuple (x )
3535
3636
37- def build_norm (name : Optional [str ]= "bn2d" , num_features : Optional [int ]= None ) -> Optional [nn .Module ]:
37+ def build_norm (name : Optional [str ] = "bn2d" , num_features : Optional [int ] = None ) -> Optional [nn .Module ]:
3838 if name is None :
3939 norm = None
4040 elif name == "rms2d" :
@@ -481,7 +481,7 @@ def build_stage_main(
481481
482482 in_channels = width if d > 0 else input_width
483483 out_channels = width
484-
484+
485485 if current_block_type == "ResBlock" :
486486 assert in_channels == out_channels
487487 block = ResBlock (
@@ -501,7 +501,7 @@ def build_stage_main(
501501 block = EfficientViTBlock (in_channels , norm = norm , act_func = act , local_module = "GLUMBConv" , scales = (5 ,))
502502 else :
503503 raise ValueError (f"block_type { current_block_type } is not supported" )
504-
504+
505505 stage .append (block )
506506 return stage
507507
@@ -543,7 +543,7 @@ def __init__(
543543 shortcut : bool = True ,
544544 ) -> None :
545545 super ().__init__ ()
546-
546+
547547 self .downsample = downsample
548548 self .factor = 2
549549 self .stride = 1 if downsample else 2
@@ -552,21 +552,21 @@ def __init__(
552552 if downsample :
553553 assert out_channels % out_ratio == 0
554554 out_channels = out_channels // out_ratio
555-
555+
556556 self .conv = nn .Conv2d (
557557 in_channels ,
558558 out_channels ,
559559 kernel_size = kernel_size ,
560560 stride = self .stride ,
561561 padding = kernel_size // 2 ,
562562 )
563-
563+
564564 self .shortcut = None
565565 if shortcut :
566566 self .shortcut = DownsamplePixelUnshuffleChannelAveraging (
567567 in_channels = in_channels , out_channels = out_channels , factor = 2
568568 )
569-
569+
570570 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
571571 x = self .conv (hidden_states )
572572 if self .downsample :
@@ -594,8 +594,8 @@ def __init__(
594594 self .interpolation_mode = interpolation_mode
595595 self .factor = 2
596596 self .stride = 1
597-
598- out_ratio = self .factor ** 2
597+
598+ out_ratio = self .factor ** 2
599599 if not interpolate :
600600 out_channels = out_channels * out_ratio
601601
@@ -612,20 +612,20 @@ def __init__(
612612 self .shortcut = UpsampleChannelDuplicatingPixelUnshuffle (
613613 in_channels = in_channels , out_channels = out_channels , factor = 2
614614 )
615-
615+
616616 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
617617 if self .interpolate :
618618 x = F .interpolate (hidden_states , scale_factor = self .factor , mode = self .interpolation_mode )
619619 x = self .conv (x )
620620 else :
621621 x = self .conv (hidden_states )
622622 x = F .pixel_shuffle (x , self .factor )
623-
623+
624624 if self .shortcut is not None :
625625 hidden_states = x + self .shortcut (hidden_states )
626626 else :
627627 hidden_states = x
628-
628+
629629 return hidden_states
630630
631631
@@ -644,9 +644,7 @@ def __init__(
644644 self .num_stages = num_stages
645645 assert len (layers_per_block ) == num_stages
646646 assert len (block_out_channels ) == num_stages
647- assert isinstance (block_type , str ) or (
648- isinstance (block_type , list ) and len (block_type ) == num_stages
649- )
647+ assert isinstance (block_type , str ) or (isinstance (block_type , list ) and len (block_type ) == num_stages )
650648
651649 factor = 1 if layers_per_block [0 ] > 0 else 2
652650
@@ -722,19 +720,11 @@ def __init__(
722720 self .num_stages = num_stages
723721 assert len (layers_per_block ) == num_stages
724722 assert len (block_out_channels ) == num_stages
725- assert isinstance (block_type , str ) or (
726- isinstance (block_type , list ) and len (block_type ) == num_stages
727- )
723+ assert isinstance (block_type , str ) or (isinstance (block_type , list ) and len (block_type ) == num_stages )
728724 assert isinstance (norm , str ) or (isinstance (norm , list ) and len (norm ) == num_stages )
729725 assert isinstance (act , str ) or (isinstance (act , list ) and len (act ) == num_stages )
730726
731- self .conv_in = nn .Conv2d (
732- latent_channels ,
733- block_out_channels [- 1 ],
734- kernel_size = 3 ,
735- stride = 1 ,
736- padding = 1
737- )
727+ self .conv_in = nn .Conv2d (latent_channels , block_out_channels [- 1 ], kernel_size = 3 , stride = 1 , padding = 1 )
738728 self .norm_in = UpsampleChannelDuplicatingPixelUnshuffle (
739729 in_channels = latent_channels , out_channels = block_out_channels [- 1 ], factor = 1
740730 )
@@ -767,9 +757,15 @@ def __init__(
767757 stages .insert (0 , nn .Sequential (* current_stage ))
768758 self .stages = nn .ModuleList (stages )
769759
770- factor = 1 if layers_per_block [0 ] > 0 else 2
760+ factor = 1 if layers_per_block [0 ] > 0 else 2
771761
772- self .norm_out = RMSNormNd (block_out_channels [0 ] if layers_per_block [0 ] > 0 else block_out_channels [1 ], eps = 1e-5 , elementwise_affine = True , bias = True , channel_dim = 1 )
762+ self .norm_out = RMSNormNd (
763+ block_out_channels [0 ] if layers_per_block [0 ] > 0 else block_out_channels [1 ],
764+ eps = 1e-5 ,
765+ elementwise_affine = True ,
766+ bias = True ,
767+ channel_dim = 1 ,
768+ )
773769 self .conv_act = nn .ReLU ()
774770 self .conv_out = None
775771
@@ -884,7 +880,9 @@ def dc_ae_f32c32(name: str) -> dict:
884880 return cfg
885881
886882
887- def dc_ae_f64c128 (name : str ,) -> dict :
883+ def dc_ae_f64c128 (
884+ name : str ,
885+ ) -> dict :
888886 if name in ["dc-ae-f64c128-in-1.0" , "dc-ae-f64c128-mix-1.0" ]:
889887 cfg = {
890888 "latent_channels" : 128 ,
@@ -901,14 +899,34 @@ def dc_ae_f64c128(name: str,) -> dict:
901899 return cfg
902900
903901
904- def dc_ae_f128c512 (name : str ,) -> dict :
902+ def dc_ae_f128c512 (
903+ name : str ,
904+ ) -> dict :
905905 if name in ["dc-ae-f128c512-in-1.0" , "dc-ae-f128c512-mix-1.0" ]:
906906 cfg = {
907907 "latent_channels" : 512 ,
908- "encoder_block_type" : ["ResBlock" , "ResBlock" , "ResBlock" , "EViT_GLU" , "EViT_GLU" , "EViT_GLU" , "EViT_GLU" , "EViT_GLU" ],
908+ "encoder_block_type" : [
909+ "ResBlock" ,
910+ "ResBlock" ,
911+ "ResBlock" ,
912+ "EViT_GLU" ,
913+ "EViT_GLU" ,
914+ "EViT_GLU" ,
915+ "EViT_GLU" ,
916+ "EViT_GLU" ,
917+ ],
909918 "block_out_channels" : [128 , 256 , 512 , 512 , 1024 , 1024 , 2048 , 2048 ],
910919 "encoder_layers_per_block" : [0 , 4 , 8 , 2 , 2 , 2 , 2 , 2 ],
911- "decoder_block_type" : ["ResBlock" , "ResBlock" , "ResBlock" , "EViT_GLU" , "EViT_GLU" , "EViT_GLU" , "EViT_GLU" , "EViT_GLU" ],
920+ "decoder_block_type" : [
921+ "ResBlock" ,
922+ "ResBlock" ,
923+ "ResBlock" ,
924+ "EViT_GLU" ,
925+ "EViT_GLU" ,
926+ "EViT_GLU" ,
927+ "EViT_GLU" ,
928+ "EViT_GLU" ,
929+ ],
912930 "decoder_layers_per_block" : [0 , 5 , 10 , 2 , 2 , 2 , 2 , 2 ],
913931 "decoder_norm" : ["bn2d" , "bn2d" , "bn2d" , "rms2d" , "rms2d" , "rms2d" , "rms2d" , "rms2d" ],
914932 "decoder_act" : ["relu" , "relu" , "relu" , "silu" , "silu" , "silu" , "silu" , "silu" ],
0 commit comments