-
Notifications
You must be signed in to change notification settings - Fork 31
add WanDMDPipeline #219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add WanDMDPipeline #219
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,111 @@ | ||||||
| import torch | ||||||
| import torch.distributed as dist | ||||||
| from typing import Callable, List, Optional | ||||||
| from tqdm import tqdm | ||||||
| from PIL import Image | ||||||
|
|
||||||
| from diffsynth_engine.pipelines.wan_video import WanVideoPipeline | ||||||
|
|
||||||
|
|
||||||
| class WanDMDPipeline(WanVideoPipeline): | ||||||
| def prepare_latents( | ||||||
| self, | ||||||
| latents, | ||||||
| denoising_step_list, | ||||||
| ): | ||||||
| height, width = latents.shape[-2:] | ||||||
| height, width = height * self.upsampling_factor, width * self.upsampling_factor | ||||||
| sigmas, timesteps = self.noise_scheduler.schedule(num_inference_steps=1000) | ||||||
| sigmas = sigmas[[1000 - t for t in denoising_step_list] + [-1]] | ||||||
| timesteps = timesteps[[1000 - t for t in denoising_step_list]] | ||||||
| init_latents = latents.clone() | ||||||
|
|
||||||
| return init_latents, latents, sigmas, timesteps | ||||||
|
|
||||||
| @torch.no_grad() | ||||||
| def __call__( | ||||||
| self, | ||||||
| prompt, | ||||||
| input_image: Image.Image | None = None, | ||||||
| seed=None, | ||||||
| height=480, | ||||||
| width=832, | ||||||
| num_frames=81, | ||||||
| denoising_step_list: List[int] = None, | ||||||
| progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status) | ||||||
| ): | ||||||
|
Comment on lines
+26
to
+36
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Here is a suggestion: @torch.no_grad()
def __call__(
self,
prompt,
input_image: Image.Image | None = None,
seed=None,
height=480,
width=832,
num_frames=81,
denoising_step_list: List[int] = None,
progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
):
"""
Generates a video based on a prompt and an optional input image using the DMD method.
Args:
prompt (str): The text prompt to guide video generation.
input_image (Image.Image | None, optional): An optional input image for image-to-video generation. Defaults to None.
seed (int, optional): Random seed for noise generation. Defaults to None.
height (int, optional): Height of the output video. Defaults to 480.
width (int, optional): Width of the output video. Defaults to 832.
num_frames (int, optional): Number of frames in the output video. Must be `4*k + 1`. Defaults to 81.
denoising_step_list (List[int], optional): A list of timesteps for the denoising process, selected from a 1000-step schedule. Defaults to `[1000, 750, 500, 250]`.
progress_callback (Optional[Callable], optional): A callback function for progress updates. Defaults to None.
Returns:
List[Image.Image]: A list of PIL Images representing the generated video frames.
Note:
This pipeline does not use classifier-free guidance; `cfg_scale` is fixed to 1.0 and negative prompts are ignored.
""" |
||||||
| denoising_step_list = [1000, 750, 500, 250] if denoising_step_list is None else denoising_step_list | ||||||
| divisor = 32 if self.vae.z_dim == 48 else 16 # 32 for wan2.2 vae, 16 for wan2.1 vae | ||||||
| assert height % divisor == 0 and width % divisor == 0, f"height and width must be divisible by {divisor}" | ||||||
| assert (num_frames - 1) % 4 == 0, "num_frames must be 4X+1" | ||||||
|
|
||||||
| # Initialize noise | ||||||
| if dist.is_initialized() and seed is None: | ||||||
| raise ValueError("must provide a seed when parallelism is enabled") | ||||||
| noise = self.generate_noise( | ||||||
| ( | ||||||
| 1, | ||||||
| self.vae.z_dim, | ||||||
| (num_frames - 1) // 4 + 1, | ||||||
| height // self.upsampling_factor, | ||||||
| width // self.upsampling_factor, | ||||||
| ), | ||||||
| seed=seed, | ||||||
| device="cpu", | ||||||
| dtype=torch.float32, | ||||||
| ).to(self.device) | ||||||
| init_latents, latents, sigmas, timesteps = self.prepare_latents(noise, denoising_step_list) | ||||||
| mask = torch.ones((1, 1, *latents.shape[2:]), dtype=latents.dtype, device=latents.device) | ||||||
|
|
||||||
| # Encode prompts | ||||||
| self.load_models_to_device(["text_encoder"]) | ||||||
| prompt_emb_posi = self.encode_prompt(prompt) | ||||||
| prompt_emb_nega = None | ||||||
|
|
||||||
| # Encode image | ||||||
| image_clip_feature = self.encode_clip_feature(input_image, height, width) | ||||||
| image_y = self.encode_vae_feature(input_image, num_frames, height, width) | ||||||
| image_latents = self.encode_image_latents(input_image, height, width) | ||||||
| if image_latents is not None: | ||||||
| latents[:, :, : image_latents.shape[2], :, :] = image_latents | ||||||
| init_latents = latents.clone() | ||||||
| mask[:, :, : image_latents.shape[2], :, :] = 0 | ||||||
|
|
||||||
| # Initialize sampler | ||||||
| self.sampler.initialize(sigmas=sigmas) | ||||||
|
|
||||||
| # Denoise | ||||||
| hide_progress = dist.is_initialized() and dist.get_rank() != 0 | ||||||
| for i, timestep in enumerate(tqdm(timesteps, disable=hide_progress)): | ||||||
| if timestep.item() / 1000 >= self.config.boundary: | ||||||
| self.load_models_to_device(["dit"]) | ||||||
| model = self.dit | ||||||
| else: | ||||||
| self.load_models_to_device(["dit2"]) | ||||||
| model = self.dit2 | ||||||
|
|
||||||
| timestep = timestep * mask[:, :, :, ::2, ::2].flatten() # seq_len | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment
Suggested change
|
||||||
| timestep = timestep.to(dtype=self.dtype, device=self.device) | ||||||
| # Classifier-free guidance | ||||||
| noise_pred = self.predict_noise_with_cfg( | ||||||
| model=model, | ||||||
| latents=latents, | ||||||
| timestep=timestep, | ||||||
| positive_prompt_emb=prompt_emb_posi, | ||||||
| negative_prompt_emb=prompt_emb_nega, | ||||||
| image_clip_feature=image_clip_feature, | ||||||
| image_y=image_y, | ||||||
| cfg_scale=1.0, | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||||||
| batch_cfg=self.config.batch_cfg, | ||||||
| ) | ||||||
| # Scheduler | ||||||
| latents = self.sampler.step(latents, noise_pred, i) | ||||||
| latents = latents * mask + init_latents * (1 - mask) | ||||||
| if progress_callback is not None: | ||||||
| progress_callback(i + 1, len(timesteps), "DENOISING") | ||||||
|
|
||||||
| # Decode | ||||||
| self.load_models_to_device(["vae"]) | ||||||
| frames = self.decode_video(latents, progress_callback=progress_callback) | ||||||
| frames = self.vae_output_to_image(frames) | ||||||
| return frames | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,6 +43,24 @@ def _from_diffsynth(self, state_dict): | |
| dit_dict[key] = lora_args | ||
| return {"dit": dit_dict} | ||
|
|
||
| def _from_diffusers(self, state_dict): | ||
| dit_dict = {} | ||
| for key, param in state_dict.items(): | ||
| if ".lora_down.weight" not in key: | ||
| continue | ||
|
|
||
| lora_args = {} | ||
| lora_args["up"] = state_dict[key.replace(".lora_down.weight", ".lora_up.weight")] | ||
| lora_args["down"] = param | ||
| lora_args["rank"] = lora_args["up"].shape[1] | ||
| if key.replace(".lora_down.weight", ".alpha") in state_dict: | ||
| lora_args["alpha"] = state_dict[key.replace(".lora_down.weight", ".alpha")] | ||
| else: | ||
| lora_args["alpha"] = lora_args["rank"] | ||
| key = key.replace("diffusion_model.", "").replace(".lora_down.weight", "") | ||
| dit_dict[key] = lora_args | ||
| return {"dit": dit_dict} | ||
|
Comment on lines
+46
to
+62
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new method |
||
|
|
||
| def _from_civitai(self, state_dict): | ||
| dit_dict = {} | ||
| for key, param in state_dict.items(): | ||
|
|
@@ -86,6 +104,9 @@ def convert(self, state_dict): | |
| if "lora_unet_blocks_0_cross_attn_k.lora_down.weight" in state_dict: | ||
| state_dict = self._from_fun(state_dict) | ||
| logger.info("use fun format state dict") | ||
| elif "diffusion_model.blocks.0.cross_attn.k.lora_down.weight" in state_dict: | ||
| state_dict = self._from_diffusers(state_dict) | ||
| logger.info("use diffusers format state dict") | ||
| elif "diffusion_model.blocks.0.cross_attn.k.lora_A.weight" in state_dict: | ||
| state_dict = self._from_civitai(state_dict) | ||
| logger.info("use civitai format state dict") | ||
|
|
@@ -480,8 +501,8 @@ def from_pretrained(cls, model_path_or_config: WanPipelineConfig) -> "WanVideoPi | |
|
|
||
| dit_state_dict, dit2_state_dict = None, None | ||
| if isinstance(config.model_path, list): | ||
| high_noise_model_ckpt = [path for path in config.model_path if "high_noise_model" in path] | ||
| low_noise_model_ckpt = [path for path in config.model_path if "low_noise_model" in path] | ||
| high_noise_model_ckpt = [path for path in config.model_path if "high_noise" in path] | ||
| low_noise_model_ckpt = [path for path in config.model_path if "low_noise" in path] | ||
| if high_noise_model_ckpt and low_noise_model_ckpt: | ||
| logger.info(f"loading high noise model state dict from {high_noise_model_ckpt} ...") | ||
| dit_state_dict = cls.load_model_checkpoint( | ||
|
|
@@ -681,8 +702,9 @@ def has_any_key(*xs): | |
| config.attn_params = VideoSparseAttentionParams(sparsity=0.9) | ||
|
|
||
| def update_weights(self, state_dicts: WanStateDicts) -> None: | ||
| is_dual_model_state_dict = (isinstance(state_dicts.model, dict) and | ||
| ("high_noise_model" in state_dicts.model or "low_noise_model" in state_dicts.model)) | ||
| is_dual_model_state_dict = isinstance(state_dicts.model, dict) and ( | ||
| "high_noise_model" in state_dicts.model or "low_noise_model" in state_dicts.model | ||
| ) | ||
| is_dual_model_pipeline = self.dit2 is not None | ||
|
|
||
| if is_dual_model_state_dict != is_dual_model_pipeline: | ||
|
|
@@ -694,15 +716,21 @@ def update_weights(self, state_dicts: WanStateDicts) -> None: | |
|
|
||
| if is_dual_model_state_dict: | ||
| if "high_noise_model" in state_dicts.model: | ||
| self.update_component(self.dit, state_dicts.model["high_noise_model"], self.config.device, self.config.model_dtype) | ||
| self.update_component( | ||
| self.dit, state_dicts.model["high_noise_model"], self.config.device, self.config.model_dtype | ||
| ) | ||
| if "low_noise_model" in state_dicts.model: | ||
| self.update_component(self.dit2, state_dicts.model["low_noise_model"], self.config.device, self.config.model_dtype) | ||
| self.update_component( | ||
| self.dit2, state_dicts.model["low_noise_model"], self.config.device, self.config.model_dtype | ||
| ) | ||
| else: | ||
| self.update_component(self.dit, state_dicts.model, self.config.device, self.config.model_dtype) | ||
|
|
||
| self.update_component(self.text_encoder, state_dicts.t5, self.config.device, self.config.t5_dtype) | ||
| self.update_component(self.vae, state_dicts.vae, self.config.device, self.config.vae_dtype) | ||
| self.update_component(self.image_encoder, state_dicts.image_encoder, self.config.device, self.config.image_encoder_dtype) | ||
| self.update_component( | ||
| self.image_encoder, state_dicts.image_encoder, self.config.device, self.config.image_encoder_dtype | ||
| ) | ||
|
|
||
| def compile(self): | ||
| self.dit.compile_repeated_blocks() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| from PIL import Image | ||
|
|
||
| from diffsynth_engine import WanPipelineConfig | ||
| from diffsynth_engine.pipelines import WanDMDPipeline | ||
| from diffsynth_engine.utils.download import fetch_model | ||
| from diffsynth_engine.utils.video import save_video | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| config = WanPipelineConfig.basic_config( | ||
| model_path=fetch_model( | ||
| "lightx2v/Wan2.2-Distill-Models", | ||
| path=[ | ||
| "wan2.2_i2v_A14b_high_noise_lightx2v_4step_1030.safetensors", | ||
| "wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors", | ||
| ], | ||
| ), | ||
| parallelism=1, | ||
| ) | ||
| pipe = WanDMDPipeline.from_pretrained(config) | ||
|
|
||
| image = Image.open("input/wan_i2v_input.jpg").convert("RGB") | ||
| video = pipe( | ||
| prompt="Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", | ||
| input_image=image, | ||
| num_frames=81, | ||
| width=480, | ||
| height=832, | ||
| seed=42, | ||
| denoising_step_list=[1000, 750, 500, 250], | ||
| ) | ||
| save_video(video, "wan_dmd_i2v.mp4", fps=pipe.get_default_fps()) | ||
|
|
||
| del pipe |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| from diffsynth_engine import WanPipelineConfig | ||
| from diffsynth_engine.pipelines import WanDMDPipeline | ||
| from diffsynth_engine.utils.download import fetch_model | ||
| from diffsynth_engine.utils.video import save_video | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| config = WanPipelineConfig.basic_config( | ||
| model_path=fetch_model( | ||
| "Wan-AI/Wan2.2-T2V-A14B-BF16", | ||
| path=[ | ||
| "high_noise_model/diffusion_pytorch_model-00001-of-00006.safetensors", | ||
| "high_noise_model/diffusion_pytorch_model-00002-of-00006.safetensors", | ||
| "high_noise_model/diffusion_pytorch_model-00003-of-00006.safetensors", | ||
| "high_noise_model/diffusion_pytorch_model-00004-of-00006.safetensors", | ||
| "high_noise_model/diffusion_pytorch_model-00005-of-00006.safetensors", | ||
| "high_noise_model/diffusion_pytorch_model-00006-of-00006.safetensors", | ||
| "low_noise_model/diffusion_pytorch_model-00001-of-00006.safetensors", | ||
| "low_noise_model/diffusion_pytorch_model-00002-of-00006.safetensors", | ||
| "low_noise_model/diffusion_pytorch_model-00003-of-00006.safetensors", | ||
| "low_noise_model/diffusion_pytorch_model-00004-of-00006.safetensors", | ||
| "low_noise_model/diffusion_pytorch_model-00005-of-00006.safetensors", | ||
| "low_noise_model/diffusion_pytorch_model-00006-of-00006.safetensors", | ||
| ], | ||
| ), | ||
| parallelism=1, | ||
| ) | ||
| pipe = WanDMDPipeline.from_pretrained(config) | ||
| pipe.load_loras_high_noise( | ||
| [(fetch_model("lightx2v/Wan2.2-Lightning", path="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/high_noise_model.safetensors"), 1.0)], | ||
| fused=False, | ||
| ) | ||
| pipe.load_loras_low_noise( | ||
| [(fetch_model("lightx2v/Wan2.2-Lightning", path="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/low_noise_model.safetensors"), 1.0)], | ||
| fused=False, | ||
| ) | ||
|
|
||
| video = pipe( | ||
| prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", | ||
| num_frames=81, | ||
| width=480, | ||
| height=832, | ||
| seed=42, | ||
| denoising_step_list=[1000, 750, 500, 250], | ||
| ) | ||
| save_video(video, "wan_dmd_t2v.mp4", fps=pipe.get_default_fps()) | ||
|
|
||
| del pipe |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The class
WanDMDPipelinelacks a docstring. It's important to document what the pipeline does, how it differs from its parentWanVideoPipeline, and what "DMD" stands for (presumably Denoising Motion Diffusion). A good docstring improves code readability and maintainability.For example: