1212# See the License for the specific language governing permissions and 
1313# limitations under the License. 
1414
15- from  typing  import  List ,  Optional , Tuple , Union 
15+ from  typing  import  Optional , Tuple , Union 
1616
17+ import  numpy  as  np 
1718import  torch 
1819import  torch .nn  as  nn 
1920import  torch .nn .functional  as  F 
2021import  torch .utils .checkpoint 
2122
23+ # YiYi TODO: remove this 
24+ from  einops  import  rearrange 
25+ 
2226from  ...configuration_utils  import  ConfigMixin , register_to_config 
2327from  ...loaders  import  FromOriginalModelMixin 
2428from  ...utils  import  logging 
2731from  ..modeling_outputs  import  AutoencoderKLOutput 
2832from  ..modeling_utils  import  ModelMixin 
2933from  .vae  import  DecoderOutput , DiagonalGaussianDistribution 
30- import  numpy  as  np 
31- 
32- #YiYi TODO: remove this 
33- from  einops  import  rearrange 
3434
3535
3636logger  =  logging .get_logger (__name__ )  # pylint: disable=invalid-name 
@@ -50,7 +50,7 @@ def __init__(self, in_channels: int, out_channels: int, non_linearity: str = "si
5050        super ().__init__ ()
5151        self .in_channels  =  in_channels 
5252        self .out_channels  =  out_channels 
53-         self .nonlinearity  =  get_activation (non_linearity )  # YiYi Notes, they have a custom defined swish but should be the same 
53+         self .nonlinearity  =  get_activation (non_linearity )
5454
5555        # layers 
5656        self .norm1  =  nn .GroupNorm (num_groups = 32 , num_channels = in_channels , eps = 1e-6 , affine = True )
@@ -109,9 +109,9 @@ def forward(self, x):
109109        value  =  self .to_v (x )
110110
111111        batch_size , channels , height , width  =  query .shape 
112-         query  =  query .permute (0 , 2 , 3 , 1 ).reshape (batch_size , height * width , channels ).contiguous ()
113-         key  =  key .permute (0 , 2 , 3 , 1 ).reshape (batch_size , height * width , channels ).contiguous ()
114-         value  =  value .permute (0 , 2 , 3 , 1 ).reshape (batch_size , height * width , channels ).contiguous ()
112+         query  =  query .permute (0 , 2 , 3 , 1 ).reshape (batch_size , height   *   width , channels ).contiguous ()
113+         key  =  key .permute (0 , 2 , 3 , 1 ).reshape (batch_size , height   *   width , channels ).contiguous ()
114+         value  =  value .permute (0 , 2 , 3 , 1 ).reshape (batch_size , height   *   width , channels ).contiguous ()
115115
116116        # apply attention 
117117        x  =  F .scaled_dot_product_attention (query , key , value )
@@ -182,12 +182,11 @@ class HunyuanImageMidBlock(nn.Module):
182182        in_channels (int): Number of input channels. 
183183        num_layers (int): Number of layers. 
184184    """ 
185+ 
185186    def  __init__ (self , in_channels : int , num_layers : int  =  1 ):
186187        super ().__init__ ()
187188
188-         resnets  =  [
189-             HunyuanImageResnetBlock (in_channels = in_channels , out_channels = in_channels )
190-         ]
189+         resnets  =  [HunyuanImageResnetBlock (in_channels = in_channels , out_channels = in_channels )]
191190
192191        attentions  =  []
193192        for  _  in  range (num_layers ):
@@ -198,7 +197,6 @@ def __init__(self, in_channels: int, num_layers: int = 1):
198197        self .attentions  =  nn .ModuleList (attentions )
199198
200199    def  forward (self , x : torch .Tensor ) ->  torch .Tensor :
201- 
202200        x  =  self .resnets [0 ](x )
203201
204202        for  attn , resnet  in  zip (self .attentions , self .resnets [1 :]):
@@ -234,8 +232,10 @@ def __init__(
234232    ):
235233        super ().__init__ ()
236234        if  block_out_channels [- 1 ] %  (2  *  z_channels ) !=  0 :
237-             raise  ValueError (f"block_out_channels[-1 has to be divisible by 2 * out_channels, you have block_out_channels = { block_out_channels [- 1 ]} { out_channels }  )
238-         
235+             raise  ValueError (
236+                 f"block_out_channels[-1 has to be divisible by 2 * out_channels, you have block_out_channels = { block_out_channels [- 1 ]} { z_channels }  
237+             )
238+ 
239239        self .in_channels  =  in_channels 
240240        self .z_channels  =  z_channels 
241241        self .block_out_channels  =  block_out_channels 
@@ -256,14 +256,18 @@ def __init__(
256256            block_out_channel  =  block_out_channels [i ]
257257            # residual blocks 
258258            for  _  in  range (num_res_blocks ):
259-                 self .down_blocks .append (HunyuanImageResnetBlock (in_channels = block_in_channel , out_channels = block_out_channel ))
259+                 self .down_blocks .append (
260+                     HunyuanImageResnetBlock (in_channels = block_in_channel , out_channels = block_out_channel )
261+                 )
260262                block_in_channel  =  block_out_channel 
261263
262264            # downsample block 
263265            if  i  <  np .log2 (ffactor_spatial ) and  i  !=  len (block_out_channels ) -  1 :
264266                if  downsample_match_channel :
265267                    block_out_channel  =  block_out_channels [i  +  1 ]
266-                 self .down_blocks .append (HunyuanImageDownsample (in_channels = block_in_channel , out_channels = block_out_channel ))
268+                 self .down_blocks .append (
269+                     HunyuanImageDownsample (in_channels = block_in_channel , out_channels = block_out_channel )
270+                 )
267271                block_in_channel  =  block_out_channel 
268272
269273        # middle blocks 
@@ -305,7 +309,6 @@ class HunyuanImageDecoder2D(nn.Module):
305309    Decoder network that reconstructs output from latent representation. 
306310
307311    Args: 
308- 
309312    z_channels : int 
310313        Number of latent channels. 
311314    out_channels : int 
@@ -333,7 +336,9 @@ def __init__(
333336    ):
334337        super ().__init__ ()
335338        if  block_out_channels [0 ] %  z_channels  !=  0 :
336-             raise  ValueError (f"block_out_channels[0] should be divisible by z_channels but has block_out_channels[0] = { block_out_channels [0 ]} { z_channels }  )
339+             raise  ValueError (
340+                 f"block_out_channels[0] should be divisible by z_channels but has block_out_channels[0] = { block_out_channels [0 ]} { z_channels }  
341+             )
337342
338343        self .z_channels  =  z_channels 
339344        self .block_out_channels  =  block_out_channels 
@@ -353,7 +358,9 @@ def __init__(
353358        for  i  in  range (len (block_out_channels )):
354359            block_out_channel  =  block_out_channels [i ]
355360            for  _  in  range (self .num_res_blocks  +  1 ):
356-                 self .up_blocks .append (HunyuanImageResnetBlock (in_channels = block_in_channel , out_channels = block_out_channel ))
361+                 self .up_blocks .append (
362+                     HunyuanImageResnetBlock (in_channels = block_in_channel , out_channels = block_out_channel )
363+                 )
357364                block_in_channel  =  block_out_channel 
358365
359366            if  i  <  np .log2 (ffactor_spatial ) and  i  !=  len (block_out_channels ) -  1 :
@@ -369,9 +376,8 @@ def __init__(
369376        self .gradient_checkpointing  =  False 
370377
371378    def  forward (self , x : torch .Tensor ) ->  torch .Tensor :
372-         
373379        h  =  self .conv_in (x ) +  x .repeat_interleave (repeats = self .repeat , dim = 1 )
374-          
380+ 
375381        if  torch .is_grad_enabled () and  self .gradient_checkpointing :
376382            h  =  self ._gradient_checkpointing_func (self .mid_block , h )
377383        else :
@@ -388,7 +394,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
388394        return  h 
389395
390396
391- 
392397class  AutoencoderKLHunyuanImage (ModelMixin , ConfigMixin , FromOriginalModelMixin ):
393398    r""" 
394399    A VAE model for 2D images with spatial tiling support. 
@@ -425,7 +430,7 @@ def __init__(
425430            ffactor_spatial = ffactor_spatial ,
426431            downsample_match_channel = downsample_match_channel ,
427432        )
428-          
433+ 
429434        self .decoder  =  HunyuanImageDecoder2D (
430435            z_channels = latent_channels ,
431436            out_channels = out_channels ,
@@ -450,9 +455,9 @@ def enable_tiling(
450455        tile_overlap_factor : Optional [float ] =  None ,
451456    ) ->  None :
452457        r""" 
453-         Enable spatial tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to  
454-         compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow  
455-         processing larger images. 
458+         Enable spatial tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles 
459+         to  compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to 
460+         allow  processing larger images. 
456461
457462        Args: 
458463            tile_sample_min_size (`int`, *optional*): 
@@ -528,7 +533,7 @@ def encode(
528533    def  _decode (self , z : torch .Tensor , return_dict : bool  =  True ):
529534
530535        batch_size , num_channels , height , width  =  z .shape 
531-          
536+ 
532537        if  self .use_tiling  and  (width  >  self .tile_latent_min_size  or  height  >  self .tile_latent_min_size ):
533538            return  self .tiled_decode (z , return_dict = return_dict )
534539
@@ -587,7 +592,7 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
587592
588593        Args: 
589594            x (`torch.Tensor`): Input tensor of shape (B, C, T, H, W). 
590-          
595+ 
591596        Returns: 
592597            `torch.Tensor`: 
593598                The latent representation of the encoded images. 
@@ -618,7 +623,7 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
618623            result_rows .append (torch .cat (result_row , dim = - 1 ))
619624
620625        moments  =  torch .cat (result_rows , dim = - 2 )
621-          
626+ 
622627        return  moments 
623628
624629    def  tiled_decode (self , z : torch .Tensor , return_dict : bool  =  True ) ->  Union [DecoderOutput , torch .Tensor ]:
0 commit comments