Skip to content

Commit eda4a31

Browse files
committed
Introduce AutoPipelineForText2Video (simple)
1 parent 54fa074 commit eda4a31

File tree

4 files changed

+70
-1
lines changed

4 files changed

+70
-1
lines changed

auto_pipeline_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import torch
2+
from diffusers import AutoPipelineForText2Video
3+
from diffusers.utils import export_to_video
4+
5+
wan_list = [
6+
"Wan-AI/Wan2.1-T2V-14B-Diffusers",
7+
"Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers",
8+
"Wan-AI/Wan2.1-I2V-14B-480P-Diffusers",
9+
"Wan-AI/Wan2.1-VACE-1.3B-diffusers",
10+
"Wan-AI/Wan2.2-I2V-A14B-Diffusers",
11+
]
12+
13+
pipe = AutoPipelineForText2Video.from_pretrained(
14+
wan_list[3],
15+
#torch_dtype=torch.float16,
16+
)
17+
18+
print(pipe.text_encoder.__class__.__name__)
19+
20+
# img = torch.randn(1, 3, 10, 512, 512) # batch 1, 3 channels, 512x512
21+
# latent = pipe.vae.encode(img).latent_dist.mode() # encoder output
22+
# print("Latent shape:", latent.shape)
23+
24+
# #Latent shape: torch.Size([1, 16, 3, 64, 64])
25+
26+
# recon =pipe.vae.decode(latent).sample
27+
# print("Reconstructed image shape:", recon.shape)

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@
305305
"AutoPipelineForImage2Image",
306306
"AutoPipelineForInpainting",
307307
"AutoPipelineForText2Image",
308+
"AutoPipelineForText2Video",
308309
"ConsistencyModelPipeline",
309310
"DanceDiffusionPipeline",
310311
"DDIMPipeline",

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
"AutoPipelineForImage2Image",
4747
"AutoPipelineForInpainting",
4848
"AutoPipelineForText2Image",
49+
"AutoPipelineForText2Video",
4950
]
5051
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
5152
_import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@
117117
StableDiffusionXLInpaintPipeline,
118118
StableDiffusionXLPipeline,
119119
)
120-
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
120+
from .wan import WanAnimatePipeline, WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline
121121
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
122122
from .z_image import ZImageImg2ImgPipeline, ZImagePipeline
123123

@@ -221,6 +221,10 @@
221221
AUTO_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict(
222222
[
223223
("wan", WanPipeline),
224+
("wan-animate", WanAnimatePipeline),
225+
("wan-image-to-video", WanImageToVideoPipeline),
226+
("wan-vace", WanVACEPipeline),
227+
("wan-video-to-video", WanVideoToVideoPipeline),
224228
]
225229
)
226230

@@ -1206,3 +1210,39 @@ def from_pipe(cls, pipeline, **kwargs):
12061210
model.register_to_config(**unused_original_config)
12071211

12081212
return model
1213+
1214+
1215+
class AutoPipelineForText2Video(ConfigMixin):
1216+
config_name = "model_index.json"
1217+
1218+
def __init__(self, *args, **kwargs):
1219+
raise EnvironmentError(
1220+
f"{self.__class__.__name__} is designed to be instantiated "
1221+
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
1222+
f"`{self.__class__.__name__}.from_pipe(pipeline)` methods."
1223+
)
1224+
1225+
@classmethod
1226+
@validate_hf_hub_args
1227+
def from_pretrained(cls, pretrained_model_or_path, **kwargs):
1228+
cache_dir = kwargs.pop("cache_dir", None)
1229+
force_download = kwargs.pop("force_download", False)
1230+
proxies = kwargs.pop("proxies", None)
1231+
token = kwargs.pop("token", None)
1232+
local_files_only = kwargs.pop("local_files_only", False)
1233+
revision = kwargs.pop("revision", None)
1234+
1235+
load_config_kwargs = {
1236+
"cache_dir": cache_dir,
1237+
"force_download": force_download,
1238+
"proxies": proxies,
1239+
"token": token,
1240+
"local_files_only": local_files_only,
1241+
"revision": revision,
1242+
}
1243+
1244+
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
1245+
orig_class_name = config["_class_name"]
1246+
text_to_video_cls = _get_task_class(AUTO_TEXT2VIDEO_PIPELINES_MAPPING, orig_class_name)
1247+
kwargs = {**load_config_kwargs, **kwargs}
1248+
return text_to_video_cls.from_pretrained(pretrained_model_or_path, **kwargs)

0 commit comments

Comments
 (0)