diff --git a/diffsynth_engine/__init__.py b/diffsynth_engine/__init__.py index c3b1e96..a024e74 100644 --- a/diffsynth_engine/__init__.py +++ b/diffsynth_engine/__init__.py @@ -25,6 +25,7 @@ SDXLImagePipeline, FluxImagePipeline, WanVideoPipeline, + WanDMDPipeline, QwenImagePipeline, Hunyuan3DShapePipeline, ) @@ -77,6 +78,7 @@ "FluxIPAdapter", "FluxRedux", "WanVideoPipeline", + "WanDMDPipeline", "QwenImagePipeline", "Hunyuan3DShapePipeline", "FluxInpaintingTool", diff --git a/diffsynth_engine/pipelines/__init__.py b/diffsynth_engine/pipelines/__init__.py index dd3d705..5a8b35e 100644 --- a/diffsynth_engine/pipelines/__init__.py +++ b/diffsynth_engine/pipelines/__init__.py @@ -4,6 +4,7 @@ from .sd_image import SDImagePipeline from .wan_video import WanVideoPipeline from .wan_s2v import WanSpeech2VideoPipeline +from .wan_dmd import WanDMDPipeline from .qwen_image import QwenImagePipeline from .hunyuan3d_shape import Hunyuan3DShapePipeline @@ -15,6 +16,7 @@ "SDImagePipeline", "WanVideoPipeline", "WanSpeech2VideoPipeline", + "WanDMDPipeline", "QwenImagePipeline", "Hunyuan3DShapePipeline", ] diff --git a/diffsynth_engine/pipelines/base.py b/diffsynth_engine/pipelines/base.py index a836efc..c63abbf 100644 --- a/diffsynth_engine/pipelines/base.py +++ b/diffsynth_engine/pipelines/base.py @@ -145,7 +145,7 @@ def load_lora(self, path: str, scale: float, fused: bool = True, save_original_w self.load_loras([(path, scale)], fused, save_original_weight) def apply_scheduler_config(self, scheduler_config: Dict): - pass + self.noise_scheduler.update_config(scheduler_config) def unload_loras(self): raise NotImplementedError() diff --git a/diffsynth_engine/pipelines/qwen_image.py b/diffsynth_engine/pipelines/qwen_image.py index 5cd90a3..afae999 100644 --- a/diffsynth_engine/pipelines/qwen_image.py +++ b/diffsynth_engine/pipelines/qwen_image.py @@ -2,7 +2,6 @@ import torch import torch.distributed as dist import math -import sys from typing import Callable, List, Dict, Tuple, Optional, Union from tqdm import tqdm from einops import rearrange @@ -45,7 +44,6 @@ logger = logging.get_logger(__name__) - class QwenImageLoRAConverter(LoRAStateDictConverter): def _from_diffsynth(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]: dit_dict = {} @@ -205,7 +203,7 @@ def _setup_nunchaku_config( else: config.use_nunchaku_attn = False logger.info("Disable nunchaku attention quantization.") - + else: config.use_nunchaku = False @@ -318,6 +316,7 @@ def _from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePip elif config.use_nunchaku: if not NUNCHAKU_AVAILABLE: from diffsynth_engine.utils.flag import NUNCHAKU_IMPORT_ERROR + raise ImportError(NUNCHAKU_IMPORT_ERROR) from diffsynth_engine.models.qwen_image import QwenImageDiTNunchaku @@ -393,9 +392,6 @@ def unload_loras(self): self.dit.unload_loras() self.noise_scheduler.restore_config() - def apply_scheduler_config(self, scheduler_config: Dict): - self.noise_scheduler.update_config(scheduler_config) - def prepare_latents( self, latents: torch.Tensor, diff --git a/diffsynth_engine/pipelines/wan_dmd.py b/diffsynth_engine/pipelines/wan_dmd.py new file mode 100644 index 0000000..ecf077f --- /dev/null +++ b/diffsynth_engine/pipelines/wan_dmd.py @@ -0,0 +1,111 @@ +import torch +import torch.distributed as dist +from typing import Callable, List, Optional +from tqdm import tqdm +from PIL import Image + +from diffsynth_engine.pipelines.wan_video import WanVideoPipeline + + +class WanDMDPipeline(WanVideoPipeline): + def prepare_latents( + self, + latents, + denoising_step_list, + ): + height, width = latents.shape[-2:] + height, width = height * self.upsampling_factor, width * self.upsampling_factor + sigmas, timesteps = self.noise_scheduler.schedule(num_inference_steps=1000) + sigmas = sigmas[[1000 - t for t in denoising_step_list] + [-1]] + timesteps = timesteps[[1000 - t for t in denoising_step_list]] + init_latents = latents.clone() + + return init_latents, latents, sigmas, timesteps + + @torch.no_grad() + def __call__( + self, + prompt, + input_image: Image.Image | None = None, + seed=None, + height=480, + width=832, + num_frames=81, + denoising_step_list: List[int] = None, + progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status) + ): + denoising_step_list = [1000, 750, 500, 250] if denoising_step_list is None else denoising_step_list + divisor = 32 if self.vae.z_dim == 48 else 16 # 32 for wan2.2 vae, 16 for wan2.1 vae + assert height % divisor == 0 and width % divisor == 0, f"height and width must be divisible by {divisor}" + assert (num_frames - 1) % 4 == 0, "num_frames must be 4X+1" + + # Initialize noise + if dist.is_initialized() and seed is None: + raise ValueError("must provide a seed when parallelism is enabled") + noise = self.generate_noise( + ( + 1, + self.vae.z_dim, + (num_frames - 1) // 4 + 1, + height // self.upsampling_factor, + width // self.upsampling_factor, + ), + seed=seed, + device="cpu", + dtype=torch.float32, + ).to(self.device) + init_latents, latents, sigmas, timesteps = self.prepare_latents(noise, denoising_step_list) + mask = torch.ones((1, 1, *latents.shape[2:]), dtype=latents.dtype, device=latents.device) + + # Encode prompts + self.load_models_to_device(["text_encoder"]) + prompt_emb_posi = self.encode_prompt(prompt) + prompt_emb_nega = None + + # Encode image + image_clip_feature = self.encode_clip_feature(input_image, height, width) + image_y = self.encode_vae_feature(input_image, num_frames, height, width) + image_latents = self.encode_image_latents(input_image, height, width) + if image_latents is not None: + latents[:, :, : image_latents.shape[2], :, :] = image_latents + init_latents = latents.clone() + mask[:, :, : image_latents.shape[2], :, :] = 0 + + # Initialize sampler + self.sampler.initialize(sigmas=sigmas) + + # Denoise + hide_progress = dist.is_initialized() and dist.get_rank() != 0 + for i, timestep in enumerate(tqdm(timesteps, disable=hide_progress)): + if timestep.item() / 1000 >= self.config.boundary: + self.load_models_to_device(["dit"]) + model = self.dit + else: + self.load_models_to_device(["dit2"]) + model = self.dit2 + + timestep = timestep * mask[:, :, :, ::2, ::2].flatten() # seq_len + timestep = timestep.to(dtype=self.dtype, device=self.device) + # Classifier-free guidance + noise_pred = self.predict_noise_with_cfg( + model=model, + latents=latents, + timestep=timestep, + positive_prompt_emb=prompt_emb_posi, + negative_prompt_emb=prompt_emb_nega, + image_clip_feature=image_clip_feature, + image_y=image_y, + cfg_scale=1.0, + batch_cfg=self.config.batch_cfg, + ) + # Scheduler + latents = self.sampler.step(latents, noise_pred, i) + latents = latents * mask + init_latents * (1 - mask) + if progress_callback is not None: + progress_callback(i + 1, len(timesteps), "DENOISING") + + # Decode + self.load_models_to_device(["vae"]) + frames = self.decode_video(latents, progress_callback=progress_callback) + frames = self.vae_output_to_image(frames) + return frames diff --git a/diffsynth_engine/pipelines/wan_video.py b/diffsynth_engine/pipelines/wan_video.py index 116effb..6053141 100644 --- a/diffsynth_engine/pipelines/wan_video.py +++ b/diffsynth_engine/pipelines/wan_video.py @@ -43,6 +43,24 @@ def _from_diffsynth(self, state_dict): dit_dict[key] = lora_args return {"dit": dit_dict} + def _from_diffusers(self, state_dict): + dit_dict = {} + for key, param in state_dict.items(): + if ".lora_down.weight" not in key: + continue + + lora_args = {} + lora_args["up"] = state_dict[key.replace(".lora_down.weight", ".lora_up.weight")] + lora_args["down"] = param + lora_args["rank"] = lora_args["up"].shape[1] + if key.replace(".lora_down.weight", ".alpha") in state_dict: + lora_args["alpha"] = state_dict[key.replace(".lora_down.weight", ".alpha")] + else: + lora_args["alpha"] = lora_args["rank"] + key = key.replace("diffusion_model.", "").replace(".lora_down.weight", "") + dit_dict[key] = lora_args + return {"dit": dit_dict} + def _from_civitai(self, state_dict): dit_dict = {} for key, param in state_dict.items(): @@ -86,6 +104,9 @@ def convert(self, state_dict): if "lora_unet_blocks_0_cross_attn_k.lora_down.weight" in state_dict: state_dict = self._from_fun(state_dict) logger.info("use fun format state dict") + elif "diffusion_model.blocks.0.cross_attn.k.lora_down.weight" in state_dict: + state_dict = self._from_diffusers(state_dict) + logger.info("use diffusers format state dict") elif "diffusion_model.blocks.0.cross_attn.k.lora_A.weight" in state_dict: state_dict = self._from_civitai(state_dict) logger.info("use civitai format state dict") @@ -480,8 +501,8 @@ def from_pretrained(cls, model_path_or_config: WanPipelineConfig) -> "WanVideoPi dit_state_dict, dit2_state_dict = None, None if isinstance(config.model_path, list): - high_noise_model_ckpt = [path for path in config.model_path if "high_noise_model" in path] - low_noise_model_ckpt = [path for path in config.model_path if "low_noise_model" in path] + high_noise_model_ckpt = [path for path in config.model_path if "high_noise" in path] + low_noise_model_ckpt = [path for path in config.model_path if "low_noise" in path] if high_noise_model_ckpt and low_noise_model_ckpt: logger.info(f"loading high noise model state dict from {high_noise_model_ckpt} ...") dit_state_dict = cls.load_model_checkpoint( @@ -681,8 +702,9 @@ def has_any_key(*xs): config.attn_params = VideoSparseAttentionParams(sparsity=0.9) def update_weights(self, state_dicts: WanStateDicts) -> None: - is_dual_model_state_dict = (isinstance(state_dicts.model, dict) and - ("high_noise_model" in state_dicts.model or "low_noise_model" in state_dicts.model)) + is_dual_model_state_dict = isinstance(state_dicts.model, dict) and ( + "high_noise_model" in state_dicts.model or "low_noise_model" in state_dicts.model + ) is_dual_model_pipeline = self.dit2 is not None if is_dual_model_state_dict != is_dual_model_pipeline: @@ -694,15 +716,21 @@ def update_weights(self, state_dicts: WanStateDicts) -> None: if is_dual_model_state_dict: if "high_noise_model" in state_dicts.model: - self.update_component(self.dit, state_dicts.model["high_noise_model"], self.config.device, self.config.model_dtype) + self.update_component( + self.dit, state_dicts.model["high_noise_model"], self.config.device, self.config.model_dtype + ) if "low_noise_model" in state_dicts.model: - self.update_component(self.dit2, state_dicts.model["low_noise_model"], self.config.device, self.config.model_dtype) + self.update_component( + self.dit2, state_dicts.model["low_noise_model"], self.config.device, self.config.model_dtype + ) else: self.update_component(self.dit, state_dicts.model, self.config.device, self.config.model_dtype) self.update_component(self.text_encoder, state_dicts.t5, self.config.device, self.config.t5_dtype) self.update_component(self.vae, state_dicts.vae, self.config.device, self.config.vae_dtype) - self.update_component(self.image_encoder, state_dicts.image_encoder, self.config.device, self.config.image_encoder_dtype) + self.update_component( + self.image_encoder, state_dicts.image_encoder, self.config.device, self.config.image_encoder_dtype + ) def compile(self): self.dit.compile_repeated_blocks() diff --git a/examples/wan_dmd_image_to_video.py b/examples/wan_dmd_image_to_video.py new file mode 100644 index 0000000..bf2ede0 --- /dev/null +++ b/examples/wan_dmd_image_to_video.py @@ -0,0 +1,34 @@ +from PIL import Image + +from diffsynth_engine import WanPipelineConfig +from diffsynth_engine.pipelines import WanDMDPipeline +from diffsynth_engine.utils.download import fetch_model +from diffsynth_engine.utils.video import save_video + + +if __name__ == "__main__": + config = WanPipelineConfig.basic_config( + model_path=fetch_model( + "lightx2v/Wan2.2-Distill-Models", + path=[ + "wan2.2_i2v_A14b_high_noise_lightx2v_4step_1030.safetensors", + "wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors", + ], + ), + parallelism=1, + ) + pipe = WanDMDPipeline.from_pretrained(config) + + image = Image.open("input/wan_i2v_input.jpg").convert("RGB") + video = pipe( + prompt="Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", + input_image=image, + num_frames=81, + width=480, + height=832, + seed=42, + denoising_step_list=[1000, 750, 500, 250], + ) + save_video(video, "wan_dmd_i2v.mp4", fps=pipe.get_default_fps()) + + del pipe diff --git a/examples/wan_dmd_text_to_video.py b/examples/wan_dmd_text_to_video.py new file mode 100644 index 0000000..66c0bc7 --- /dev/null +++ b/examples/wan_dmd_text_to_video.py @@ -0,0 +1,48 @@ +from diffsynth_engine import WanPipelineConfig +from diffsynth_engine.pipelines import WanDMDPipeline +from diffsynth_engine.utils.download import fetch_model +from diffsynth_engine.utils.video import save_video + + +if __name__ == "__main__": + config = WanPipelineConfig.basic_config( + model_path=fetch_model( + "Wan-AI/Wan2.2-T2V-A14B-BF16", + path=[ + "high_noise_model/diffusion_pytorch_model-00001-of-00006.safetensors", + "high_noise_model/diffusion_pytorch_model-00002-of-00006.safetensors", + "high_noise_model/diffusion_pytorch_model-00003-of-00006.safetensors", + "high_noise_model/diffusion_pytorch_model-00004-of-00006.safetensors", + "high_noise_model/diffusion_pytorch_model-00005-of-00006.safetensors", + "high_noise_model/diffusion_pytorch_model-00006-of-00006.safetensors", + "low_noise_model/diffusion_pytorch_model-00001-of-00006.safetensors", + "low_noise_model/diffusion_pytorch_model-00002-of-00006.safetensors", + "low_noise_model/diffusion_pytorch_model-00003-of-00006.safetensors", + "low_noise_model/diffusion_pytorch_model-00004-of-00006.safetensors", + "low_noise_model/diffusion_pytorch_model-00005-of-00006.safetensors", + "low_noise_model/diffusion_pytorch_model-00006-of-00006.safetensors", + ], + ), + parallelism=1, + ) + pipe = WanDMDPipeline.from_pretrained(config) + pipe.load_loras_high_noise( + [(fetch_model("lightx2v/Wan2.2-Lightning", path="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/high_noise_model.safetensors"), 1.0)], + fused=False, + ) + pipe.load_loras_low_noise( + [(fetch_model("lightx2v/Wan2.2-Lightning", path="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/low_noise_model.safetensors"), 1.0)], + fused=False, + ) + + video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + num_frames=81, + width=480, + height=832, + seed=42, + denoising_step_list=[1000, 750, 500, 250], + ) + save_video(video, "wan_dmd_t2v.mp4", fps=pipe.get_default_fps()) + + del pipe