1313# See the License for the specific language governing permissions and 
1414# limitations under the License. 
1515
16- from  typing  import  Any , Optional , Callable , Union 
17- from  collections  import  OrderedDict 
16+ from  typing  import  Optional , Union 
1817
1918import  torch 
2019import  torch .nn  as  nn 
21- import  torch .nn .functional  as  F 
2220from  torch .nn  import  BatchNorm2d 
23- from  huggingface_hub  import  PyTorchModelHubMixin 
24- import  ipdb 
2521
2622from  ...configuration_utils  import  ConfigMixin , register_to_config 
27- from  ..modeling_utils  import  ModelMixin 
28- 
2923from  ..activations  import  get_activation 
30- from  ..normalization  import  RMSNorm2d 
31- from  ..downsampling  import  ConvPixelUnshuffleDownsample2D , PixelUnshuffleChannelAveragingDownsample2D 
32- from  ..upsampling  import  ConvPixelShuffleUpsample2D , ChannelDuplicatingPixelUnshuffleUpsample2D , Upsample2D 
3324from  ..attention  import  DCAELiteMLA 
34- 
25+ from  ..downsampling  import  ConvPixelUnshuffleDownsample2D , PixelUnshuffleChannelAveragingDownsample2D 
26+ from  ..modeling_utils  import  ModelMixin 
27+ from  ..normalization  import  RMSNorm2d 
28+ from  ..upsampling  import  ChannelDuplicatingPixelUnshuffleUpsample2D , ConvPixelShuffleUpsample2D , Upsample2D 
3529from  .vae  import  DecoderOutput 
3630
3731
@@ -267,7 +261,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
267261
268262class  Encoder (nn .Module ):
269263    def  __init__ (
270-         self ,  
264+         self ,
271265        in_channels : int ,
272266        latent_channels : int ,
273267        width_list : list [int ] =  [128 , 256 , 512 , 512 , 1024 , 1024 ],
@@ -291,7 +285,7 @@ def __init__(
291285            raise  ValueError (f"len(depth_list) { len (depth_list )} { len (width_list )} { num_stages }  )
292286        if  not  isinstance (block_type , (str , list )) or  (isinstance (block_type , list ) and  len (block_type ) !=  num_stages ):
293287            raise  ValueError (f"block_type should be either a str or a list of str with length { num_stages } { block_type }  )
294-          
288+ 
295289        # project in 
296290        if  depth_list [0 ] >  0 :
297291            project_in_block  =  nn .Conv2d (
@@ -422,7 +416,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
422416
423417class  Decoder (nn .Module ):
424418    def  __init__ (
425-         self ,  
419+         self ,
426420        in_channels : int ,
427421        latent_channels : int ,
428422        in_shortcut : Optional [str ] =  "duplicating" ,
0 commit comments