Skip to content

Commit 790aeff

Browse files
committed
make style
1 parent 02864b5 commit 790aeff

File tree

7 files changed

+81
-44
lines changed

7 files changed

+81
-44
lines changed

src/diffusers/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,9 @@
184184
"AutoencoderKLAllegro",
185185
"AutoencoderKLCogVideoX",
186186
"AutoencoderKLCosmos",
187-
"AutoencoderKLHunyuanVideo",
188-
"AutoencoderKLHunyuanImageRefiner",
189187
"AutoencoderKLHunyuanImage",
188+
"AutoencoderKLHunyuanImageRefiner",
189+
"AutoencoderKLHunyuanVideo",
190190
"AutoencoderKLLTXVideo",
191191
"AutoencoderKLMagvit",
192192
"AutoencoderKLMochi",
@@ -872,9 +872,9 @@
872872
AutoencoderKLAllegro,
873873
AutoencoderKLCogVideoX,
874874
AutoencoderKLCosmos,
875-
AutoencoderKLHunyuanVideo,
876875
AutoencoderKLHunyuanImage,
877876
AutoencoderKLHunyuanImageRefiner,
877+
AutoencoderKLHunyuanVideo,
878878
AutoencoderKLLTXVideo,
879879
AutoencoderKLMagvit,
880880
AutoencoderKLMochi,
@@ -905,9 +905,9 @@
905905
HunyuanDiT2DControlNetModel,
906906
HunyuanDiT2DModel,
907907
HunyuanDiT2DMultiControlNetModel,
908+
HunyuanImageTransformer2DModel,
908909
HunyuanVideoFramepackTransformer3DModel,
909910
HunyuanVideoTransformer3DModel,
910-
HunyuanImageTransformer2DModel,
911911
I2VGenXLUNet,
912912
Kandinsky3UNet,
913913
LatteTransformer3DModel,

src/diffusers/models/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@
7373
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
7474
_import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
7575
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
76-
_import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
7776
_import_structure["transformers.latte_transformer_3d"] = ["LatteTransformer3DModel"]
7877
_import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
7978
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
@@ -93,6 +92,7 @@
9392
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
9493
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
9594
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
95+
_import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
9696
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
9797
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
9898
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
@@ -132,9 +132,9 @@
132132
AutoencoderKLAllegro,
133133
AutoencoderKLCogVideoX,
134134
AutoencoderKLCosmos,
135-
AutoencoderKLHunyuanVideo,
136135
AutoencoderKLHunyuanImage,
137136
AutoencoderKLHunyuanImageRefiner,
137+
AutoencoderKLHunyuanVideo,
138138
AutoencoderKLLTXVideo,
139139
AutoencoderKLMagvit,
140140
AutoencoderKLMochi,

src/diffusers/models/autoencoders/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
66
from .autoencoder_kl_cosmos import AutoencoderKLCosmos
77
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
8+
from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
9+
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
810
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
911
from .autoencoder_kl_magvit import AutoencoderKLMagvit
1012
from .autoencoder_kl_mochi import AutoencoderKLMochi
1113
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
12-
from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
13-
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
1414
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
1515
from .autoencoder_kl_wan import AutoencoderKLWan
1616
from .autoencoder_oobleck import AutoencoderOobleck

src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from ...utils import logging
2525
from ...utils.accelerate_utils import apply_forward_hook
2626
from ..activations import get_activation
27-
from ..attention_processor import Attention
2827
from ..modeling_outputs import AutoencoderKLOutput
2928
from ..modeling_utils import ModelMixin
3029
from .vae import DecoderOutput, DiagonalGaussianDistribution
@@ -126,8 +125,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
126125

127126
x = x.squeeze(1).reshape(batch_size, frames, height, width, channels).permute(0, 4, 1, 2, 3)
128127
x = self.proj_out(x)
129-
130-
return x + identity
128+
129+
return x + identity
131130

132131

133132
class HunyuanImageRefinerUpsampleDCAE(nn.Module):
@@ -143,11 +142,11 @@ def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: b
143142
def _dcae_upsample_rearrange(tensor, r1=1, r2=2, r3=2):
144143
"""
145144
Convert (b, r1*r2*r3*c, f, h, w) -> (b, c, r1*f, r2*h, r3*w)
146-
145+
147146
Args:
148147
tensor: Input tensor of shape (b, r1*r2*r3*c, f, h, w)
149148
r1: temporal upsampling factor
150-
r2: height upsampling factor
149+
r2: height upsampling factor
151150
r3: width upsampling factor
152151
"""
153152
b, packed_c, f, h, w = tensor.shape
@@ -187,12 +186,11 @@ def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample:
187186
self.add_temporal_downsample = add_temporal_downsample
188187
self.group_size = factor * in_channels // out_channels
189188

190-
191189
@staticmethod
192190
def _dcae_downsample_rearrange(self, tensor, r1=1, r2=2, r3=2):
193191
"""
194192
Convert (b, c, r1*f, r2*h, r3*w) -> (b, r1*r2*r3*c, f, h, w)
195-
193+
196194
This packs spatial/temporal dimensions into channels (opposite of upsample)
197195
"""
198196
b, c, packed_f, packed_h, packed_w = tensor.shape
@@ -202,7 +200,6 @@ def _dcae_downsample_rearrange(self, tensor, r1=1, r2=2, r3=2):
202200
tensor = tensor.permute(0, 2, 4, 6, 1, 3, 5, 7)
203201
return tensor.reshape(b, r1 * r2 * r3 * c, f, h, w)
204202

205-
206203
def forward(self, x: torch.Tensor):
207204
r1 = 2 if self.add_temporal_downsample else 1
208205
h = self.conv(x)
@@ -304,16 +301,13 @@ def __init__(
304301
self.gradient_checkpointing = False
305302

306303
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
307-
308304
hidden_states = self.resnets[0](hidden_states)
309305

310-
311306
for attn, resnet in zip(self.attentions, self.resnets[1:]):
312307
if attn is not None:
313308
hidden_states = attn(hidden_states)
314309
hidden_states = resnet(hidden_states)
315310

316-
317311
return hidden_states
318312

319313

@@ -356,7 +350,6 @@ def __init__(
356350
self.gradient_checkpointing = False
357351

358352
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
359-
360353
for resnet in self.resnets:
361354
hidden_states = resnet(hidden_states)
362355

@@ -461,7 +454,6 @@ def __init__(
461454
)
462455
input_channel = output_channel
463456
else:
464-
465457
add_temporal_downsample = i >= np.log2(spatial_compression_ratio // temporal_compression_ratio)
466458
downsample_out_channels = block_out_channels[i + 1] if downsample_match_channel else output_channel
467459
down_block = HunyuanImageRefinerDownBlock3D(
@@ -518,7 +510,7 @@ class HunyuanImageRefinerDecoder3D(nn.Module):
518510
def __init__(
519511
self,
520512
in_channels: int = 32,
521-
out_channels: int = 3,
513+
out_channels: int = 3,
522514
block_out_channels: Tuple[int, ...] = (1024, 1024, 512, 256, 128),
523515
layers_per_block: int = 2,
524516
spatial_compression_ratio: int = 16,
@@ -574,10 +566,8 @@ def __init__(
574566
self.gradient_checkpointing = False
575567

576568
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
577-
578569
hidden_states = self.conv_in(hidden_states) + hidden_states.repeat_interleave(repeats=self.repeat, dim=1)
579570

580-
581571
if torch.is_grad_enabled() and self.gradient_checkpointing:
582572
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
583573

@@ -598,8 +588,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
598588

599589
class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
600590
r"""
601-
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
602-
Used for HunyuanImage-2.1 Refiner..
591+
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
592+
HunyuanImage-2.1 Refiner..
603593
604594
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
605595
for all models (such as downloading or saving).
@@ -621,7 +611,7 @@ def __init__(
621611
upsample_match_channel: bool = True,
622612
scaling_factor: float = 1.03682,
623613
) -> None:
624-
super().__init__()
614+
super().__init__()
625615

626616
self.encoder = HunyuanImageRefinerEncoder3D(
627617
in_channels=in_channels,
@@ -655,7 +645,6 @@ def __init__(
655645
# intermediate tiles together, the memory requirement can be lowered.
656646
self.use_tiling = False
657647

658-
659648
# The minimal tile height and width for spatial tiling to be used
660649
self.tile_sample_min_height = 256
661650
self.tile_sample_min_width = 256
@@ -763,7 +752,7 @@ def _decode(self, z: torch.Tensor) -> torch.Tensor:
763752

764753
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
765754
return self.tiled_decode(z)
766-
755+
767756
dec = self.decoder(z)
768757

769758
return dec
@@ -829,7 +818,7 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
829818
The latent representation of the encoded videos.
830819
"""
831820
_, _, _, height, width = x.shape
832-
821+
833822
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
834823
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
835824
overlap_height = int(tile_latent_min_height * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
@@ -922,7 +911,6 @@ def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
922911

923912
return dec
924913

925-
926914
def forward(
927915
self,
928916
sample: torch.Tensor,

src/diffusers/models/transformers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from .dit_transformer_2d import DiTTransformer2DModel
99
from .dual_transformer_2d import DualTransformer2DModel
1010
from .hunyuan_transformer_2d import HunyuanDiT2DModel
11-
from .transformer_hunyuanimage import HunyuanImageTransformer2DModel
1211
from .latte_transformer_3d import LatteTransformer3DModel
1312
from .lumina_nextdit2d import LuminaNextDiT2DModel
1413
from .pixart_transformer_2d import PixArtTransformer2DModel
@@ -28,6 +27,7 @@
2827
from .transformer_hidream_image import HiDreamImageTransformer2DModel
2928
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
3029
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
30+
from .transformer_hunyuanimage import HunyuanImageTransformer2DModel
3131
from .transformer_ltx import LTXVideoTransformer3DModel
3232
from .transformer_lumina2 import Lumina2Transformer2DModel
3333
from .transformer_mochi import MochiTransformer3DModel

src/diffusers/models/transformers/transformer_hunyuanimage.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Dict, List, Optional, Tuple, Union
16-
1715
import math
16+
from typing import Any, Dict, List, Optional, Tuple, Union
1817

1918
import torch
2019
import torch.nn as nn
@@ -217,7 +216,9 @@ def __init__(
217216
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
218217

219218
def forward(
220-
self, timestep: torch.Tensor, guidance: Optional[torch.Tensor] = None,
219+
self,
220+
timestep: torch.Tensor,
221+
guidance: Optional[torch.Tensor] = None,
221222
) -> Tuple[torch.Tensor, torch.Tensor]:
222223
timesteps_proj = self.time_proj(timestep)
223224
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=timestep.dtype)) # (N, D)
@@ -381,13 +382,15 @@ def forward(
381382

382383

383384
class HunyuanImageRotaryPosEmbed(nn.Module):
384-
def __init__(self, patch_size: Union[Tuple, List[int]], rope_dim: Union[Tuple, List[int]], theta: float = 256.0) -> None:
385+
def __init__(
386+
self, patch_size: Union[Tuple, List[int]], rope_dim: Union[Tuple, List[int]], theta: float = 256.0
387+
) -> None:
385388
super().__init__()
386389

387-
if not isinstance(patch_size, (tuple, list)) or not len(patch_size) in [2, 3]:
390+
if not isinstance(patch_size, (tuple, list)) or len(patch_size) not in [2, 3]:
388391
raise ValueError(f"patch_size must be a tuple or list of length 2 or 3, got {patch_size}")
389-
390-
if not isinstance(rope_dim, (tuple, list)) or not len(rope_dim) in [2, 3]:
392+
393+
if not isinstance(rope_dim, (tuple, list)) or len(rope_dim) not in [2, 3]:
391394
raise ValueError(f"rope_dim must be a tuple or list of length 2 or 3, got {rope_dim}")
392395

393396
if not len(patch_size) == len(rope_dim):
@@ -398,7 +401,6 @@ def __init__(self, patch_size: Union[Tuple, List[int]], rope_dim: Union[Tuple, L
398401
self.theta = theta
399402

400403
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
401-
402404
if hidden_states.ndim == 5:
403405
_, _, frame, height, width = hidden_states.shape
404406
patch_size_frame, patch_size_height, patch_size_width = self.patch_size
@@ -805,7 +807,7 @@ def forward(
805807
sizes = (frame, height, width)
806808
else:
807809
raise ValueError(f"hidden_states must be a 4D or 5D tensor, got {hidden_states.shape}")
808-
810+
809811
post_patch_sizes = tuple(d // p for d, p in zip(sizes, self.config.patch_size))
810812

811813
# 1. RoPE
@@ -816,7 +818,7 @@ def forward(
816818
temb = self.time_guidance_embed(timestep, guidance)
817819
hidden_states = self.x_embedder(hidden_states)
818820
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
819-
821+
820822
if self.context_embedder_2 is not None and encoder_hidden_states_2 is not None:
821823
encoder_hidden_states_2 = self.context_embedder_2(encoder_hidden_states_2)
822824

@@ -912,7 +914,7 @@ def forward(
912914
hidden_states = hidden_states.reshape(*reshape_dims)
913915

914916
# create permutation pattern: batch, channels, then interleave post_patch and patch dims
915-
# For 4D: [0, 3, 1, 4, 2, 5] -> batch, channels, post_patch_height, patch_size_height, post_patch_width, patch_size_width
917+
# For 4D: [0, 3, 1, 4, 2, 5] -> batch, channels, post_patch_height, patch_size_height, post_patch_width, patch_size_width
916918
# For 5D: [0, 4, 1, 5, 2, 6, 3, 7] -> batch, channels, post_patch_frame, patch_size_frame, post_patch_height, patch_size_height, post_patch_width, patch_size_width
917919
ndim = len(post_patch_sizes)
918920
permute_pattern = [0, ndim + 1] # batch, channels
@@ -922,7 +924,9 @@ def forward(
922924

923925
# flatten patch dimensions: flatten each (post_patch_size, patch_size) pair
924926
# batch_size, channels, post_patch_sizes[0] * patch_sizes[0], post_patch_sizes[1] * patch_sizes[1], ...
925-
final_dims = [batch_size, out_channels] + [post_patch * patch for post_patch, patch in zip(post_patch_sizes, self.config.patch_size)]
927+
final_dims = [batch_size, out_channels] + [
928+
post_patch * patch for post_patch, patch in zip(post_patch_sizes, self.config.patch_size)
929+
]
926930
hidden_states = hidden_states.reshape(*final_dims)
927931

928932
if USE_PEFT_BACKEND:

0 commit comments

Comments
 (0)