Skip to content

Commit bb8f753

Browse files
committed
style
1 parent feb29c3 commit bb8f753

File tree

2 files changed

+71
-57
lines changed

2 files changed

+71
-57
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import List, Optional, Tuple, Union
15+
from typing import Optional, Tuple, Union
1616

17+
import numpy as np
1718
import torch
1819
import torch.nn as nn
1920
import torch.nn.functional as F
2021
import torch.utils.checkpoint
2122

23+
# YiYi TODO: remove this
24+
from einops import rearrange
25+
2226
from ...configuration_utils import ConfigMixin, register_to_config
2327
from ...loaders import FromOriginalModelMixin
2428
from ...utils import logging
@@ -27,10 +31,6 @@
2731
from ..modeling_outputs import AutoencoderKLOutput
2832
from ..modeling_utils import ModelMixin
2933
from .vae import DecoderOutput, DiagonalGaussianDistribution
30-
import numpy as np
31-
32-
#YiYi TODO: remove this
33-
from einops import rearrange
3434

3535

3636
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -50,7 +50,7 @@ def __init__(self, in_channels: int, out_channels: int, non_linearity: str = "si
5050
super().__init__()
5151
self.in_channels = in_channels
5252
self.out_channels = out_channels
53-
self.nonlinearity = get_activation(non_linearity) # YiYi Notes, they have a custom defined swish but should be the same
53+
self.nonlinearity = get_activation(non_linearity)
5454

5555
# layers
5656
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
@@ -109,9 +109,9 @@ def forward(self, x):
109109
value = self.to_v(x)
110110

111111
batch_size, channels, height, width = query.shape
112-
query = query.permute(0, 2, 3, 1).reshape(batch_size, height*width, channels).contiguous()
113-
key = key.permute(0, 2, 3, 1).reshape(batch_size, height*width, channels).contiguous()
114-
value = value.permute(0, 2, 3, 1).reshape(batch_size, height*width, channels).contiguous()
112+
query = query.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
113+
key = key.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
114+
value = value.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
115115

116116
# apply attention
117117
x = F.scaled_dot_product_attention(query, key, value)
@@ -182,12 +182,11 @@ class HunyuanImageMidBlock(nn.Module):
182182
in_channels (int): Number of input channels.
183183
num_layers (int): Number of layers.
184184
"""
185+
185186
def __init__(self, in_channels: int, num_layers: int = 1):
186187
super().__init__()
187188

188-
resnets = [
189-
HunyuanImageResnetBlock(in_channels=in_channels, out_channels=in_channels)
190-
]
189+
resnets = [HunyuanImageResnetBlock(in_channels=in_channels, out_channels=in_channels)]
191190

192191
attentions = []
193192
for _ in range(num_layers):
@@ -198,7 +197,6 @@ def __init__(self, in_channels: int, num_layers: int = 1):
198197
self.attentions = nn.ModuleList(attentions)
199198

200199
def forward(self, x: torch.Tensor) -> torch.Tensor:
201-
202200
x = self.resnets[0](x)
203201

204202
for attn, resnet in zip(self.attentions, self.resnets[1:]):
@@ -234,8 +232,10 @@ def __init__(
234232
):
235233
super().__init__()
236234
if block_out_channels[-1] % (2 * z_channels) != 0:
237-
raise ValueError(f"block_out_channels[-1 has to be divisible by 2 * out_channels, you have block_out_channels = {block_out_channels[-1]} and out_channels = {out_channels}")
238-
235+
raise ValueError(
236+
f"block_out_channels[-1 has to be divisible by 2 * out_channels, you have block_out_channels = {block_out_channels[-1]} and out_channels = {z_channels}"
237+
)
238+
239239
self.in_channels = in_channels
240240
self.z_channels = z_channels
241241
self.block_out_channels = block_out_channels
@@ -256,14 +256,18 @@ def __init__(
256256
block_out_channel = block_out_channels[i]
257257
# residual blocks
258258
for _ in range(num_res_blocks):
259-
self.down_blocks.append(HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel))
259+
self.down_blocks.append(
260+
HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel)
261+
)
260262
block_in_channel = block_out_channel
261263

262264
# downsample block
263265
if i < np.log2(ffactor_spatial) and i != len(block_out_channels) - 1:
264266
if downsample_match_channel:
265267
block_out_channel = block_out_channels[i + 1]
266-
self.down_blocks.append(HunyuanImageDownsample(in_channels=block_in_channel, out_channels=block_out_channel))
268+
self.down_blocks.append(
269+
HunyuanImageDownsample(in_channels=block_in_channel, out_channels=block_out_channel)
270+
)
267271
block_in_channel = block_out_channel
268272

269273
# middle blocks
@@ -305,7 +309,6 @@ class HunyuanImageDecoder2D(nn.Module):
305309
Decoder network that reconstructs output from latent representation.
306310
307311
Args:
308-
309312
z_channels : int
310313
Number of latent channels.
311314
out_channels : int
@@ -333,7 +336,9 @@ def __init__(
333336
):
334337
super().__init__()
335338
if block_out_channels[0] % z_channels != 0:
336-
raise ValueError(f"block_out_channels[0] should be divisible by z_channels but has block_out_channels[0] = {block_out_channels[0]} and z_channels = {z_channels}")
339+
raise ValueError(
340+
f"block_out_channels[0] should be divisible by z_channels but has block_out_channels[0] = {block_out_channels[0]} and z_channels = {z_channels}"
341+
)
337342

338343
self.z_channels = z_channels
339344
self.block_out_channels = block_out_channels
@@ -353,7 +358,9 @@ def __init__(
353358
for i in range(len(block_out_channels)):
354359
block_out_channel = block_out_channels[i]
355360
for _ in range(self.num_res_blocks + 1):
356-
self.up_blocks.append(HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel))
361+
self.up_blocks.append(
362+
HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel)
363+
)
357364
block_in_channel = block_out_channel
358365

359366
if i < np.log2(ffactor_spatial) and i != len(block_out_channels) - 1:
@@ -369,9 +376,8 @@ def __init__(
369376
self.gradient_checkpointing = False
370377

371378
def forward(self, x: torch.Tensor) -> torch.Tensor:
372-
373379
h = self.conv_in(x) + x.repeat_interleave(repeats=self.repeat, dim=1)
374-
380+
375381
if torch.is_grad_enabled() and self.gradient_checkpointing:
376382
h = self._gradient_checkpointing_func(self.mid_block, h)
377383
else:
@@ -388,7 +394,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
388394
return h
389395

390396

391-
392397
class AutoencoderKLHunyuanImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
393398
r"""
394399
A VAE model for 2D images with spatial tiling support.
@@ -425,7 +430,7 @@ def __init__(
425430
ffactor_spatial=ffactor_spatial,
426431
downsample_match_channel=downsample_match_channel,
427432
)
428-
433+
429434
self.decoder = HunyuanImageDecoder2D(
430435
z_channels=latent_channels,
431436
out_channels=out_channels,
@@ -450,9 +455,9 @@ def enable_tiling(
450455
tile_overlap_factor: Optional[float] = None,
451456
) -> None:
452457
r"""
453-
Enable spatial tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
454-
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
455-
processing larger images.
458+
Enable spatial tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles
459+
to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to
460+
allow processing larger images.
456461
457462
Args:
458463
tile_sample_min_size (`int`, *optional*):
@@ -528,7 +533,7 @@ def encode(
528533
def _decode(self, z: torch.Tensor, return_dict: bool = True):
529534

530535
batch_size, num_channels, height, width = z.shape
531-
536+
532537
if self.use_tiling and (width > self.tile_latent_min_size or height > self.tile_latent_min_size):
533538
return self.tiled_decode(z, return_dict=return_dict)
534539

@@ -587,7 +592,7 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
587592
588593
Args:
589594
x (`torch.Tensor`): Input tensor of shape (B, C, T, H, W).
590-
595+
591596
Returns:
592597
`torch.Tensor`:
593598
The latent representation of the encoded images.
@@ -618,7 +623,7 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
618623
result_rows.append(torch.cat(result_row, dim=-1))
619624

620625
moments = torch.cat(result_rows, dim=-2)
621-
626+
622627
return moments
623628

624629
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:

src/diffusers/models/transformers/transformer_hunyuanimage.py

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,13 @@
2828
from ..cache_utils import CacheMixin
2929
from ..embeddings import (
3030
CombinedTimestepTextProjEmbeddings,
31-
PixArtAlphaTextProjection,
3231
TimestepEmbedding,
3332
Timesteps,
3433
get_1d_rotary_pos_embed,
3534
)
3635
from ..modeling_outputs import Transformer2DModelOutput
3736
from ..modeling_utils import ModelMixin
38-
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm
37+
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
3938

4039

4140
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -64,7 +63,7 @@ def __call__(
6463
key = attn.to_k(hidden_states)
6564
value = attn.to_v(hidden_states)
6665

67-
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) # batch_size, heads, seq_len, head_dim
66+
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) # batch_size, heads, seq_len, head_dim
6867
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
6968
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
7069

@@ -228,8 +227,8 @@ def forward(
228227
class HunyuanImageIndividualTokenRefinerBlock(nn.Module):
229228
def __init__(
230229
self,
231-
num_attention_heads: int, # 28
232-
attention_head_dim: int, # 128
230+
num_attention_heads: int, # 28
231+
attention_head_dim: int, # 128
233232
mlp_width_ratio: str = 4.0,
234233
mlp_drop_rate: float = 0.0,
235234
attention_bias: bool = True,
@@ -321,6 +320,7 @@ def forward(
321320

322321
return hidden_states
323322

323+
324324
# txt_in
325325
class HunyuanImageTokenRefiner(nn.Module):
326326
def __init__(
@@ -381,7 +381,7 @@ def __init__(self, patch_size: int, rope_dim: List[int], theta: float = 256.0) -
381381
self.theta = theta
382382

383383
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
384-
_, _ , height, width = hidden_states.shape
384+
_, _, height, width = hidden_states.shape
385385
rope_sizes = [height // self.patch_size, width // self.patch_size]
386386

387387
axes_grids = []
@@ -635,7 +635,6 @@ def __init__(
635635
) -> None:
636636
super().__init__()
637637

638-
639638
inner_dim = num_attention_heads * attention_head_dim
640639
out_channels = out_channels or in_channels
641640

@@ -644,8 +643,7 @@ def __init__(
644643
self.context_embedder = HunyuanImageTokenRefiner(
645644
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
646645
)
647-
self.context_embedder_2 = HunyuanImageByT5TextProjection(
648-
text_embed_2_dim, 2048, inner_dim)
646+
self.context_embedder_2 = HunyuanImageByT5TextProjection(text_embed_2_dim, 2048, inner_dim)
649647

650648
self.time_guidance_embed = HunyuanImageCombinedTimeGuidanceEmbedding(inner_dim, guidance_embeds)
651649

@@ -739,7 +737,6 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
739737
for name, module in self.named_children():
740738
fn_recursive_attn_processor(name, module, processor)
741739

742-
743740
def forward(
744741
self,
745742
hidden_states: torch.Tensor,
@@ -785,24 +782,36 @@ def forward(
785782
new_encoder_hidden_states = []
786783
new_encoder_attention_mask = []
787784

788-
for text, text_mask, text_2, text_mask_2 in zip(encoder_hidden_states, encoder_attention_mask, encoder_hidden_states_2, encoder_attention_mask_2):
785+
for text, text_mask, text_2, text_mask_2 in zip(
786+
encoder_hidden_states, encoder_attention_mask, encoder_hidden_states_2, encoder_attention_mask_2
787+
):
789788
text_mask = text_mask.bool()
790789
text_mask_2 = text_mask_2.bool()
791790
# Concatenate: [valid_mllm, valid_byt5, invalid_mllm, invalid_byt5]
792-
new_encoder_hidden_states.append(torch.cat([
793-
text_2[text_mask_2], # valid byt5
794-
text[text_mask], # valid mllm
795-
text_2[~text_mask_2], # invalid byt5
796-
text[~text_mask], # invalid mllm
797-
], dim=0))
791+
new_encoder_hidden_states.append(
792+
torch.cat(
793+
[
794+
text_2[text_mask_2], # valid byt5
795+
text[text_mask], # valid mllm
796+
text_2[~text_mask_2], # invalid byt5
797+
text[~text_mask], # invalid mllm
798+
],
799+
dim=0,
800+
)
801+
)
798802

799803
# Apply same reordering to attention masks
800-
new_encoder_attention_mask.append(torch.cat([
801-
text_mask_2[text_mask_2],
802-
text_mask[text_mask],
803-
text_mask_2[~text_mask_2],
804-
text_mask[~text_mask],
805-
], dim=0))
804+
new_encoder_attention_mask.append(
805+
torch.cat(
806+
[
807+
text_mask_2[text_mask_2],
808+
text_mask[text_mask],
809+
text_mask_2[~text_mask_2],
810+
text_mask[~text_mask],
811+
],
812+
dim=0,
813+
)
814+
)
806815

807816
encoder_hidden_states = torch.stack(new_encoder_hidden_states)
808817
encoder_attention_mask = torch.stack(new_encoder_attention_mask)
@@ -854,10 +863,10 @@ def forward(
854863
hidden_states = self.norm_out(hidden_states, temb)
855864
hidden_states = self.proj_out(hidden_states)
856865

857-
hidden_states = hidden_states.reshape(
858-
batch_size, post_patch_height, post_patch_width, -1, p, p
859-
)
860-
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5) # batch_size, channels, height, patch_size, width, patch_size
866+
hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
867+
hidden_states = hidden_states.permute(
868+
0, 3, 1, 4, 2, 5
869+
) # batch_size, channels, height, patch_size, width, patch_size
861870
hidden_states = hidden_states.flatten(4, 5).flatten(2, 3)
862871

863872
if USE_PEFT_BACKEND:

0 commit comments

Comments
 (0)