Skip to content

Commit d15d090

Browse files
committed
initial support
1 parent ceb7af2 commit d15d090

File tree

9 files changed

+1250
-3
lines changed

9 files changed

+1250
-3
lines changed

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@
215215
"UVit2DModel",
216216
"VQModel",
217217
"WanTransformer3DModel",
218+
"WanVACETransformer3DModel",
218219
]
219220
)
220221
_import_structure["optimization"] = [
@@ -526,6 +527,7 @@
526527
"VQDiffusionPipeline",
527528
"WanImageToVideoPipeline",
528529
"WanPipeline",
530+
"WanVACEPipeline",
529531
"WanVideoToVideoPipeline",
530532
"WuerstchenCombinedPipeline",
531533
"WuerstchenDecoderPipeline",
@@ -819,6 +821,7 @@
819821
UVit2DModel,
820822
VQModel,
821823
WanTransformer3DModel,
824+
WanVACETransformer3DModel,
822825
)
823826
from .optimization import (
824827
get_constant_schedule,
@@ -1109,6 +1112,7 @@
11091112
VQDiffusionPipeline,
11101113
WanImageToVideoPipeline,
11111114
WanPipeline,
1115+
WanVACEPipeline,
11121116
WanVideoToVideoPipeline,
11131117
WuerstchenCombinedPipeline,
11141118
WuerstchenDecoderPipeline,

src/diffusers/loaders/peft.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
5959
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
6060
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
61+
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
6162
}
6263

6364

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
9090
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
9191
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
92+
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
9293
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
9394
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
9495
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
@@ -178,6 +179,7 @@
178179
Transformer2DModel,
179180
TransformerTemporalModel,
180181
WanTransformer3DModel,
182+
WanVACETransformer3DModel,
181183
)
182184
from .unets import (
183185
I2VGenXLUNet,

src/diffusers/models/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,4 @@
3232
from .transformer_sd3 import SD3Transformer2DModel
3333
from .transformer_temporal import TransformerTemporalModel
3434
from .transformer_wan import WanTransformer3DModel
35+
from .transformer_wan_vace import WanVACETransformer3DModel

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
340340

341341
_supports_gradient_checkpointing = True
342342
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
343-
_no_split_modules = ["WanTransformerBlock"]
343+
_no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock"]
344344
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
345345
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
346346

0 commit comments

Comments
 (0)