Skip to content

Commit d9f615d

Browse files
committed
conversion script
1 parent a9768d2 commit d9f615d

File tree

8 files changed

+283
-73
lines changed

8 files changed

+283
-73
lines changed

scripts/convert_wan_to_diffusers.py

Lines changed: 265 additions & 55 deletions
Large diffs are not rendered by default.

src/diffusers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@
438438
"VersatileDiffusionTextToImagePipeline",
439439
"VideoToVideoSDPipeline",
440440
"VQDiffusionPipeline",
441-
"WanI2VPipeline",
441+
"WanImageToVideoPipeline",
442442
"WanPipeline",
443443
"WuerstchenCombinedPipeline",
444444
"WuerstchenDecoderPipeline",
@@ -939,7 +939,7 @@
939939
VersatileDiffusionTextToImagePipeline,
940940
VideoToVideoSDPipeline,
941941
VQDiffusionPipeline,
942-
WanI2VPipeline,
942+
WanImageToVideoPipeline,
943943
WanPipeline,
944944
WuerstchenCombinedPipeline,
945945
WuerstchenDecoderPipeline,

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def __init__(
127127
time_freq_dim: int,
128128
time_proj_dim: int,
129129
text_embed_dim: int,
130-
image_embedding_dim: Optional[int] = None,
130+
image_embed_dim: Optional[int] = None,
131131
):
132132
super().__init__()
133133

@@ -138,8 +138,8 @@ def __init__(
138138
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
139139

140140
self.image_embedder = None
141-
if image_embedding_dim is not None:
142-
self.image_embedder = WanImageEmbedding(image_embedding_dim, dim)
141+
if image_embed_dim is not None:
142+
self.image_embedder = WanImageEmbedding(image_embed_dim, dim)
143143

144144
def forward(
145145
self,
@@ -348,7 +348,7 @@ def __init__(
348348
cross_attn_norm: bool = True,
349349
qk_norm: Optional[str] = "rms_norm_across_heads",
350350
eps: float = 1e-6,
351-
image_embedding_dim: Optional[int] = None,
351+
image_dim: Optional[int] = None,
352352
added_kv_proj_dim: Optional[int] = None,
353353
rope_max_seq_len: int = 1024,
354354
) -> None:
@@ -368,7 +368,7 @@ def __init__(
368368
time_freq_dim=freq_dim,
369369
time_proj_dim=inner_dim * 6,
370370
text_embed_dim=text_dim,
371-
image_embedding_dim=image_embedding_dim,
371+
image_embed_dim=image_dim,
372372
)
373373

374374
# 3. Transformer blocks

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@
346346
"WuerstchenDecoderPipeline",
347347
"WuerstchenPriorPipeline",
348348
]
349-
_import_structure["wan"] = ["WanPipeline", "WanI2VPipeline"]
349+
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline"]
350350
try:
351351
if not is_onnx_available():
352352
raise OptionalDependencyNotAvailable()
@@ -689,7 +689,7 @@
689689
UniDiffuserPipeline,
690690
UniDiffuserTextDecoder,
691691
)
692-
from .wan import WanI2VPipeline, WanPipeline
692+
from .wan import WanImageToVideoPipeline, WanPipeline
693693
from .wuerstchen import (
694694
WuerstchenCombinedPipeline,
695695
WuerstchenDecoderPipeline,

src/diffusers/pipelines/wan/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2424
else:
2525
_import_structure["pipeline_wan"] = ["WanPipeline"]
26-
_import_structure["pipeline_wan_i2v"] = ["WanI2VPipeline"]
26+
_import_structure["pipeline_wan_i2v"] = ["WanImageToVideoPipeline"]
2727

2828
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
2929
try:
@@ -34,7 +34,7 @@
3434
from ...utils.dummy_torch_and_transformers_objects import *
3535
else:
3636
from .pipeline_wan import WanPipeline
37-
from .pipeline_wan_i2v import WanI2VPipeline
37+
from .pipeline_wan_i2v import WanImageToVideoPipeline
3838

3939
else:
4040
import sys

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
Examples:
4747
```python
4848
>>> import torch
49-
>>> from diffusers import WanI2VPipeline, WanTransformer3DModel
49+
>>> from diffusers import WanImageToVideoPipeline, WanTransformer3DModel
5050
>>> from transformers import CLIPVisionModel, CLIPImageProcessor, UMT5EncoderModel
5151
>>> from diffusers.utils import load_image, export_to_video
5252
@@ -56,7 +56,7 @@
5656
>>> text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder")
5757
>>> transformer_i2v = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer_i2v_720p")
5858
>>> image_processor = CLIPImageProcessor.from_pretrained(model_id, subfolder="image_processor")
59-
>>> pipe = WanI2VPipeline.from_pretrained(
59+
>>> pipe = WanImageToVideoPipeline.from_pretrained(
6060
... model_id,
6161
... transformer=transformer_i2v,
6262
... text_encoder=text_encoder,
@@ -125,7 +125,7 @@ def retrieve_latents(
125125
raise AttributeError("Could not access latents of provided encoder_output")
126126

127127

128-
class WanI2VPipeline(DiffusionPipeline):
128+
class WanImageToVideoPipeline(DiffusionPipeline):
129129
r"""
130130
Pipeline for image-to-video generation using Wan.
131131

src/diffusers/utils/dummy_torch_and_transformers_objects.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2552,7 +2552,7 @@ def from_pretrained(cls, *args, **kwargs):
25522552
requires_backends(cls, ["torch", "transformers"])
25532553

25542554

2555-
class WanI2VPipeline(metaclass=DummyObject):
2555+
class WanImageToVideoPipeline(metaclass=DummyObject):
25562556
_backends = ["torch", "transformers"]
25572557

25582558
def __init__(self, *args, **kwargs):

tests/pipelines/wan/test_wan_image_to_video.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from PIL import Image
2020
from transformers import AutoTokenizer, T5EncoderModel, CLIPVisionConfig, CLIPVisionModel, CLIPImageProcessor
2121

22-
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanI2VPipeline, WanTransformer3DModel
22+
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanImageToVideoPipeline, WanTransformer3DModel
2323
from diffusers.utils.testing_utils import enable_full_determinism
2424

2525
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
@@ -30,7 +30,7 @@
3030

3131

3232
class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
33-
pipeline_class = WanI2VPipeline
33+
pipeline_class = WanImageToVideoPipeline
3434
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "height", "width"}
3535
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
3636
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
@@ -78,7 +78,7 @@ def get_dummy_components(self):
7878
cross_attn_norm=True,
7979
qk_norm="rms_norm_across_heads",
8080
rope_max_seq_len=32,
81-
image_embedding_dim=4,
81+
image_dim=4,
8282
)
8383

8484
torch.manual_seed(0)

0 commit comments

Comments
 (0)