Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
"AuraFlowTransformer2DModel",
"AutoencoderDC",
"AutoencoderKL",
"AutoencoderKLWan",
"AutoencoderKLAllegro",
"AutoencoderKLCogVideoX",
"AutoencoderKLHunyuanVideo",
Expand Down Expand Up @@ -148,6 +149,7 @@
"UNetSpatioTemporalConditionModel",
"UVit2DModel",
"VQModel",
"WanTransformer3DModel",
]
)
_import_structure["optimization"] = [
Expand Down Expand Up @@ -439,6 +441,8 @@
"WuerstchenCombinedPipeline",
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
"WanPipeline",
"WanI2VPipeline",
]
)

Expand Down Expand Up @@ -610,6 +614,7 @@
AuraFlowTransformer2DModel,
AutoencoderDC,
AutoencoderKL,
AutoencoderKLWan,
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLHunyuanVideo,
Expand Down Expand Up @@ -667,6 +672,7 @@
UNetSpatioTemporalConditionModel,
UVit2DModel,
VQModel,
WanTransformer3DModel,
)
from .optimization import (
get_constant_schedule,
Expand Down Expand Up @@ -936,6 +942,8 @@
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
WanPipeline,
WanI2VPipeline
)

try:
Expand Down
3 changes: 3 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
Expand Down Expand Up @@ -89,6 +90,7 @@
_import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"]
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]

if is_flax_available():
_import_structure["controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
Expand Down Expand Up @@ -158,6 +160,7 @@
T5FilmDecoder,
Transformer2DModel,
TransformerTemporalModel,
WanTransformer3DModel,
)
from .unets import (
I2VGenXLUNet,
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,10 @@ def __init__(
elif qk_norm == "rms_norm":
self.norm_added_q = RMSNorm(dim_head, eps=eps)
self.norm_added_k = RMSNorm(dim_head, eps=eps)
elif qk_norm == "rms_norm_across_heads":
# Wanx applies qk norm across all heads
self.norm_added_q = RMSNorm(dim_head * heads, eps=eps)
self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps)
else:
raise ValueError(
f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/autoencoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_dc import AutoencoderDC
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_wan import AutoencoderKLWan
from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
Expand Down
Loading