Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions diffsynth_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
SDXLImagePipeline,
FluxImagePipeline,
WanVideoPipeline,
WanDMDPipeline,
QwenImagePipeline,
Hunyuan3DShapePipeline,
)
Expand Down Expand Up @@ -77,6 +78,7 @@
"FluxIPAdapter",
"FluxRedux",
"WanVideoPipeline",
"WanDMDPipeline",
"QwenImagePipeline",
"Hunyuan3DShapePipeline",
"FluxInpaintingTool",
Expand Down
2 changes: 2 additions & 0 deletions diffsynth_engine/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .sd_image import SDImagePipeline
from .wan_video import WanVideoPipeline
from .wan_s2v import WanSpeech2VideoPipeline
from .wan_dmd import WanDMDPipeline
from .qwen_image import QwenImagePipeline
from .hunyuan3d_shape import Hunyuan3DShapePipeline

Expand All @@ -15,6 +16,7 @@
"SDImagePipeline",
"WanVideoPipeline",
"WanSpeech2VideoPipeline",
"WanDMDPipeline",
"QwenImagePipeline",
"Hunyuan3DShapePipeline",
]
2 changes: 1 addition & 1 deletion diffsynth_engine/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def load_lora(self, path: str, scale: float, fused: bool = True, save_original_w
self.load_loras([(path, scale)], fused, save_original_weight)

def apply_scheduler_config(self, scheduler_config: Dict):
pass
self.noise_scheduler.update_config(scheduler_config)

def unload_loras(self):
raise NotImplementedError()
Expand Down
8 changes: 2 additions & 6 deletions diffsynth_engine/pipelines/qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import torch
import torch.distributed as dist
import math
import sys
from typing import Callable, List, Dict, Tuple, Optional, Union
from tqdm import tqdm
from einops import rearrange
Expand Down Expand Up @@ -45,7 +44,6 @@
logger = logging.get_logger(__name__)



class QwenImageLoRAConverter(LoRAStateDictConverter):
def _from_diffsynth(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
dit_dict = {}
Expand Down Expand Up @@ -205,7 +203,7 @@ def _setup_nunchaku_config(
else:
config.use_nunchaku_attn = False
logger.info("Disable nunchaku attention quantization.")

else:
config.use_nunchaku = False

Expand Down Expand Up @@ -318,6 +316,7 @@ def _from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePip
elif config.use_nunchaku:
if not NUNCHAKU_AVAILABLE:
from diffsynth_engine.utils.flag import NUNCHAKU_IMPORT_ERROR

raise ImportError(NUNCHAKU_IMPORT_ERROR)

from diffsynth_engine.models.qwen_image import QwenImageDiTNunchaku
Expand Down Expand Up @@ -393,9 +392,6 @@ def unload_loras(self):
self.dit.unload_loras()
self.noise_scheduler.restore_config()

def apply_scheduler_config(self, scheduler_config: Dict):
self.noise_scheduler.update_config(scheduler_config)

def prepare_latents(
self,
latents: torch.Tensor,
Expand Down
111 changes: 111 additions & 0 deletions diffsynth_engine/pipelines/wan_dmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import torch
import torch.distributed as dist
from typing import Callable, List, Optional
from tqdm import tqdm
from PIL import Image

from diffsynth_engine.pipelines.wan_video import WanVideoPipeline


class WanDMDPipeline(WanVideoPipeline):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The class WanDMDPipeline lacks a docstring. It's important to document what the pipeline does, how it differs from its parent WanVideoPipeline, and what "DMD" stands for (presumably Denoising Motion Diffusion). A good docstring improves code readability and maintainability.

For example:

class WanDMDPipeline(WanVideoPipeline):
    """
    A pipeline for Denoising Motion Diffusion (DMD) for video generation, inheriting from WanVideoPipeline.

    This pipeline uses a specific denoising schedule controlled by `denoising_step_list`
    and does not use classifier-free guidance (negative prompts are ignored).
    """

def prepare_latents(
self,
latents,
denoising_step_list,
):
height, width = latents.shape[-2:]
height, width = height * self.upsampling_factor, width * self.upsampling_factor
sigmas, timesteps = self.noise_scheduler.schedule(num_inference_steps=1000)
sigmas = sigmas[[1000 - t for t in denoising_step_list] + [-1]]
timesteps = timesteps[[1000 - t for t in denoising_step_list]]
init_latents = latents.clone()

return init_latents, latents, sigmas, timesteps

@torch.no_grad()
def __call__(
self,
prompt,
input_image: Image.Image | None = None,
seed=None,
height=480,
width=832,
num_frames=81,
denoising_step_list: List[int] = None,
progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
):
Comment on lines +26 to +36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __call__ method is the main entry point for the pipeline but lacks a docstring. Please add one to explain the parameters, especially denoising_step_list, and what the method returns. It's also crucial to mention that this pipeline variant does not support negative prompts and uses a fixed cfg_scale of 1.0.

Here is a suggestion:

    @torch.no_grad()
    def __call__(
        self,
        prompt,
        input_image: Image.Image | None = None,
        seed=None,
        height=480,
        width=832,
        num_frames=81,
        denoising_step_list: List[int] = None,
        progress_callback: Optional[Callable] = None,  # def progress_callback(current, total, status)
    ):
        """
        Generates a video based on a prompt and an optional input image using the DMD method.

        Args:
            prompt (str): The text prompt to guide video generation.
            input_image (Image.Image | None, optional): An optional input image for image-to-video generation. Defaults to None.
            seed (int, optional): Random seed for noise generation. Defaults to None.
            height (int, optional): Height of the output video. Defaults to 480.
            width (int, optional): Width of the output video. Defaults to 832.
            num_frames (int, optional): Number of frames in the output video. Must be `4*k + 1`. Defaults to 81.
            denoising_step_list (List[int], optional): A list of timesteps for the denoising process, selected from a 1000-step schedule. Defaults to `[1000, 750, 500, 250]`.
            progress_callback (Optional[Callable], optional): A callback function for progress updates. Defaults to None.

        Returns:
            List[Image.Image]: A list of PIL Images representing the generated video frames.
        
        Note:
            This pipeline does not use classifier-free guidance; `cfg_scale` is fixed to 1.0 and negative prompts are ignored.
        """

denoising_step_list = [1000, 750, 500, 250] if denoising_step_list is None else denoising_step_list
divisor = 32 if self.vae.z_dim == 48 else 16 # 32 for wan2.2 vae, 16 for wan2.1 vae
assert height % divisor == 0 and width % divisor == 0, f"height and width must be divisible by {divisor}"
assert (num_frames - 1) % 4 == 0, "num_frames must be 4X+1"

# Initialize noise
if dist.is_initialized() and seed is None:
raise ValueError("must provide a seed when parallelism is enabled")
noise = self.generate_noise(
(
1,
self.vae.z_dim,
(num_frames - 1) // 4 + 1,
height // self.upsampling_factor,
width // self.upsampling_factor,
),
seed=seed,
device="cpu",
dtype=torch.float32,
).to(self.device)
init_latents, latents, sigmas, timesteps = self.prepare_latents(noise, denoising_step_list)
mask = torch.ones((1, 1, *latents.shape[2:]), dtype=latents.dtype, device=latents.device)

# Encode prompts
self.load_models_to_device(["text_encoder"])
prompt_emb_posi = self.encode_prompt(prompt)
prompt_emb_nega = None

# Encode image
image_clip_feature = self.encode_clip_feature(input_image, height, width)
image_y = self.encode_vae_feature(input_image, num_frames, height, width)
image_latents = self.encode_image_latents(input_image, height, width)
if image_latents is not None:
latents[:, :, : image_latents.shape[2], :, :] = image_latents
init_latents = latents.clone()
mask[:, :, : image_latents.shape[2], :, :] = 0

# Initialize sampler
self.sampler.initialize(sigmas=sigmas)

# Denoise
hide_progress = dist.is_initialized() and dist.get_rank() != 0
for i, timestep in enumerate(tqdm(timesteps, disable=hide_progress)):
if timestep.item() / 1000 >= self.config.boundary:
self.load_models_to_device(["dit"])
model = self.dit
else:
self.load_models_to_device(["dit2"])
model = self.dit2

timestep = timestep * mask[:, :, :, ::2, ::2].flatten() # seq_len
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment # seq_len is not very descriptive. It doesn't explain why the timestep is being multiplied by the flattened mask. Please provide a more informative comment that clarifies the purpose of this operation, which appears to be for applying a spatially-varying denoising schedule.

Suggested change
timestep = timestep * mask[:, :, :, ::2, ::2].flatten() # seq_len
timestep = timestep * mask[:, :, :, ::2, ::2].flatten() # Apply mask for a spatially-varying denoising schedule

timestep = timestep.to(dtype=self.dtype, device=self.device)
# Classifier-free guidance
noise_pred = self.predict_noise_with_cfg(
model=model,
latents=latents,
timestep=timestep,
positive_prompt_emb=prompt_emb_posi,
negative_prompt_emb=prompt_emb_nega,
image_clip_feature=image_clip_feature,
image_y=image_y,
cfg_scale=1.0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The cfg_scale is hardcoded to 1.0. This disables classifier-free guidance and means negative prompts have no effect. While this might be intentional for this specific pipeline, it reduces flexibility. Consider making cfg_scale a parameter of the __call__ method with a default of 1.0, and document this behavior. This would make the pipeline more versatile and its behavior clearer to users. If it must be fixed, a comment explaining why would be beneficial.

batch_cfg=self.config.batch_cfg,
)
# Scheduler
latents = self.sampler.step(latents, noise_pred, i)
latents = latents * mask + init_latents * (1 - mask)
if progress_callback is not None:
progress_callback(i + 1, len(timesteps), "DENOISING")

# Decode
self.load_models_to_device(["vae"])
frames = self.decode_video(latents, progress_callback=progress_callback)
frames = self.vae_output_to_image(frames)
return frames
42 changes: 35 additions & 7 deletions diffsynth_engine/pipelines/wan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,24 @@ def _from_diffsynth(self, state_dict):
dit_dict[key] = lora_args
return {"dit": dit_dict}

def _from_diffusers(self, state_dict):
dit_dict = {}
for key, param in state_dict.items():
if ".lora_down.weight" not in key:
continue

lora_args = {}
lora_args["up"] = state_dict[key.replace(".lora_down.weight", ".lora_up.weight")]
lora_args["down"] = param
lora_args["rank"] = lora_args["up"].shape[1]
if key.replace(".lora_down.weight", ".alpha") in state_dict:
lora_args["alpha"] = state_dict[key.replace(".lora_down.weight", ".alpha")]
else:
lora_args["alpha"] = lora_args["rank"]
key = key.replace("diffusion_model.", "").replace(".lora_down.weight", "")
dit_dict[key] = lora_args
return {"dit": dit_dict}
Comment on lines +46 to +62
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new method _from_diffusers contains logic for parsing LoRA arguments (lines 52-59) that is identical to the logic in the existing _from_fun method (lines 88-95 in the full file). This code duplication can make maintenance harder. Consider refactoring this common logic into a private helper method to improve code reuse and readability.


def _from_civitai(self, state_dict):
dit_dict = {}
for key, param in state_dict.items():
Expand Down Expand Up @@ -86,6 +104,9 @@ def convert(self, state_dict):
if "lora_unet_blocks_0_cross_attn_k.lora_down.weight" in state_dict:
state_dict = self._from_fun(state_dict)
logger.info("use fun format state dict")
elif "diffusion_model.blocks.0.cross_attn.k.lora_down.weight" in state_dict:
state_dict = self._from_diffusers(state_dict)
logger.info("use diffusers format state dict")
elif "diffusion_model.blocks.0.cross_attn.k.lora_A.weight" in state_dict:
state_dict = self._from_civitai(state_dict)
logger.info("use civitai format state dict")
Expand Down Expand Up @@ -480,8 +501,8 @@ def from_pretrained(cls, model_path_or_config: WanPipelineConfig) -> "WanVideoPi

dit_state_dict, dit2_state_dict = None, None
if isinstance(config.model_path, list):
high_noise_model_ckpt = [path for path in config.model_path if "high_noise_model" in path]
low_noise_model_ckpt = [path for path in config.model_path if "low_noise_model" in path]
high_noise_model_ckpt = [path for path in config.model_path if "high_noise" in path]
low_noise_model_ckpt = [path for path in config.model_path if "low_noise" in path]
if high_noise_model_ckpt and low_noise_model_ckpt:
logger.info(f"loading high noise model state dict from {high_noise_model_ckpt} ...")
dit_state_dict = cls.load_model_checkpoint(
Expand Down Expand Up @@ -681,8 +702,9 @@ def has_any_key(*xs):
config.attn_params = VideoSparseAttentionParams(sparsity=0.9)

def update_weights(self, state_dicts: WanStateDicts) -> None:
is_dual_model_state_dict = (isinstance(state_dicts.model, dict) and
("high_noise_model" in state_dicts.model or "low_noise_model" in state_dicts.model))
is_dual_model_state_dict = isinstance(state_dicts.model, dict) and (
"high_noise_model" in state_dicts.model or "low_noise_model" in state_dicts.model
)
is_dual_model_pipeline = self.dit2 is not None

if is_dual_model_state_dict != is_dual_model_pipeline:
Expand All @@ -694,15 +716,21 @@ def update_weights(self, state_dicts: WanStateDicts) -> None:

if is_dual_model_state_dict:
if "high_noise_model" in state_dicts.model:
self.update_component(self.dit, state_dicts.model["high_noise_model"], self.config.device, self.config.model_dtype)
self.update_component(
self.dit, state_dicts.model["high_noise_model"], self.config.device, self.config.model_dtype
)
if "low_noise_model" in state_dicts.model:
self.update_component(self.dit2, state_dicts.model["low_noise_model"], self.config.device, self.config.model_dtype)
self.update_component(
self.dit2, state_dicts.model["low_noise_model"], self.config.device, self.config.model_dtype
)
else:
self.update_component(self.dit, state_dicts.model, self.config.device, self.config.model_dtype)

self.update_component(self.text_encoder, state_dicts.t5, self.config.device, self.config.t5_dtype)
self.update_component(self.vae, state_dicts.vae, self.config.device, self.config.vae_dtype)
self.update_component(self.image_encoder, state_dicts.image_encoder, self.config.device, self.config.image_encoder_dtype)
self.update_component(
self.image_encoder, state_dicts.image_encoder, self.config.device, self.config.image_encoder_dtype
)

def compile(self):
self.dit.compile_repeated_blocks()
Expand Down
34 changes: 34 additions & 0 deletions examples/wan_dmd_image_to_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from PIL import Image

from diffsynth_engine import WanPipelineConfig
from diffsynth_engine.pipelines import WanDMDPipeline
from diffsynth_engine.utils.download import fetch_model
from diffsynth_engine.utils.video import save_video


if __name__ == "__main__":
config = WanPipelineConfig.basic_config(
model_path=fetch_model(
"lightx2v/Wan2.2-Distill-Models",
path=[
"wan2.2_i2v_A14b_high_noise_lightx2v_4step_1030.safetensors",
"wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors",
],
),
parallelism=1,
)
pipe = WanDMDPipeline.from_pretrained(config)

image = Image.open("input/wan_i2v_input.jpg").convert("RGB")
video = pipe(
prompt="Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
input_image=image,
num_frames=81,
width=480,
height=832,
seed=42,
denoising_step_list=[1000, 750, 500, 250],
)
save_video(video, "wan_dmd_i2v.mp4", fps=pipe.get_default_fps())

del pipe
48 changes: 48 additions & 0 deletions examples/wan_dmd_text_to_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from diffsynth_engine import WanPipelineConfig
from diffsynth_engine.pipelines import WanDMDPipeline
from diffsynth_engine.utils.download import fetch_model
from diffsynth_engine.utils.video import save_video


if __name__ == "__main__":
config = WanPipelineConfig.basic_config(
model_path=fetch_model(
"Wan-AI/Wan2.2-T2V-A14B-BF16",
path=[
"high_noise_model/diffusion_pytorch_model-00001-of-00006.safetensors",
"high_noise_model/diffusion_pytorch_model-00002-of-00006.safetensors",
"high_noise_model/diffusion_pytorch_model-00003-of-00006.safetensors",
"high_noise_model/diffusion_pytorch_model-00004-of-00006.safetensors",
"high_noise_model/diffusion_pytorch_model-00005-of-00006.safetensors",
"high_noise_model/diffusion_pytorch_model-00006-of-00006.safetensors",
"low_noise_model/diffusion_pytorch_model-00001-of-00006.safetensors",
"low_noise_model/diffusion_pytorch_model-00002-of-00006.safetensors",
"low_noise_model/diffusion_pytorch_model-00003-of-00006.safetensors",
"low_noise_model/diffusion_pytorch_model-00004-of-00006.safetensors",
"low_noise_model/diffusion_pytorch_model-00005-of-00006.safetensors",
"low_noise_model/diffusion_pytorch_model-00006-of-00006.safetensors",
],
),
parallelism=1,
)
pipe = WanDMDPipeline.from_pretrained(config)
pipe.load_loras_high_noise(
[(fetch_model("lightx2v/Wan2.2-Lightning", path="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/high_noise_model.safetensors"), 1.0)],
fused=False,
)
pipe.load_loras_low_noise(
[(fetch_model("lightx2v/Wan2.2-Lightning", path="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/low_noise_model.safetensors"), 1.0)],
fused=False,
)

video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
num_frames=81,
width=480,
height=832,
seed=42,
denoising_step_list=[1000, 750, 500, 250],
)
save_video(video, "wan_dmd_t2v.mp4", fps=pipe.get_default_fps())

del pipe