|
117 | 117 | StableDiffusionXLInpaintPipeline, |
118 | 118 | StableDiffusionXLPipeline, |
119 | 119 | ) |
120 | | -from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline |
| 120 | +from .wan import WanAnimatePipeline, WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline |
121 | 121 | from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline |
122 | 122 | from .z_image import ZImageImg2ImgPipeline, ZImagePipeline |
123 | 123 |
|
|
221 | 221 | AUTO_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict( |
222 | 222 | [ |
223 | 223 | ("wan", WanPipeline), |
| 224 | + ("wan-animate", WanAnimatePipeline), |
| 225 | + ("wan-image-to-video", WanImageToVideoPipeline), |
| 226 | + ("wan-vace", WanVACEPipeline), |
| 227 | + ("wan-video-to-video", WanVideoToVideoPipeline), |
224 | 228 | ] |
225 | 229 | ) |
226 | 230 |
|
@@ -1206,3 +1210,39 @@ def from_pipe(cls, pipeline, **kwargs): |
1206 | 1210 | model.register_to_config(**unused_original_config) |
1207 | 1211 |
|
1208 | 1212 | return model |
| 1213 | + |
| 1214 | + |
| 1215 | +class AutoPipelineForText2Video(ConfigMixin): |
| 1216 | + config_name = "model_index.json" |
| 1217 | + |
| 1218 | + def __init__(self, *args, **kwargs): |
| 1219 | + raise EnvironmentError( |
| 1220 | + f"{self.__class__.__name__} is designed to be instantiated " |
| 1221 | + f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " |
| 1222 | + f"`{self.__class__.__name__}.from_pipe(pipeline)` methods." |
| 1223 | + ) |
| 1224 | + |
| 1225 | + @classmethod |
| 1226 | + @validate_hf_hub_args |
| 1227 | + def from_pretrained(cls, pretrained_model_or_path, **kwargs): |
| 1228 | + cache_dir = kwargs.pop("cache_dir", None) |
| 1229 | + force_download = kwargs.pop("force_download", False) |
| 1230 | + proxies = kwargs.pop("proxies", None) |
| 1231 | + token = kwargs.pop("token", None) |
| 1232 | + local_files_only = kwargs.pop("local_files_only", False) |
| 1233 | + revision = kwargs.pop("revision", None) |
| 1234 | + |
| 1235 | + load_config_kwargs = { |
| 1236 | + "cache_dir": cache_dir, |
| 1237 | + "force_download": force_download, |
| 1238 | + "proxies": proxies, |
| 1239 | + "token": token, |
| 1240 | + "local_files_only": local_files_only, |
| 1241 | + "revision": revision, |
| 1242 | + } |
| 1243 | + |
| 1244 | + config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) |
| 1245 | + orig_class_name = config["_class_name"] |
| 1246 | + text_to_video_cls = _get_task_class(AUTO_TEXT2VIDEO_PIPELINES_MAPPING, orig_class_name) |
| 1247 | + kwargs = {**load_config_kwargs, **kwargs} |
| 1248 | + return text_to_video_cls.from_pretrained(pretrained_model_or_path, **kwargs) |
0 commit comments