-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Add Wan2.2-S2V: Audio-Driven Cinematic Video Generation #12258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add Wan2.2-S2V: Audio-Driven Cinematic Video Generation #12258
Conversation
…date example imports Add unit tests for WanSpeechToVideoPipeline and WanS2VTransformer3DModel and gguf
The previous audio encoding logic was a placeholder. It is now replaced with a `Wav2Vec2ForCTC` model and processor, including the full implementation for processing audio inputs. This involves resampling and aligning audio features with video frames to ensure proper synchronization. Additionally, utility functions for loading audio from files or URLs are added, and the `audio_processor` module is refactored to correctly handle audio data types instead of image types.
Introduces support for audio and pose conditioning, replacing the previous image conditioning mechanism. The model now accepts audio embeddings and pose latents as input. This change also adds two new, mutually exclusive motion processing modules: - `MotionerTransformers`: A transformer-based module for encoding motion. - `FramePackMotioner`: A module that packs frames from different temporal buckets for motion representation. Additionally, an `AudioInjector` module is implemented to fuse audio features into specific transformer blocks using cross-attention.
The `MotionerTransformers` module is removed and its functionality is replaced by a `FramePackMotioner` module and a simplified standard motion processing pipeline. The codebase is refactored to remove the `einops` dependency, replacing `rearrange` operations with standard PyTorch tensor manipulations for better code consistency. Additionally, `AdaLayerNorm` is introduced for improved conditioning, and helper functions for Rotary Positional Embeddings (RoPE) are added (probably temporarily) and refactored for clarity and flexibility. The audio injection mechanism is also updated to align with the new model structure.
Removes the calculation of several unused variables and an unnecessary `deepcopy` operation on the latents tensor. This change also removes the now-unused `deepcopy` import, simplifying the overall logic.
Refactors the `WanS2VTransformer3DModel` for clarity and better handling of various conditioning inputs like audio, pose, and motion. Key changes: - Simplifies the `WanS2VTransformerBlock` by removing projection layers and streamlining the forward pass. - Introduces `after_transformer_block` to cleanly inject audio information after each transformer block, improving code organization. - Enhances the main `forward` method to better process and combine multiple conditioning signals (image, audio, motion) before the transformer blocks. - Adds support for a zero-value timestep to differentiate between image and video latents. - Generalizes temporal embedding logic to support multiple model variations.
Introduces the necessary configurations and state dictionary key mappings to enable the conversion of S2V model checkpoints to the Diffusers format. This includes: - A new transformer configuration for the S2V model architecture, including parameters for audio and pose conditioning. - A comprehensive rename dictionary to map the original S2V layer names to their Diffusers equivalents.
…heads in transformer configuration
|
@tolgacangoz I truly appreciate your great work, and apologize for the late response error stack traceBy slightly modifying the keys of the transformer and key name remapimport os
import json
from safetensors import safe_open
from safetensors.torch import save_file
from typing import Dict
def remap_safetensors_keys(transformer_dir_path: str, key_remap_dict: Dict[str, str]):
"""
Remap keys in safetensors files and save to a new file
Args:
transformer_dir_path (str): Path to the transformer directory
key_remap_dict (Dict[str, str]): Dictionary for key mapping {old_key: new_key}
"""
# Load index file
index_json_path = os.path.join(transformer_dir_path, 'diffusion_pytorch_model.safetensors.index.json')
with open(index_json_path, 'r') as f:
index_data = json.load(f)
# Get weight_map (this is where the actual key->file mapping is)
weight_map = index_data.get("weight_map", {})
# Build a map from files to keys that need remapping
file_path_key_remap_dict = {}
new_weight_map = {}
remapped_count = 0
# Process each key in weight_map
for key, file_name in weight_map.items():
if key in key_remap_dict:
new_key = key_remap_dict[key]
new_weight_map[new_key] = file_name
# Add to file-specific remap dict
if file_name not in file_path_key_remap_dict:
file_path_key_remap_dict[file_name] = {}
file_path_key_remap_dict[file_name][key] = new_key
else:
# Keep original key
new_weight_map[key] = file_name
# Update index file with new weight_map
index_data["weight_map"] = new_weight_map
with open(index_json_path, 'w') as f:
json.dump(index_data, f, indent=2)
# Process each safetensors file
for file_name, remap_dict in file_path_key_remap_dict.items():
file_path = os.path.join(transformer_dir_path, file_name)
print(f"Loading tensors from {file_name}...")
# Read all tensors
tensors_dict = {}
with safe_open(file_path, framework='pt', device='cpu') as f:
keys = list(f.keys())
print(f"Total keys found: {len(keys)}")
for key in keys:
if key in remap_dict:
new_key = remap_dict[key]
print(f"Remapping: {key} -> {new_key}")
tensors_dict[new_key] = f.get_tensor(key)
remapped_count += 1
else:
tensors_dict[key] = f.get_tensor(key)
print(f"Saving remapped tensors to {file_path}...")
save_file(tensors_dict, file_path)
print("Done!")
return remapped_count
# key mapping dictionary
key_remap_dict = {
"condition_embedder.causal_audio_encoder.encoder.conv2.conv.weight": "condition_embedder.causal_audio_encoder.encoder.conv2.conv.conv.weight",
"condition_embedder.causal_audio_encoder.encoder.conv3.conv.bias": "condition_embedder.causal_audio_encoder.encoder.conv3.conv.conv.bias",
"condition_embedder.causal_audio_encoder.weights": "condition_embedder.causal_audio_encoder.weighted_avg.weights",
"condition_embedder.causal_audio_encoder.encoder.conv3.conv.weight": "condition_embedder.causal_audio_encoder.encoder.conv3.conv.conv.weight",
"condition_embedder.causal_audio_encoder.encoder.conv2.conv.bias": "condition_embedder.causal_audio_encoder.encoder.conv2.conv.conv.bias"
}
# path to transformer directory
transformer_dir_path = 'models/tolgacangoz/Wan2.2-S2V-14B-Diffusers/transformer'
# execute remapping
remapped_count = remap_safetensors_keys(transformer_dir_path, key_remap_dict)
print(f"Successfully remapped {remapped_count} keys!")
"""
Loading tensors from diffusion_pytorch_model-00001-of-00007.safetensors...
Total keys found: 199
Remapping: condition_embedder.causal_audio_encoder.encoder.conv2.conv.bias -> condition_embedder.causal_audio_encoder.encoder.conv2.conv.conv.bias
Remapping: condition_embedder.causal_audio_encoder.encoder.conv2.conv.weight -> condition_embedder.causal_audio_encoder.encoder.conv2.conv.conv.weight
Remapping: condition_embedder.causal_audio_encoder.encoder.conv3.conv.bias -> condition_embedder.causal_audio_encoder.encoder.conv3.conv.conv.bias
Remapping: condition_embedder.causal_audio_encoder.encoder.conv3.conv.weight -> condition_embedder.causal_audio_encoder.encoder.conv3.conv.conv.weight
Remapping: condition_embedder.causal_audio_encoder.weights -> condition_embedder.causal_audio_encoder.weighted_avg.weights
Saving remapped tensors to models/tolgacangoz/Wan2.2-S2V-14B-Diffusers/transformer/diffusion_pytorch_model-00001-of-00007.safetensors...
Done!
Successfully remapped 5 keys!
"""Once again, I sincerely appreciate your wonderful dedication, and I wish you a blessed and peaceful day😄 ps. code# !pip install bitsandbytes -qU
# ... same before ...
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
model_id = "Wan-AI/Wan2.2-S2V-14B-Diffusers" # will be official
# download with `hf download tolgacangoz/Wan2.2-S2V-14B-Diffusers --local-dir models/tolgacangoz/Wan2.2-S2V-14B-Diffusers`
model_id = "models/tolgacangoz/Wan2.2-S2V-14B-Diffusers"
audio_encoder = Wav2Vec2ForCTC.from_pretrained(model_id, subfolder="audio_encoder", dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
text_encoder_quant_config = TransformersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
)
text_encoder = UMT5EncoderModel.from_pretrained(
model_id, subfolder="text_encoder", quantization_config=text_encoder_quant_config, torch_dtype=torch.bfloat16
)
transformer_quant_config = DiffusersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
)
transformer = WanS2VTransformer3DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=torch.bfloat16, quantization_config=transformer_quant_config
)
pipe = WanSpeechToVideoPipeline.from_pretrained(
model_id, vae=vae, audio_encoder=audio_encoder, transformer=transformer, text_encoder=text_encoder, torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
first_frame = load_image("https://raw.githubusercontent.com/Wan-Video/Wan2.2/refs/heads/main/examples/i2v_input.JPG")
audio, sampling_rate = load_audio("https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/talk.wav")
height, width = get_size_less_than_area(first_frame.height, first_frame.width, target_area=480*832)
prompt = "A Cat is talking."
output = pipe(
image=first_frame, audio=audio, sampling_rate=sampling_rate,
prompt=prompt, height=height, width=width, num_frames_per_chunk=80,
).frames[0]
export_to_video(output, "video.mp4", fps=16)
# ... same after ...
test.mp4 |
|
This branch is constantly changing. I put a functionally same branch in the script attached to the first message. |
| pose_video = None | ||
| if pose_video_path_or_url is not None: | ||
| pose_video = load_video( | ||
| pose_video_path_or_url, | ||
| n_frames=num_frames_per_chunk * num_chunks, | ||
| target_fps=sampling_fps, | ||
| reverse=True, | ||
| ) | ||
| pose_video = self.video_processor.preprocess_video( | ||
| pose_video, height=height, width=width, resize_mode="resize_min_center_crop" | ||
| ).to(device, dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Giving pose info as pose_video_path_or_url doesn't seem diffusers friendly, right? load_video is usually run before the pipeline is called. But in this case, we need num_chunks after it might have been updated in the lines 881-882. Is there a better way to do this?
| audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1) | ||
| return audio_embed_bucket, num_repeat | ||
|
|
||
| # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
Dear @tolgacangoz While trying to use error stack traceAfter some investigation, I found a workaround that resolved the issue on my end, so I wanted to share the changes I made in case they’re helpful. In def encode_audio(
self,
audio: PipelineAudioInput,
sampling_rate: int,
num_frames: int,
fps: int = 16,
device: Optional[torch.device] = None,
):
device = device or self._execution_device
video_rate = 30
audio_sample_m = 0
input_values = self.audio_processor(audio, sampling_rate=sampling_rate, return_tensors="pt").input_values
# retrieve logits & take argmax
- res = self.audio_encoder(input_values.to(self.audio_encoder.device), output_hidden_states=True)
+ res = self.audio_encoder(input_values.to(device), output_hidden_states=True)
feat = torch.cat(res.hidden_states)
...and in def load_pose_condition(
self, pose_video, num_chunks, num_frames_per_chunk, height, width, latents_mean, latents_std
):
+ device = self._execution_device
+ dtype = self.vae.dtype
if pose_video is not None:
padding_frame_num = num_chunks * num_frames_per_chunk - pose_video.shape[2]
- pose_video = pose_video.to(dtype=self.vae.dtype, device=self.vae.device)
+ pose_video = pose_video.to(dtype=dtype, device=device)
pose_video = torch.cat(
[
pose_video,
-torch.ones(
- [1, 3, padding_frame_num, height, width], dtype=self.vae.dtype, device=self.vae.device
+ [1, 3, padding_frame_num, height, width], dtype=dtype, device=device
),
],
dim=2,
)
pose_video = torch.chunk(pose_video, num_chunks, dim=2)
else:
pose_video = [
- -torch.ones([1, 3, num_frames_per_chunk, height, width], dtype=self.vae.dtype, device=self.vae.device)
+ -torch.ones([1, 3, num_frames_per_chunk, height, width], dtype=dtype, device=device)
]
I hope this would be a little help! |
- Updated device references in audio encoding and pose video loading to use a unified device variable. - Enhanced image preprocessing to include a resize mode option for better handling of input dimensions. Co-authored-by: Ju Hoon Park <[email protected]>
|
Thanks @J4BEZ, fixed it. |
|
@tolgacangoz Thanks! I am delighted to help Have a peaceful day! |
Added contributor information and enhanced model description.
Added project page link for Wan-S2V model and improved context.
|
|
||
| The project page: https://humanaigc.github.io/wan-s2v-webpage/ | ||
|
|
||
| This model was contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is fixing #12257.
Comparison with the original repo
When I put
with torch.amp.autocast('cuda', dtype=torch.bfloat16):onto the transformer only and converted the initial noise'sdtypeintotorch.float32fromtorch.bfloat16in the original repo, the videos seem almost the same. As far as I can see, the original repo's video has an extra blink.wan.mp4
diffusers.mp4
Try
WanSpeechToVideoPipeline!@yiyixuxu @sayakpaul @asomoza @dg845 @stevhliu
@WanX-Video-1 @Steven-SWZhang @kelseyee
@SHYuanBest @J4BEZ @okaris @xziayro-ai @teith @luke14free @lopho @arnold408