diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index ffd7385d81a5..a8eb2d28ea1c 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -372,6 +372,8 @@ title: Lumina2Transformer2DModel - local: api/models/lumina_nextdit2d title: LuminaNextDiT2DModel + - local: api/models/magi_transformer_3d + title: Magi1Transformer3DModel - local: api/models/mochi_transformer3d title: MochiTransformer3DModel - local: api/models/omnigen_transformer @@ -430,6 +432,8 @@ title: AutoencoderKLHunyuanVideo - local: api/models/autoencoderkl_ltx_video title: AutoencoderKLLTXVideo + - local: api/models/autoencoder_kl_magi1 + title: AutoencoderKLMagi1 - local: api/models/autoencoderkl_magvit title: AutoencoderKLMagvit - local: api/models/autoencoderkl_mochi diff --git a/docs/source/en/api/models/autoencoder_kl_magi1.md b/docs/source/en/api/models/autoencoder_kl_magi1.md new file mode 100644 index 000000000000..2301d08b72da --- /dev/null +++ b/docs/source/en/api/models/autoencoder_kl_magi1.md @@ -0,0 +1,34 @@ + + +# AutoencoderKLMagi1 + +The 3D variational autoencoder (VAE) model with KL loss used in [MAGI-1: Autoregressive Video Generation at Scale](https://arxiv.org/abs/2505.13211) by Sand.ai. + +MAGI-1 uses a transformer-based VAE with 8x spatial and 4x temporal compression, providing fast average decoding time and highly competitive reconstruction quality. + +The model can be loaded with the following code snippet. + +```python +from diffusers import AutoencoderKLMagi1 + +vae = AutoencoderKLMagi1.from_pretrained("sand-ai/MAGI-1", subfolder="vae", torch_dtype=torch.float32) +``` + +## AutoencoderKLMagi1 + +[[autodoc]] AutoencoderKLMagi1 + - decode + - all + +## DecoderOutput + +[[autodoc]] models.autoencoders.vae.DecoderOutput diff --git a/docs/source/en/api/models/magi1_transformer_3d.md b/docs/source/en/api/models/magi1_transformer_3d.md new file mode 100644 index 000000000000..8fb369f16253 --- /dev/null +++ b/docs/source/en/api/models/magi1_transformer_3d.md @@ -0,0 +1,32 @@ + + +# Magi1Transformer3DModel + +A Diffusion Transformer model for 3D video-like data was introduced in [MAGI-1: Autoregressive Video Generation at Scale](https://arxiv.org/abs/2505.13211) by Sand.ai. + +MAGI-1 is an autoregressive denoising video generation model that generates videos chunk-by-chunk instead of as a whole. Each chunk (24 frames) is denoised holistically, and the generation of the next chunk begins as soon as the current one reaches a certain level of denoising. + +The model can be loaded with the following code snippet. + +```python +from diffusers import Magi1Transformer3DModel + +transformer = Magi1Transformer3DModel.from_pretrained("sand-ai/MAGI-1", subfolder="transformer", torch_dtype=torch.bfloat16) +``` + +## Magi1Transformer3DModel + +[[autodoc]] Magi1Transformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput \ No newline at end of file diff --git a/docs/source/en/api/pipelines/magi1.md b/docs/source/en/api/pipelines/magi1.md new file mode 100644 index 000000000000..b64c6f57c9b1 --- /dev/null +++ b/docs/source/en/api/pipelines/magi1.md @@ -0,0 +1,261 @@ + + +
+
+ + LoRA + +
+
+ +# MAGI-1 + +[MAGI-1: Autoregressive Video Generation at Scale](https://arxiv.org/abs/2505.13211) by Sand.ai. + +*MAGI-1 is an autoregressive video generation model that generates videos chunk-by-chunk instead of as a whole. Each chunk (24 frames) is denoised holistically, and the generation of the next chunk begins as soon as the current one reaches a certain level of denoising. This pipeline design enables concurrent processing of up to four chunks for efficient video generation. The model leverages a specialized architecture with a transformer-based VAE with 8x spatial and 4x temporal compression, and a diffusion transformer with several key innovations including Block-Causal Attention, Parallel Attention Block, QK-Norm and GQA, Sandwich Normalization in FFN, SwiGLU, and Softcap Modulation.* + +The original repo: https://github.com/SandAI-org/MAGI-1 + +This model was contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz). + +You can find the MAGI-1 checkpoints under the [sand-ai](https://huggingface.co/sand-ai) organization. + +The following MAGI-1 models are supported in Diffusers: + +**Base Models:** +- [MAGI-1 24B](https://huggingface.co/sand-ai/MAGI-1) +- [MAGI-1 4.5B](https://huggingface.co/sand-ai/MAGI-1-4.5B) + +**Distilled Models (faster inference):** +- [MAGI-1 24B Distill](https://huggingface.co/sand-ai/MAGI-1/tree/main/ckpt/magi/24B_distill) +- [MAGI-1 24B Distill+Quant (FP8)](https://huggingface.co/sand-ai/MAGI-1/tree/main/ckpt/magi/24B_distill_quant) +- [MAGI-1 4.5B Distill](https://huggingface.co/sand-ai/MAGI-1/tree/main/ckpt/magi/4.5B_distill) +- [MAGI-1 4.5B Distill+Quant (FP8)](https://huggingface.co/sand-ai/MAGI-1/tree/main/ckpt/magi/4.5B_distill_quant) + +> [!TIP] +> Click on the MAGI-1 models in the right sidebar for more examples of video generation. + +### Text-to-Video Generation + +The example below demonstrates how to generate a video from text optimized for memory or inference speed. + + + + +Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques. + +The MAGI-1 text-to-video model below requires ~13GB of VRAM. + +```py +import torch +import numpy as np +from diffusers import AutoModel, Magi1Pipeline +from diffusers.hooks.group_offloading import apply_group_offloading +from diffusers.utils import export_to_video +from transformers import T5EncoderModel + +text_encoder = T5EncoderModel.from_pretrained("sand-ai/MAGI-1", subfolder="text_encoder", torch_dtype=torch.bfloat16) +vae = AutoModel.from_pretrained("sand-ai/MAGI-1", subfolder="vae", torch_dtype=torch.float32) +transformer = AutoModel.from_pretrained("sand-ai/MAGI-1", subfolder="transformer", torch_dtype=torch.bfloat16) + +# group-offloading +onload_device = torch.device("cuda") +offload_device = torch.device("cpu") +apply_group_offloading(text_encoder, + onload_device=onload_device, + offload_device=offload_device, + offload_type="block_level", + num_blocks_per_group=4 +) +transformer.enable_group_offload( + onload_device=onload_device, + offload_device=offload_device, + offload_type="leaf_level", + use_stream=True +) + +pipeline = Magi1Pipeline.from_pretrained( + "sand-ai/MAGI-1", + vae=vae, + transformer=transformer, + text_encoder=text_encoder, + torch_dtype=torch.bfloat16 +) +pipeline.to("cuda") + +prompt = """ +A majestic eagle soaring over a mountain landscape. The eagle's wings are spread wide, +catching the golden sunlight as it glides through the clear blue sky. Below, snow-capped +mountains stretch to the horizon, with pine forests and a winding river visible in the valley. +""" +negative_prompt = """ +Poor quality, blurry, pixelated, low resolution, distorted proportions, unnatural colors, +watermark, text overlay, incomplete rendering, glitches, artifacts, unrealistic lighting +""" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=24, + guidance_scale=7.0, +).frames[0] +export_to_video(output, "output.mp4", fps=8) +``` + + + + +[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. + +```py +import torch +import numpy as np +from diffusers import AutoModel, Magi1Pipeline +from diffusers.utils import export_to_video +from transformers import T5EncoderModel + +text_encoder = T5EncoderModel.from_pretrained("sand-ai/MAGI-1", subfolder="text_encoder", torch_dtype=torch.bfloat16) +vae = AutoModel.from_pretrained("sand-ai/MAGI-1", subfolder="vae", torch_dtype=torch.float32) +transformer = AutoModel.from_pretrained("sand-ai/MAGI-1", subfolder="transformer", torch_dtype=torch.bfloat16) + +pipeline = Magi1Pipeline.from_pretrained( + "sand-ai/MAGI-1", + vae=vae, + transformer=transformer, + text_encoder=text_encoder, + torch_dtype=torch.bfloat16 +) +pipeline.to("cuda") + +# torch.compile +pipeline.transformer.to(memory_format=torch.channels_last) +pipeline.transformer = torch.compile( + pipeline.transformer, mode="max-autotune", fullgraph=True +) + +prompt = """ +A majestic eagle soaring over a mountain landscape. The eagle's wings are spread wide, +catching the golden sunlight as it glides through the clear blue sky. Below, snow-capped +mountains stretch to the horizon, with pine forests and a winding river visible in the valley. +""" +negative_prompt = """ +Poor quality, blurry, pixelated, low resolution, distorted proportions, unnatural colors, +watermark, text overlay, incomplete rendering, glitches, artifacts, unrealistic lighting +""" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=24, + guidance_scale=7.0, +).frames[0] +export_to_video(output, "output.mp4", fps=8) +``` + + + + +### Image-to-Video Generation + +The example below demonstrates how to use the image-to-video pipeline to generate a video animation from a single image using text prompts for guidance. + + + + +```python +import torch +from diffusers import Magi1ImageToVideoPipeline, AutoencoderKLMagi1 +from diffusers.utils import export_to_video, load_image + +model_id = "sand-ai/MAGI-1-I2V" +vae = AutoencoderKLMagi1.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) +pipe = Magi1ImageToVideoPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) +pipe.to("cuda") + +# Load input image +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg") + +prompt = ( + "An astronaut walking on the moon's surface, with the Earth visible in the background. " + "The astronaut moves slowly in a low-gravity environment, kicking up lunar dust with each step." +) +negative_prompt = "Bright tones, overexposed, static, blurred details, worst quality, low quality" + +output = pipe( + image=image, + prompt=prompt, + negative_prompt=negative_prompt, + height=480, + width=832, + num_frames=81, # Generate 81 frames (~5 seconds at 16fps) + guidance_scale=5.0, + num_inference_steps=50, +).frames[0] +export_to_video(output, "astronaut_animation.mp4", fps=16) +``` + + + + +### Video-to-Video Generation + +The example below demonstrates how to use the video-to-video pipeline to extend or continue an existing video using text prompts. + + + + +```python +import torch +from diffusers import Magi1VideoToVideoPipeline, AutoencoderKLMagi1 +from diffusers.utils import export_to_video, load_video + +model_id = "sand-ai/MAGI-1-V2V" +vae = AutoencoderKLMagi1.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) +pipe = Magi1VideoToVideoPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) +pipe.to("cuda") + +# Load prefix video (e.g., first 24 frames of a video) +video = load_video("path/to/input_video.mp4", num_frames=24) + +prompt = ( + "Continue this video with smooth camera motion and consistent style. " + "The scene evolves naturally with coherent motion and lighting." +) +negative_prompt = "Bright tones, overexposed, static, blurred details, worst quality, low quality, jumpy motion" + +output = pipe( + video=video, + prompt=prompt, + negative_prompt=negative_prompt, + height=480, + width=832, + num_frames=81, # Total frames including prefix (24 prefix + 57 generated) + guidance_scale=5.0, + num_inference_steps=50, +).frames[0] +export_to_video(output, "video_continuation.mp4", fps=16) +``` + + + + +## Notes + +- MAGI-1 uses autoregressive chunked generation with `chunk_width=6` and `window_size=4`, enabling efficient long video generation. +- The model supports special tokens for quality control (HQ_TOKEN), style (THREE_D_MODEL_TOKEN, TWO_D_ANIME_TOKEN), and motion guidance (STATIC_FIRST_FRAMES_TOKEN, DYNAMIC_FIRST_FRAMES_TOKEN). +- For I2V, the input image is encoded as a clean prefix chunk to condition the video generation. +- For V2V, input video frames (typically 24 frames or ~1.5 seconds) are encoded as clean prefix chunks, and the model generates a continuation. +- MAGI-1 supports LoRAs with [`~loaders.Magi1LoraLoaderMixin.load_lora_weights`]. +- Distillation mode can be enabled for faster inference with `enable_distillation=True` (requires distilled model checkpoint). \ No newline at end of file diff --git a/docs/source/en/optimization/attention_backends.md b/docs/source/en/optimization/attention_backends.md index 8be2c0603009..82547fedceec 100644 --- a/docs/source/en/optimization/attention_backends.md +++ b/docs/source/en/optimization/attention_backends.md @@ -149,5 +149,6 @@ Refer to the table below for a complete list of available attention backends and | `_sage_qk_int8_pv_fp16_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP16 PV (CUDA) | | `_sage_qk_int8_pv_fp16_triton` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP16 PV (Triton) | | `xformers` | [xFormers](https://github.com/facebookresearch/xformers) | Memory-efficient attention | +| `magi` | [MagiAttention](https://github.com/SandAI-org/MagiAttention) | A CP-based Attention Towards Linear Scalability, Heterogeneous Mask Training | - \ No newline at end of file + diff --git a/scripts/convert_magi1_to_diffusers.py b/scripts/convert_magi1_to_diffusers.py new file mode 100644 index 000000000000..57c5d88a0f42 --- /dev/null +++ b/scripts/convert_magi1_to_diffusers.py @@ -0,0 +1,548 @@ +import argparse +import json +import os +import shutil +import tempfile + +import torch +from huggingface_hub import hf_hub_download +from safetensors import safe_open +from safetensors.torch import load_file + +from diffusers import Magi1Pipeline, Magi1Transformer3DModel +from diffusers.models.autoencoders import AutoencoderKLMagi1 + + +def convert_magi1_transformer(model_type): + """ + Convert MAGI-1 transformer for a specific model type. + + Args: + model_type: The model type (e.g., "MAGI-1-T2V-4.5B-distill", "MAGI-1-T2V-24B-distill", etc.) + + Returns: + The converted transformer model. + """ + + model_type_mapping = { + "MAGI-1-T2V-4.5B-distill": "4.5B_distill", + "MAGI-1-T2V-24B-distill": "24B_distill", + "MAGI-1-T2V-4.5B": "4.5B", + "MAGI-1-T2V-24B": "24B", + "4.5B_distill": "4.5B_distill", + "24B_distill": "24B_distill", + "4.5B": "4.5B", + "24B": "24B", + } + + repo_path = model_type_mapping.get(model_type, model_type) + + temp_dir = tempfile.mkdtemp() + transformer_ckpt_dir = os.path.join(temp_dir, "transformer_checkpoint") + os.makedirs(transformer_ckpt_dir, exist_ok=True) + + checkpoint_files = [] + shard_index = 1 + while True: + try: + if shard_index == 1: + shard_filename = f"model-{shard_index:05d}-of-00002.safetensors" + shard_path = hf_hub_download( + "sand-ai/MAGI-1", f"ckpt/magi/{repo_path}/inference_weight.distill/{shard_filename}" + ) + checkpoint_files.append(shard_path) + print(f"Downloaded {shard_filename}") + shard_index += 1 + elif shard_index == 2: + shard_filename = f"model-{shard_index:05d}-of-00002.safetensors" + shard_path = hf_hub_download( + "sand-ai/MAGI-1", f"ckpt/magi/{repo_path}/inference_weight.distill/{shard_filename}" + ) + checkpoint_files.append(shard_path) + print(f"Downloaded {shard_filename}") + break + else: + break + except Exception as e: + print(f"No more shards found or error downloading shard {shard_index}: {e}") + break + + if not checkpoint_files: + raise ValueError(f"No checkpoint files found for model type: {model_type}") + + for i, shard_path in enumerate(checkpoint_files): + dest_path = os.path.join(transformer_ckpt_dir, f"model-{i + 1:05d}-of-{len(checkpoint_files):05d}.safetensors") + shutil.copy2(shard_path, dest_path) + + transformer = convert_magi1_transformer_checkpoint(transformer_ckpt_dir) + + return transformer + + +def convert_magi1_vae(): + vae_ckpt_path = hf_hub_download("sand-ai/MAGI-1", "ckpt/vae/diffusion_pytorch_model.safetensors") + checkpoint = load_file(vae_ckpt_path) + + config = { + "patch_size": (4, 8, 8), + "num_attention_heads": 16, + "attention_head_dim": 64, + "z_dim": 16, + "height": 256, + "width": 256, + "num_frames": 16, + "ffn_dim": 4 * 1024, + "num_layers": 24, + "eps": 1e-6, + } + + vae = AutoencoderKLMagi1( + patch_size=config["patch_size"], + num_attention_heads=config["num_attention_heads"], + attention_head_dim=config["attention_head_dim"], + z_dim=config["z_dim"], + height=config["height"], + width=config["width"], + num_frames=config["num_frames"], + ffn_dim=config["ffn_dim"], + num_layers=config["num_layers"], + eps=config["eps"], + ) + + converted_state_dict = convert_vae_state_dict(checkpoint) + + vae.load_state_dict(converted_state_dict, strict=True) + + return vae + + +def convert_vae_state_dict(checkpoint): + """ + Convert MAGI-1 VAE state dict to diffusers format. + + Maps the keys from the MAGI-1 VAE state dict to the diffusers VAE state dict. + """ + state_dict = {} + + state_dict["encoder.patch_embedding.weight"] = checkpoint["encoder.patch_embed.proj.weight"] + state_dict["encoder.patch_embedding.bias"] = checkpoint["encoder.patch_embed.proj.bias"] + + state_dict["encoder.pos_embed"] = checkpoint["encoder.pos_embed"] + + state_dict["encoder.cls_token"] = checkpoint["encoder.cls_token"] + + for i in range(24): + qkv_weight = checkpoint[f"encoder.blocks.{i}.attn.qkv.weight"] + qkv_bias = checkpoint[f"encoder.blocks.{i}.attn.qkv.bias"] + + q_weight, k_weight, v_weight = qkv_weight.chunk(3, dim=0) + q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0) + + state_dict[f"encoder.blocks.{i}.attn.to_q.weight"] = q_weight + state_dict[f"encoder.blocks.{i}.attn.to_q.bias"] = q_bias + state_dict[f"encoder.blocks.{i}.attn.to_k.weight"] = k_weight + state_dict[f"encoder.blocks.{i}.attn.to_k.bias"] = k_bias + state_dict[f"encoder.blocks.{i}.attn.to_v.weight"] = v_weight + state_dict[f"encoder.blocks.{i}.attn.to_v.bias"] = v_bias + + state_dict[f"encoder.blocks.{i}.attn.to_out.0.weight"] = checkpoint[f"encoder.blocks.{i}.attn.proj.weight"] + state_dict[f"encoder.blocks.{i}.attn.to_out.0.bias"] = checkpoint[f"encoder.blocks.{i}.attn.proj.bias"] + + state_dict[f"encoder.blocks.{i}.norm.weight"] = checkpoint[f"encoder.blocks.{i}.norm2.weight"] + state_dict[f"encoder.blocks.{i}.norm.bias"] = checkpoint[f"encoder.blocks.{i}.norm2.bias"] + + state_dict[f"encoder.blocks.{i}.proj_out.net.0.proj.weight"] = checkpoint[f"encoder.blocks.{i}.mlp.fc1.weight"] + state_dict[f"encoder.blocks.{i}.proj_out.net.0.proj.bias"] = checkpoint[f"encoder.blocks.{i}.mlp.fc1.bias"] + state_dict[f"encoder.blocks.{i}.proj_out.net.2.weight"] = checkpoint[f"encoder.blocks.{i}.mlp.fc2.weight"] + + state_dict[f"encoder.blocks.{i}.proj_out.net.2.bias"] = checkpoint[f"encoder.blocks.{i}.mlp.fc2.bias"] + + state_dict["encoder.norm_out.weight"] = checkpoint["encoder.norm.weight"] + state_dict["encoder.norm_out.bias"] = checkpoint["encoder.norm.bias"] + + state_dict["encoder.linear_out.weight"] = checkpoint["encoder.last_layer.weight"] + state_dict["encoder.linear_out.bias"] = checkpoint["encoder.last_layer.bias"] + + state_dict["decoder.proj_in.weight"] = checkpoint["decoder.proj_in.weight"] + state_dict["decoder.proj_in.bias"] = checkpoint["decoder.proj_in.bias"] + + state_dict["decoder.pos_embed"] = checkpoint["decoder.pos_embed"] + + state_dict["decoder.cls_token"] = checkpoint["decoder.cls_token"] + + for i in range(24): + qkv_weight = checkpoint[f"decoder.blocks.{i}.attn.qkv.weight"] + qkv_bias = checkpoint[f"decoder.blocks.{i}.attn.qkv.bias"] + + q_weight, k_weight, v_weight = qkv_weight.chunk(3, dim=0) + q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0) + + state_dict[f"decoder.blocks.{i}.attn.to_q.weight"] = q_weight + state_dict[f"decoder.blocks.{i}.attn.to_q.bias"] = q_bias + state_dict[f"decoder.blocks.{i}.attn.to_k.weight"] = k_weight + state_dict[f"decoder.blocks.{i}.attn.to_k.bias"] = k_bias + state_dict[f"decoder.blocks.{i}.attn.to_v.weight"] = v_weight + state_dict[f"decoder.blocks.{i}.attn.to_v.bias"] = v_bias + + state_dict[f"decoder.blocks.{i}.attn.to_out.0.weight"] = checkpoint[f"decoder.blocks.{i}.attn.proj.weight"] + state_dict[f"decoder.blocks.{i}.attn.to_out.0.bias"] = checkpoint[f"decoder.blocks.{i}.attn.proj.bias"] + + state_dict[f"decoder.blocks.{i}.norm.weight"] = checkpoint[f"decoder.blocks.{i}.norm2.weight"] + state_dict[f"decoder.blocks.{i}.norm.bias"] = checkpoint[f"decoder.blocks.{i}.norm2.bias"] + + state_dict[f"decoder.blocks.{i}.proj_out.net.0.proj.weight"] = checkpoint[f"decoder.blocks.{i}.mlp.fc1.weight"] + state_dict[f"decoder.blocks.{i}.proj_out.net.0.proj.bias"] = checkpoint[f"decoder.blocks.{i}.mlp.fc1.bias"] + state_dict[f"decoder.blocks.{i}.proj_out.net.2.weight"] = checkpoint[f"decoder.blocks.{i}.mlp.fc2.weight"] + state_dict[f"decoder.blocks.{i}.proj_out.net.2.bias"] = checkpoint[f"decoder.blocks.{i}.mlp.fc2.bias"] + + state_dict["decoder.norm_out.weight"] = checkpoint["decoder.norm.weight"] + state_dict["decoder.norm_out.bias"] = checkpoint["decoder.norm.bias"] + + state_dict["decoder.conv_out.weight"] = checkpoint["decoder.last_layer.weight"] + state_dict["decoder.conv_out.bias"] = checkpoint["decoder.last_layer.bias"] + + return state_dict + + +def load_magi1_transformer_checkpoint(checkpoint_path): + """ + Load a MAGI-1 transformer checkpoint. + + Args: + checkpoint_path: Path to the MAGI-1 transformer checkpoint. + + Returns: + The loaded checkpoint state dict. + """ + if checkpoint_path.endswith(".safetensors"): + state_dict = load_file(checkpoint_path) + elif os.path.isdir(checkpoint_path): + safetensors_files = [f for f in os.listdir(checkpoint_path) if f.endswith(".safetensors")] + if safetensors_files: + state_dict = {} + for safetensors_file in sorted(safetensors_files): + file_path = os.path.join(checkpoint_path, safetensors_file) + with safe_open(file_path, framework="pt", device="cpu") as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key) + else: + checkpoint_files = [f for f in os.listdir(checkpoint_path) if f.endswith(".pt") or f.endswith(".pth")] + if not checkpoint_files: + raise ValueError(f"No checkpoint files found in {checkpoint_path}") + + checkpoint_file = os.path.join(checkpoint_path, checkpoint_files[0]) + checkpoint_data = torch.load(checkpoint_file, map_location="cpu") + + if isinstance(checkpoint_data, dict): + if "model" in checkpoint_data: + state_dict = checkpoint_data["model"] + elif "state_dict" in checkpoint_data: + state_dict = checkpoint_data["state_dict"] + else: + state_dict = checkpoint_data + else: + state_dict = checkpoint_data + else: + checkpoint_data = torch.load(checkpoint_path, map_location="cpu") + + if isinstance(checkpoint_data, dict): + if "model" in checkpoint_data: + state_dict = checkpoint_data["model"] + elif "state_dict" in checkpoint_data: + state_dict = checkpoint_data["state_dict"] + else: + state_dict = checkpoint_data + else: + state_dict = checkpoint_data + + return state_dict + + +def convert_magi1_transformer_checkpoint(checkpoint_path, transformer_config_file=None, dtype=None, allow_partial=False): + """ + Convert a MAGI-1 transformer checkpoint to a diffusers Magi1Transformer3DModel. + + Args: + checkpoint_path: Path to the MAGI-1 transformer checkpoint. + transformer_config_file: Optional path to a transformer config file. + dtype: Optional dtype for the model. + + Returns: + A diffusers Magi1Transformer3DModel model. + """ + if transformer_config_file is not None: + with open(transformer_config_file, "r") as f: + config = json.load(f) + else: + config = { + "in_channels": 16, + "out_channels": 16, + "num_layers": 34, + "num_attention_heads": 24, + "num_kv_heads": 8, + "attention_head_dim": 128, + "cross_attention_dim": 4096, + "freq_dim": 256, + "ffn_dim": 12288, + "patch_size": (1, 2, 2), + "eps": 1e-6, + } + + transformer = Magi1Transformer3DModel( + in_channels=config["in_channels"], + out_channels=config["out_channels"], + num_layers=config["num_layers"], + num_attention_heads=config["num_attention_heads"], + num_kv_heads=config["num_kv_heads"], + attention_head_dim=config["attention_head_dim"], + cross_attention_dim=config["cross_attention_dim"], + freq_dim=config["freq_dim"], + ffn_dim=config["ffn_dim"], + patch_size=config["patch_size"], + eps=config["eps"], + ) + + checkpoint = load_magi1_transformer_checkpoint(checkpoint_path) + + converted_state_dict, report = convert_transformer_state_dict(checkpoint, transformer, allow_partial=allow_partial) + + # Verify mapping coverage & shapes + print("\n=== MAGI-1 -> Diffusers mapping report ===") + print(f"Source keys used: {report['used_src_keys']} / {report['total_src_keys']}") + if report["missing_src_keys"]: + print(f"Missing source keys referenced: {len(report['missing_src_keys'])}") + print("Examples:", report["missing_src_keys"][:20]) + + # Target verifications + expected = transformer.state_dict() + expected_keys = set(expected.keys()) + got_keys = set(converted_state_dict.keys()) + missing_target = sorted(list(expected_keys - got_keys)) + unexpected_target = sorted(list(got_keys - expected_keys)) + + shape_mismatches = [] + for k in sorted(list(expected_keys & got_keys)): + if tuple(expected[k].shape) != tuple(converted_state_dict[k].shape): + shape_mismatches.append((k, tuple(converted_state_dict[k].shape), tuple(expected[k].shape))) + + if missing_target: + print(f"Missing target keys: {len(missing_target)}") + print("Examples:", missing_target[:20]) + if unexpected_target: + print(f"Unexpected converted keys: {len(unexpected_target)}") + print("Examples:", unexpected_target[:20]) + if shape_mismatches: + print(f"Shape mismatches: {len(shape_mismatches)}") + print("Examples:", shape_mismatches[:5]) + + if (report["missing_src_keys"] or missing_target or shape_mismatches): + raise ValueError("Conversion verification failed. See report above.") + + # Enforce strict=True per requirement + transformer.load_state_dict(converted_state_dict, strict=True) + + if dtype is not None: + transformer = transformer.to(dtype=dtype) + + return transformer + + +def convert_transformer_state_dict(checkpoint, transformer=None, allow_partial=False): + """ + Convert MAGI-1 transformer state dict to diffusers format. + + Maps the original MAGI-1 parameter names to diffusers' standard transformer naming. + Handles all shape mismatches and key mappings based on actual checkpoint analysis. + """ + print("Converting MAGI-1 checkpoint keys...") + + converted_state_dict = {} + used_src_keys = set() + missing_src_keys = [] + + def require(key: str) -> torch.Tensor: + if key not in checkpoint: + missing_src_keys.append(key) + if allow_partial: + return None # will be skipped by caller + raise KeyError(f"Missing source key: {key}") + used_src_keys.add(key) + return checkpoint[key] + + def assign(src: str, dst: str): + val = require(src) + if val is not None: + converted_state_dict[dst] = val + + def split_assign(src: str, dst_k: str, dst_v: str): + kv = require(src) + if kv is not None: + k, v = kv.chunk(2, dim=0) + converted_state_dict[dst_k] = k + converted_state_dict[dst_v] = v + + # Simple top-level mappings + simple_maps = [ + ("x_embedder.weight", "patch_embedding.weight"), + ("t_embedder.mlp.0.weight", "condition_embedder.time_embedder.linear_1.weight"), + ("t_embedder.mlp.0.bias", "condition_embedder.time_embedder.linear_1.bias"), + ("t_embedder.mlp.2.weight", "condition_embedder.time_embedder.linear_2.weight"), + ("t_embedder.mlp.2.bias", "condition_embedder.time_embedder.linear_2.bias"), + ("y_embedder.y_proj_xattn.0.weight", "condition_embedder.text_embedder.y_proj_xattn.0.weight"), + ("y_embedder.y_proj_xattn.0.bias", "condition_embedder.text_embedder.y_proj_xattn.0.bias"), + ("y_embedder.y_proj_adaln.0.weight", "condition_embedder.text_embedder.y_proj_adaln.weight"), + ("y_embedder.y_proj_adaln.0.bias", "condition_embedder.text_embedder.y_proj_adaln.bias"), + ("videodit_blocks.final_layernorm.weight", "norm_out.weight"), + ("videodit_blocks.final_layernorm.bias", "norm_out.bias"), + ("final_linear.linear.weight", "proj_out.weight"), + ("rope.bands", "rope.bands"), + ] + + for src, dst in simple_maps: + try: + assign(src, dst) + except KeyError: + if not allow_partial: + raise + + # Determine number of layers + if transformer is not None and hasattr(transformer, "config"): + num_layers = transformer.config.num_layers + else: + # Fallback: infer from checkpoint keys + num_layers = 0 + for k in checkpoint.keys(): + if k.startswith("videodit_blocks.layers."): + try: + idx = int(k.split(".")[3]) + num_layers = max(num_layers, idx + 1) + except Exception: + pass + + # Per-layer mappings + for i in range(num_layers): + layer_prefix = f"videodit_blocks.layers.{i}" + block_prefix = f"blocks.{i}" + + layer_maps = [ + (f"{layer_prefix}.self_attention.linear_qkv.layer_norm.weight", f"{block_prefix}.norm1.weight"), + (f"{layer_prefix}.self_attention.linear_qkv.layer_norm.bias", f"{block_prefix}.norm1.bias"), + (f"{layer_prefix}.self_attention.linear_qkv.q.weight", f"{block_prefix}.attn1.to_q.weight"), + (f"{layer_prefix}.self_attention.linear_qkv.k.weight", f"{block_prefix}.attn1.to_k.weight"), + (f"{layer_prefix}.self_attention.linear_qkv.v.weight", f"{block_prefix}.attn1.to_v.weight"), + (f"{layer_prefix}.self_attention.q_layernorm.weight", f"{block_prefix}.attn1.norm_q.weight"), + (f"{layer_prefix}.self_attention.q_layernorm.bias", f"{block_prefix}.attn1.norm_q.bias"), + (f"{layer_prefix}.self_attention.k_layernorm.weight", f"{block_prefix}.attn1.norm_k.weight"), + (f"{layer_prefix}.self_attention.k_layernorm.bias", f"{block_prefix}.attn1.norm_k.bias"), + (f"{layer_prefix}.self_attention.linear_qkv.qx.weight", f"{block_prefix}.attn2.to_q.weight"), + (f"{layer_prefix}.self_attention.q_layernorm_xattn.weight", f"{block_prefix}.attn2.norm_q.weight"), + (f"{layer_prefix}.self_attention.q_layernorm_xattn.bias", f"{block_prefix}.attn2.norm_q.bias"), + (f"{layer_prefix}.self_attention.k_layernorm_xattn.weight", f"{block_prefix}.attn2.norm_k.weight"), + (f"{layer_prefix}.self_attention.k_layernorm_xattn.bias", f"{block_prefix}.attn2.norm_k.bias"), + # Combined projection for concatenated [self_attn, cross_attn] outputs + (f"{layer_prefix}.self_attention.linear_proj.weight", f"{block_prefix}.attn_proj.weight"), + (f"{layer_prefix}.self_attn_post_norm.weight", f"{block_prefix}.norm2.weight"), + (f"{layer_prefix}.self_attn_post_norm.bias", f"{block_prefix}.norm2.bias"), + (f"{layer_prefix}.mlp.layer_norm.weight", f"{block_prefix}.norm3.weight"), + (f"{layer_prefix}.mlp.layer_norm.bias", f"{block_prefix}.norm3.bias"), + (f"{layer_prefix}.mlp.linear_fc1.weight", f"{block_prefix}.ffn.net.0.proj.weight"), + (f"{layer_prefix}.mlp.linear_fc2.weight", f"{block_prefix}.ffn.net.2.weight"), + (f"{layer_prefix}.mlp_post_norm.weight", f"{block_prefix}.norm4.weight"), + (f"{layer_prefix}.mlp_post_norm.bias", f"{block_prefix}.norm4.bias"), + (f"{layer_prefix}.ada_modulate_layer.proj.0.weight", f"{block_prefix}.ada_modulate_layer.1.weight"), + (f"{layer_prefix}.ada_modulate_layer.proj.0.bias", f"{block_prefix}.ada_modulate_layer.1.bias"), + ] + + for src, dst in layer_maps: + try: + assign(src, dst) + except KeyError: + if not allow_partial: + raise + + # special split for kv + try: + split_assign( + f"{layer_prefix}.self_attention.linear_kv_xattn.weight", + f"{block_prefix}.attn2.to_k.weight", + f"{block_prefix}.attn2.to_v.weight", + ) + except KeyError: + if not allow_partial: + raise + + print(f"Converted {len(converted_state_dict)} parameters") + report = { + "total_src_keys": len(checkpoint), + "used_src_keys": len(used_src_keys), + "missing_src_keys": missing_src_keys, + } + return converted_state_dict, report + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_type", type=str, default=None) + parser.add_argument("--checkpoint_path", type=str, default=None, help="Local MAGI-1 transformer checkpoint path") + parser.add_argument("--config_path", type=str, default=None, help="Optional JSON config for transformer") + parser.add_argument("--output_path", type=str, required=True) + parser.add_argument("--dtype", default="fp32", choices=["fp32", "fp16", "bf16", "none"]) + parser.add_argument("--push_to_hub", action="store_true", help="If set, push to the Hub after conversion") + parser.add_argument("--repo_id", type=str, default=None, help="Repo ID to push to (when --push_to_hub is set)") + parser.add_argument("--allow_partial", action="store_true", help="Allow partial/loose state dict loading") + return parser.parse_args() + + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + +if __name__ == "__main__": + args = get_args() + + if args.model_type is not None: + transformer = convert_magi1_transformer(args.model_type) + elif args.checkpoint_path is not None: + transformer = convert_magi1_transformer_checkpoint( + args.checkpoint_path, transformer_config_file=args.config_path, allow_partial=args.allow_partial + ) + else: + raise ValueError("Provide either --model_type for HF download or --checkpoint_path for local conversion.") + + # If user has specified "none", we keep the original dtypes of the state dict without any conversion + if args.dtype != "none": + dtype = DTYPE_MAPPING[args.dtype] + transformer.to(dtype) + + # Save transformer directly to output path (subfolder 'transformer') + save_kwargs = {"safe_serialization": True, "max_shard_size": "5GB"} + save_dir = os.path.join(args.output_path, "transformer") + os.makedirs(save_dir, exist_ok=True) + if args.push_to_hub: + save_kwargs.update( + { + "push_to_hub": True, + "repo_id": ( + args.repo_id + if args.repo_id is not None + else (f"tolgacangoz/{args.model_type}-Magi1Transformer" if args.model_type else "tolgacangoz/Magi1Transformer") + ), + } + ) + transformer.save_pretrained(save_dir, **save_kwargs) + + # Also write a minimal model_index.json for convenience when composing a pipeline later + index_path = os.path.join(args.output_path, "model_index.json") + index = { + "_class_name": "Magi1Pipeline", + "_diffusers_version": "0.0.0", + "transformer": ["transformer"], + "vae": None, + "text_encoder": None, + "tokenizer": None, + "scheduler": None, + } + with open(index_path, "w") as f: + json.dump(index, f, indent=2) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index aa500b149441..2906cbc6c9d2 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -186,6 +186,7 @@ "AutoencoderKLCosmos", "AutoencoderKLHunyuanVideo", "AutoencoderKLLTXVideo", + "AutoencoderKLMagi1", "AutoencoderKLMagvit", "AutoencoderKLMochi", "AutoencoderKLQwenImage", @@ -225,6 +226,7 @@ "LTXVideoTransformer3DModel", "Lumina2Transformer2DModel", "LuminaNextDiT2DModel", + "Magi1Transformer3DModel", "MochiTransformer3DModel", "ModelMixin", "MotionAdapter", @@ -508,6 +510,9 @@ "Lumina2Text2ImgPipeline", "LuminaPipeline", "LuminaText2ImgPipeline", + "Magi1ImageToVideoPipeline", + "Magi1Pipeline", + "Magi1VideoToVideoPipeline", "MarigoldDepthPipeline", "MarigoldIntrinsicsPipeline", "MarigoldNormalsPipeline", @@ -880,6 +885,7 @@ AutoencoderKLCosmos, AutoencoderKLHunyuanVideo, AutoencoderKLLTXVideo, + AutoencoderKLMagi1, AutoencoderKLMagvit, AutoencoderKLMochi, AutoencoderKLQwenImage, @@ -919,6 +925,7 @@ LTXVideoTransformer3DModel, Lumina2Transformer2DModel, LuminaNextDiT2DModel, + Magi1Transformer3DModel, MochiTransformer3DModel, ModelMixin, MotionAdapter, @@ -1172,6 +1179,9 @@ Lumina2Text2ImgPipeline, LuminaPipeline, LuminaText2ImgPipeline, + Magi1ImageToVideoPipeline, + Magi1Pipeline, + Magi1VideoToVideoPipeline, MarigoldDepthPipeline, MarigoldIntrinsicsPipeline, MarigoldNormalsPipeline, diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 48507aae038c..8072db3b8073 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -72,6 +72,7 @@ def text_encoder_attn_modules(text_encoder): "FluxLoraLoaderMixin", "CogVideoXLoraLoaderMixin", "CogView4LoraLoaderMixin", + "Magi1LoraLoaderMixin", "Mochi1LoraLoaderMixin", "HunyuanVideoLoraLoaderMixin", "SanaLoraLoaderMixin", @@ -120,6 +121,7 @@ def text_encoder_attn_modules(text_encoder): LoraLoaderMixin, LTXVideoLoraLoaderMixin, Lumina2LoraLoaderMixin, + Magi1LoraLoaderMixin, Mochi1LoraLoaderMixin, QwenImageLoraLoaderMixin, SanaLoraLoaderMixin, diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 2bb6c0ea026e..97573024d664 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -3924,6 +3924,240 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) +class Magi1LoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`Magi1Transformer3DModel`]. Specific to [`Magi1Pipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details. + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + + state_dict, metadata = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out + + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Magi1Transformer3DModel + def load_lora_into_transformer( + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with StableDiffusion->Magi1 + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an irreversible operation. If you need to unfuse, you'll need to reload the model. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + """ + super().unfuse_lora(components=components, **kwargs) + + class WanLoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`]. diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 8d029bf5d31c..168f66f2d376 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -37,6 +37,7 @@ _import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"] _import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"] _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"] + _import_structure["autoencoders.autoencoder_kl_magi1"] = ["AutoencoderKLMagi1"] _import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"] _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] _import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"] @@ -94,6 +95,7 @@ _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] + _import_structure["transformers.transformer_magi1"] = ["Magi1Transformer3DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] _import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"] @@ -134,6 +136,7 @@ AutoencoderKLCosmos, AutoencoderKLHunyuanVideo, AutoencoderKLLTXVideo, + AutoencoderKLMagi1, AutoencoderKLMagvit, AutoencoderKLMochi, AutoencoderKLQwenImage, @@ -188,6 +191,7 @@ LTXVideoTransformer3DModel, Lumina2Transformer2DModel, LuminaNextDiT2DModel, + Magi1Transformer3DModel, MochiTransformer3DModel, OmniGenTransformer2DModel, PixArtTransformer2DModel, diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index e1694910997a..ff1c970ece65 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -31,6 +31,8 @@ is_flash_attn_available, is_flash_attn_version, is_kernels_available, + is_magi_attn_available, + is_magi_attn_version, is_sageattention_available, is_sageattention_version, is_torch_npu_available, @@ -51,6 +53,7 @@ _REQUIRED_FLEX_VERSION = "2.5.0" _REQUIRED_XLA_VERSION = "2.2" _REQUIRED_XFORMERS_VERSION = "0.0.29" +_REQUIRED_MAGI_VERSION = "1.0.3" _CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION) _CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available() @@ -59,7 +62,7 @@ _CAN_USE_NPU_ATTN = is_torch_npu_available() _CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION) _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION) - +_CAN_USE_MAGI_ATTN = is_magi_attn_available() and is_magi_attn_version(">=", _REQUIRED_MAGI_VERSION) if _CAN_USE_FLASH_ATTN: from flash_attn import flash_attn_func, flash_attn_varlen_func @@ -132,6 +135,11 @@ else: xops = None +if _CAN_USE_MAGI_ATTN: + from magi_attention.functional import flex_flash_attn_func +else: + flex_flash_attn_func = None + # Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4 if torch.__version__ >= "2.4.0": _custom_op = torch.library.custom_op @@ -202,6 +210,9 @@ class AttentionBackendName(str, Enum): # `xformers` XFORMERS = "xformers" + # `magi-attention` + MAGI = "magi" + class _AttentionBackendRegistry: _backends = {} @@ -450,6 +461,11 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None raise RuntimeError( f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`." ) + elif backend == AttentionBackendName.MAGI: + if not _CAN_USE_MAGI_ATTN: + raise RuntimeError( + f"Magi Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `magi-attention>={_REQUIRED_MAGI_VERSION}`." + ) @functools.lru_cache(maxsize=128) @@ -1952,3 +1968,176 @@ def _xformers_attention( out = out.flatten(2, 3) return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName.MAGI, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _magi_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + q_ranges: torch.Tensor, + k_ranges: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + cu_seqlens_q: torch.Tensor | None = None, + cu_seqlens_k: torch.Tensor | None = None, + attn_mask: torch.Tensor | None = None, + attn_type_map: torch.Tensor | None = None, + softmax_scale: float | None = None, + softcap: float = 0.0, + deterministic: bool = False, + sm_margin: int = 0, + disable_fwd_atomic_reduction: bool = False, + auto_range_merge: bool = False, + ref_block_size: tuple[int, int] | None = None, +) -> torch.Tensor: + batch_size, seq_len_q, _, _ = query.shape + _, seq_len_kv, _, _ = key.shape + + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + + if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) + ) + else: + seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) + cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) + cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + + key_valid, value_valid = [], [] + for b in range(batch_size): + valid_len = seqlens_k[b] + key_valid.append(key[b, :valid_len]) + value_valid.append(value[b, :valid_len]) + + query_packed = query.flatten(0, 1) + key_packed = torch.cat(key_valid, dim=0) + value_packed = torch.cat(value_valid, dim=0) + + out = flex_flash_attn_func( + q=query_packed, + k=key_packed, + v=value_packed, + q_ranges=q_ranges, + k_ranges=k_ranges, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + attn_type_map=attn_type_map, + softmax_scale=softmax_scale, + softcap=softcap, + deterministic=deterministic, + sm_margin=sm_margin, + disable_fwd_atomic_reduction=disable_fwd_atomic_reduction, + auto_range_merge=auto_range_merge, + ref_block_size=ref_block_size, + ) + out = out.unflatten(0, (batch_size, -1)) + + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName.MAGI, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape, _check_magi_attn_backend], +) +def _magi_varlen_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + scale: Optional[float] = None, + is_causal: bool = False, + return_lse: bool = False, + q_ranges: Optional[torch.Tensor] = None, + k_ranges: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + _parallel_config: Optional["ParallelConfig"] = None, +) -> torch.Tensor: + """ + MAGI varlen attention backend using flex_flash_attn_func from magi-attention library. + + This backend supports variable-length sequences with q_ranges and k_ranges for + MAGI-1's autoregressive video generation pattern. + + Args: + query: Query tensor [B, S_q, H, D] + key: Key tensor [B, S_kv, H, D] or packed [total_tokens, H, D] + value: Value tensor [B, S_kv, H, D] or packed [total_tokens, H, D] + q_ranges: Tensor of shape [num_ranges, 2] specifying [start, end) for each query range + k_ranges: Tensor of shape [num_ranges, 2] specifying [start, end) for each key range + max_seqlen_q: Maximum sequence length for queries + max_seqlen_k: Maximum sequence length for keys + cu_seqlens_q: Cumulative sequence lengths for queries (alternative to q_ranges) + cu_seqlens_k: Cumulative sequence lengths for keys (alternative to k_ranges) + """ + # If q_ranges/k_ranges are provided, use them directly (MAGI-1 style) + if q_ranges is not None and k_ranges is not None: + # Flatten query/key/value if they're not already packed + if query.ndim == 4: # [B, S, H, D] + batch_size, seq_len_q, num_heads, head_dim = query.shape + query = query.flatten(0, 1) # [B*S, H, D] + key = key.flatten(0, 1) + value = value.flatten(0, 1) + + # Call flex_flash_attn_func with q_ranges/k_ranges + out, _ = flex_flash_attn_func( + query, + key, + value, + q_ranges, + k_ranges, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=scale, + deterministic=torch.are_deterministic_algorithms_enabled(), + disable_fwd_atomic_reduction=True, + ) + + # Unflatten output back to [B, S, H, D] + if query.ndim == 3: # was flattened from [B, S, H, D] + out = out.unflatten(0, (batch_size, seq_len_q)) + + return out + + # Fallback to cu_seqlens if ranges not provided + elif cu_seqlens_q is not None and cu_seqlens_k is not None: + # Convert cu_seqlens to ranges + q_ranges = torch.cat([cu_seqlens_q[:-1].unsqueeze(1), cu_seqlens_q[1:].unsqueeze(1)], dim=1) + k_ranges = torch.cat([cu_seqlens_k[:-1].unsqueeze(1), cu_seqlens_k[1:].unsqueeze(1)], dim=1) + + batch_size, seq_len_q, num_heads, head_dim = query.shape + query = query.flatten(0, 1) + key = key.flatten(0, 1) + value = value.flatten(0, 1) + + out, _ = flex_flash_attn_func( + query, + key, + value, + q_ranges, + k_ranges, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=scale, + deterministic=torch.are_deterministic_algorithms_enabled(), + disable_fwd_atomic_reduction=True, + ) + + out = out.unflatten(0, (batch_size, seq_len_q)) + return out + + else: + raise ValueError( + "MAGI attention backend requires either (q_ranges and k_ranges) or (cu_seqlens_q and cu_seqlens_k)" + ) diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index c008a45298e8..02f6826f51d5 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -6,6 +6,7 @@ from .autoencoder_kl_cosmos import AutoencoderKLCosmos from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo from .autoencoder_kl_ltx import AutoencoderKLLTXVideo +from .autoencoder_kl_magi1 import AutoencoderKLMagi1 from .autoencoder_kl_magvit import AutoencoderKLMagvit from .autoencoder_kl_mochi import AutoencoderKLMochi from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_magi1.py b/src/diffusers/models/autoencoders/autoencoder_kl_magi1.py new file mode 100644 index 000000000000..7b01c685e8a3 --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_magi1.py @@ -0,0 +1,1024 @@ +# Copyright 2025 The Sand AI Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin +from ...utils import logging +from ...utils.accelerate_utils import apply_forward_hook +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..cache_utils import CacheMixin +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def resize_pos_embed(posemb, src_shape, target_shape): + posemb = posemb.reshape(1, src_shape[0], src_shape[1], src_shape[2], -1) + posemb = posemb.permute(0, 4, 1, 2, 3) + posemb = nn.functional.interpolate(posemb, size=target_shape, mode="trilinear", align_corners=False) + posemb = posemb.permute(0, 2, 3, 4, 1) + posemb = posemb.reshape(1, target_shape[0] * target_shape[1] * target_shape[2], -1) + return posemb + + +# Copied from diffusers.models.transformers.transformer_wan._get_qkv_projections +def _get_qkv_projections(attn: "Magi1VAEAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor): + # encoder_hidden_states is only passed for cross-attention + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + if attn.fused_projections: + if attn.cross_attention_dim_head is None: + # In self-attention layers, we can fuse the entire QKV projection into a single linear + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + else: + # In cross-attention layers, we can only fuse the KV projections into a single linear + query = attn.to_q(hidden_states) + key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1) + else: + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + return query, key, value + + +class Magi1VAELayerNorm(nn.Module): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super(Magi1VAELayerNorm, self).__init__() + self.normalized_shape = normalized_shape + self.eps = eps + self.elementwise_affine = elementwise_affine + + def forward(self, x): + mean = x.mean(dim=-1, keepdim=True) + std = x.std(dim=-1, keepdim=True, unbiased=False) + + x_normalized = (x - mean) / (std + self.eps) + + return x_normalized + + +class Magi1VAEAttnProcessor: + _attention_backend = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("Magi1VAEAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: "Magi1VAEAttention", + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + batch_size, time_height_width, channels = hidden_states.size() + + query, key, value = _get_qkv_projections(attn, hidden_states, None) + + qkv = torch.cat((query, key, value), dim=2) + qkv = qkv.reshape(batch_size, time_height_width, 3, attn.heads, channels // attn.heads) + qkv = attn.qkv_norm(qkv) + query, key, value = qkv.chunk(3, dim=2) + + # Remove the extra dimension from chunking + # Shape: (batch_size, time_height_width, num_heads, head_dim) + query = query.squeeze(2) + key = key.squeeze(2) + value = value.squeeze(2) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class Magi1VAEAttention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = Magi1VAEAttnProcessor + _available_processors = [Magi1VAEAttnProcessor] + + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + eps: float = 1e-5, + dropout: float = 0.0, + added_kv_proj_dim: Optional[int] = None, + cross_attention_dim_head: Optional[int] = None, + processor=None, + ): + super().__init__() + + self.inner_dim = dim_head * heads + self.heads = heads + self.added_kv_proj_dim = added_kv_proj_dim + self.cross_attention_dim_head = cross_attention_dim_head + self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads + + self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True) + self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_out = torch.nn.ModuleList( + [ + torch.nn.Linear(self.inner_dim, dim, bias=True), + torch.nn.Dropout(dropout), + ] + ) + self.qkv_norm = Magi1VAELayerNorm(dim // heads, elementwise_affine=False) + + self.add_k_proj = self.add_v_proj = None + if added_kv_proj_dim is not None: + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps) + + self.set_processor(processor) + + # Copied from diffusers.models.transformers.transformer_wan.WanAttention.fuse_projections + def fuse_projections(self): + if getattr(self, "fused_projections", False): + return + + if self.cross_attention_dim_head is None: + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_qkv = nn.Linear(in_features, out_features, bias=True) + self.to_qkv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + else: + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_kv = nn.Linear(in_features, out_features, bias=True) + self.to_kv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + + if self.added_kv_proj_dim is not None: + concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data]) + concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_added_kv = nn.Linear(in_features, out_features, bias=True) + self.to_added_kv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + + self.fused_projections = True + + @torch.no_grad() + # Copied from diffusers.models.transformers.transformer_wan.WanAttention.unfuse_projections + def unfuse_projections(self): + if not getattr(self, "fused_projections", False): + return + + if hasattr(self, "to_qkv"): + delattr(self, "to_qkv") + if hasattr(self, "to_kv"): + delattr(self, "to_kv") + if hasattr(self, "to_added_kv"): + delattr(self, "to_added_kv") + + self.fused_projections = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + return self.processor(self, hidden_states, attention_mask, rotary_emb, **kwargs) + + +class Magi1VAETransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + ffn_dim: int = 4 * 1024, + eps: float = 1e-6, + ): + super().__init__() + self.attn = Magi1VAEAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + cross_attention_dim_head=None, + processor=Magi1VAEAttnProcessor(), + ) + + self.norm = nn.LayerNorm(dim) + self.proj_out = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu") + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor): + hidden_states = hidden_states + self.attn(hidden_states) + hidden_states = hidden_states + self.proj_out(self.norm(hidden_states)) + return hidden_states + + +class Magi1Encoder3d(nn.Module): + def __init__( + self, + inner_dim=128, + z_dim=4, + patch_size: Tuple[int] = (1, 2, 2), + num_frames: int = 16, + height: int = 256, + width: int = 256, + num_attention_heads: int = 40, + ffn_dim: int = 4 * 1024, + num_layers: int = 24, + eps: float = 1e-6, + ): + super().__init__() + self.z_dim = z_dim + self.height = height + self.width = width + self.num_frames = num_frames + + # 1. Patch & position embedding + self.patch_embedding = nn.Conv3d(3, inner_dim, kernel_size=patch_size, stride=patch_size) + self.patch_size = patch_size + + self.cls_token_nums = 1 + self.cls_token = nn.Parameter(torch.zeros(1, 1, inner_dim)) + # `generator` as a parameter? + nn.init.trunc_normal_(self.cls_token, std=0.02) + + p_t, p_h, p_w = patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + num_patches = post_patch_num_frames * post_patch_height * post_patch_width + + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.cls_token_nums, inner_dim)) + self.pos_drop = nn.Dropout(p=0.0) + + # 3. Transformer blocks + self.blocks = nn.ModuleList( + [ + Magi1VAETransformerBlock( + inner_dim, + num_attention_heads, + ffn_dim, + eps, + ) + for _ in range(num_layers) + ] + ) + + # output blocks + self.norm_out = nn.LayerNorm(inner_dim) + self.linear_out = nn.Linear(inner_dim, z_dim * 2) + + # `generator` as a parameter? + nn.init.trunc_normal_(self.pos_embed, std=0.02) + + self.gradient_checkpointing = False + + def forward(self, x): + B = x.shape[0] + # B C T H W -> B C T/pT H/pH W//pW + x = self.patch_embedding(x) + latentT, latentH, latentW = x.shape[2], x.shape[3], x.shape[4] + # B C T/pT H/pH W//pW -> B (T/pT H/pH W//pW) C + x = x.flatten(2).transpose(1, 2) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + if latentT != self.patch_size[0] or latentH != self.patch_size[1] or latentW != self.patch_size[2]: + pos_embed = resize_pos_embed( + self.pos_embed[:, 1:, :], + src_shape=( + self.num_frames // self.patch_size[0], + self.height // self.patch_size[1], + self.width // self.patch_size[2], + ), + target_shape=(latentT, latentH, latentW), + ) + pos_embed = torch.cat((self.pos_embed[:, 0:1, :], pos_embed), dim=1) + else: + pos_embed = self.pos_embed + + x = x + pos_embed + x = self.pos_drop(x) + + ## transformer blocks + for block in self.blocks: + x = block(x) + + ## head + x = self.norm_out(x) + x = x[:, 1:] # remove cls_token + x = self.linear_out(x) + + # B L C - > B , lT, lH, lW, zC (where zC is now z_dim * 2) + x = x.reshape(B, latentT, latentH, latentW, self.z_dim * 2) + + # B , lT, lH, lW, zC -> B, zC, lT, lH, lW + x = x.permute(0, 4, 1, 2, 3) + + return x + + +class Magi1Decoder3d(nn.Module): + def __init__( + self, + inner_dim=1024, + z_dim=16, + patch_size: Tuple[int] = (4, 8, 8), + num_frames: int = 16, + height: int = 256, + width: int = 256, + num_attention_heads: int = 16, + ffn_dim: int = 4 * 1024, + num_layers: int = 24, + eps: float = 1e-6, + ): + super().__init__() + self.z_dim = z_dim + self.patch_size = patch_size + self.height = height + self.width = width + self.num_frames = num_frames + + # init block + self.proj_in = nn.Linear(z_dim, inner_dim) + + self.cls_token_nums = 1 + self.cls_token = nn.Parameter(torch.zeros(1, 1, inner_dim)) + # `generator` as a parameter? + nn.init.trunc_normal_(self.cls_token, std=0.02) + + p_t, p_h, p_w = patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + num_patches = post_patch_num_frames * post_patch_height * post_patch_width + + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.cls_token_nums, inner_dim)) + self.pos_drop = nn.Dropout(p=0.0) + + # 3. Transformer blocks + self.blocks = nn.ModuleList( + [ + Magi1VAETransformerBlock( + inner_dim, + num_attention_heads, + ffn_dim, + eps, + ) + for _ in range(num_layers) + ] + ) + + # output blocks + self.norm_out = nn.LayerNorm(inner_dim) + self.unpatch_channels = inner_dim // (patch_size[0] * patch_size[1] * patch_size[2]) + self.conv_out = nn.Conv3d(self.unpatch_channels, 3, 3, padding=1) + + # `generator` as a parameter? + nn.init.trunc_normal_(self.pos_embed, std=0.02) + + self.gradient_checkpointing = False + + def forward(self, x): + B, C, latentT, latentH, latentW = x.shape + x = x.permute(0, 2, 3, 4, 1) + + x = x.reshape(B, -1, C) + + x = self.proj_in(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + if latentT != self.patch_size[0] or latentH != self.patch_size[1] or latentW != self.patch_size[2]: + pos_embed = resize_pos_embed( + self.pos_embed[:, 1:, :], + src_shape=( + self.num_frames // self.patch_size[0], + self.height // self.patch_size[1], + self.width // self.patch_size[2], + ), + target_shape=(latentT, latentH, latentW), + ) + pos_embed = torch.cat((self.pos_embed[:, 0:1, :], pos_embed), dim=1) + else: + pos_embed = self.pos_embed + + x = x + pos_embed + x = self.pos_drop(x) + + ## transformer blocks + for block in self.blocks: + x = block(x) + + ## head + x = self.norm_out(x) + x = x[:, 1:] # remove cls_token + + x = x.reshape( + B, + latentT, + latentH, + latentW, + self.patch_size[0], + self.patch_size[1], + self.patch_size[2], + self.unpatch_channels, + ) + # Rearrange from (B, lT, lH, lW, pT, pH, pW, C) to (B, C, lT*pT, lH*pH, lW*pW) + x = x.permute(0, 7, 1, 4, 2, 5, 3, 6) # (B, C, lT, pT, lH, pH, lW, pW) + x = x.reshape( + B, + self.unpatch_channels, + latentT * self.patch_size[0], + latentH * self.patch_size[1], + latentW * self.patch_size[2], + ) + + x = self.conv_out(x) + return x + + +class AutoencoderKLMagi1(ModelMixin, ConfigMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + Introduced in [Magi1](https://arxiv.org/abs/2505.13211). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = False + _skip_layerwise_casting_patterns = ["patch_embedding", "norm"] + _no_split_modules = ["Magi1VAETransformerBlock"] + # _keep_in_fp32_modules = ["qkv_norm", "norm1", "norm2"] + _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + + @register_to_config + def __init__( + self, + patch_size: Tuple[int] = (4, 8, 8), + num_attention_heads: int = 16, + attention_head_dim: int = 64, + z_dim: int = 16, + height: int = 256, + width: int = 256, + num_frames: int = 16, + ffn_dim: int = 4 * 1024, + num_layers: int = 24, + eps: float = 1e-6, + latents_mean: List[float] = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ], + latents_std: List[float] = [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ], + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + self.z_dim = z_dim + + self.encoder = Magi1Encoder3d( + inner_dim, + z_dim, + patch_size, + num_frames, + height, + width, + num_attention_heads, + ffn_dim, + num_layers, + eps, + ) + + self.decoder = Magi1Decoder3d( + inner_dim, + z_dim, + patch_size, + num_frames, + height, + width, + num_attention_heads, + ffn_dim, + num_layers, + eps, + ) + + self.spatial_compression_ratio = patch_size[1] or patch_size[2] + self.temporal_compression_ratio = patch_size[0] + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal tile length for temporal tiling to be used + self.tile_sample_min_length = 16 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + + # The minimal distance between two temporal tiles + self.tile_sample_stride_length = 16 + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_min_length: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + tile_sample_stride_length: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_min_length = tile_sample_min_length or self.tile_sample_min_length + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + self.tile_sample_stride_length = tile_sample_stride_length or self.tile_sample_stride_length + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def _encode(self, x: torch.Tensor): + _, _, num_frames, height, width = x.shape + + if self.use_tiling and ( + width > self.tile_sample_min_width + or height > self.tile_sample_min_height + or num_frames > self.tile_sample_min_length + ): + return self.tiled_encode(x) + + out = self.encoder(x) + + return out + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True): + _, _, num_frames, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_length = self.tile_sample_min_length // self.temporal_compression_ratio + + if self.use_tiling and ( + width > tile_latent_min_width or height > tile_latent_min_height or num_frames > tile_latent_min_length + ): + return self.tiled_decode(z, return_dict=return_dict) + + out = self.decoder(z) + + if not return_dict: + return (out,) + + return DecoderOutput(sample=out) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int, power: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + w_a, w_b = 1 - y / blend_extent, y / blend_extent + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (w_a**power) + b[:, :, :, y, :] * (w_b**power) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int, power: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + w_a, w_b = 1.0 - x / blend_extent, x / blend_extent + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (w_a**power) + b[:, :, :, :, x] * (w_b**power) + return b + + def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int, power: int) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for t in range(blend_extent): + w_a, w_b = 1.0 - t / blend_extent, t / blend_extent + b[:, :, t, :, :] = a[:, :, -blend_extent + t, :, :] * (w_a**power) + b[:, :, t, :, :] * (w_b**power) + return b + + def _encode_tile(self, x: torch.Tensor) -> torch.Tensor: + """ + Encode a single tile. + """ + N, C, T, H, W = x.shape + + if T == 1 and self.temporal_compression_ratio > 1: + x = x.expand(-1, -1, self.temporal_compression_ratio, -1, -1) + x = self.encoder(x) + # After temporal expansion and encoding, select only the first temporal frame. + x = x[:, :, :1] + return x + else: + x = self.encoder(x) + return x + + def _decode_tile(self, x: torch.Tensor) -> torch.Tensor: + """ + Decode a single tile. + """ + N, C, T, H, W = x.shape + + x = self.decoder(x) + return x[:, :, :1, :, :] if T == 1 else x + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + B, C, num_frames, height, width = x.shape + + # Latent sizes after compression + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + latent_length = num_frames // self.temporal_compression_ratio + + # Tile latent sizes / strides + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_length = self.tile_sample_min_length // self.temporal_compression_ratio + + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + tile_latent_stride_length = self.tile_sample_stride_length // self.temporal_compression_ratio + + # Overlap (blend) sizes in latent space + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + blend_length = tile_latent_min_length - tile_latent_stride_length + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + times = [] + for t in range(0, num_frames, self.tile_sample_stride_length): + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + tile = x[ + :, + :, + t : t + self.tile_sample_min_length, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + h_tile = self._encode_tile(tile) + # Original implementation samples here and blends the latents. + # Instead we're keeping moments (mu and logvar) and blend them. + row.append(h_tile) + rows.append(row) + times.append(rows) + + # Calculate global blending order because blending is not commutative here. + nT = len(times) + nH = len(times[0]) if nT else 0 + nW = len(times[0][0]) if nH else 0 + idx_numel = [] + for tt in range(nT): + for ii in range(nH): + for jj in range(nW): + idx_numel.append(((tt, ii, jj), times[tt][ii][jj].numel())) + global_order = [idx for (idx, _) in sorted(idx_numel, key=lambda kv: kv[1], reverse=True)] + + result_grid = [[[None for _ in range(nW)] for _ in range(nH)] for _ in range(nT)] + for t_idx, i_idx, j_idx in global_order: + rows = times[t_idx] + row = rows[i_idx] + h = row[j_idx] + + # Separate the mu and the logvar because mu needs to be blended linearly + # but logvar needs to be blended quadratically to obtain numerical equivalence + # so that the overall distribution is preserved + mu, logvar = h[:, : self.z_dim], h[:, self.z_dim :] + var = logvar.exp() + + # Blend the prev tile, the above tile and the left tile + # to the current tile and add the current tile to the result grid + if t_idx > 0: + h_tile = times[t_idx - 1][i_idx][j_idx] + mu_prev, logvar_prev = h_tile[:, : self.z_dim], h_tile[:, self.z_dim :] + var_prev = logvar_prev.exp() + mu = self.blend_t(mu_prev, mu, blend_length, power=1) + var = self.blend_t(var_prev, var, blend_length, power=2) + + if i_idx > 0: + h_tile = rows[i_idx - 1][j_idx] + mu_up, logvar_up = h_tile[:, : self.z_dim], h_tile[:, self.z_dim :] + var_up = logvar_up.exp() + mu = self.blend_v(mu_up, mu, blend_height, power=1) + var = self.blend_v(var_up, var, blend_height, power=2) + + if j_idx > 0: + h_tile = row[j_idx - 1] + mu_left, logvar_left = h_tile[:, : self.z_dim], h_tile[:, self.z_dim :] + var_left = logvar_left.exp() + mu = self.blend_h(mu_left, mu, blend_width, power=1) + var = self.blend_h(var_left, var, blend_width, power=2) + + logvar = var.clamp_min(1e-12).log() + h_blended = torch.cat([mu, logvar], dim=1) + h_core = h_blended[ + :, + :, + :tile_latent_stride_length, + :tile_latent_stride_height, + :tile_latent_stride_width, + ] + result_grid[t_idx][i_idx][j_idx] = h_core + + # Stitch blended tiles to spatially correct places + result_times = [] + for t_idx in range(nT): + result_rows = [] + for i_idx in range(nH): + result_row = [] + for j_idx in range(nW): + result_row.append(result_grid[t_idx][i_idx][j_idx]) + result_rows.append(torch.cat(result_row, dim=4)) + result_times.append(torch.cat(result_rows, dim=3)) + + h_full = torch.cat(result_times, dim=2)[:, :, :latent_length, :latent_height, :latent_width] + return h_full + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + B, C, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + sample_length = num_frames * self.temporal_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_length = self.tile_sample_min_length // self.temporal_compression_ratio + + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + tile_latent_stride_length = self.tile_sample_stride_length // self.temporal_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + blend_length = self.tile_sample_min_length - self.tile_sample_stride_length + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + times = [] + for t in range(0, num_frames, tile_latent_stride_length): + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + tile = z[ + :, + :, + t : t + tile_latent_min_length, + i : i + tile_latent_min_height, + j : j + tile_latent_min_width, + ] + decoded = self._decode_tile(tile) + row.append(decoded) + rows.append(row) + times.append(rows) + + result_times = [] + for t in range(len(times)): + result_rows = [] + for i in range(len(times[t])): + result_row = [] + for j in range(len(times[t][i])): + # Clone the current decoded tile to ensure blending uses an unmodified copy of the tile. + tile = times[t][i][j].clone() + + if t > 0: + tile = self.blend_t(times[t - 1][i][j], tile, blend_length, power=1) + if i > 0: + tile = self.blend_v(times[t][i - 1][j], tile, blend_height, power=1) + if j > 0: + tile = self.blend_h(times[t][i][j - 1], tile, blend_width, power=1) + + result_row.append( + tile[ + :, + :, + : self.tile_sample_stride_length, + : self.tile_sample_stride_height, + : self.tile_sample_stride_width, + ] + ) + + result_rows.append(torch.cat(result_row, dim=4)) + result_times.append(torch.cat(result_rows, dim=3)) + dec = torch.cat(result_times, dim=2)[:, :, :sample_length, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = True, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + """ + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index b51f5d7aec25..c91d7004d3f2 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -830,6 +830,8 @@ def get_3d_rotary_pos_embed( grid_type: str = "linspace", max_size: Optional[Tuple[int, int]] = None, device: Optional[torch.device] = None, + center_grid_hw_indices: bool = False, + equal_split_ratio: Optional[int] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ RoPE for video tokens with 3D structure. @@ -876,10 +878,19 @@ def get_3d_rotary_pos_embed( else: raise ValueError("Invalid value passed for `grid_type`.") - # Compute dimensions for each axis - dim_t = embed_dim // 4 - dim_h = embed_dim // 8 * 3 - dim_w = embed_dim // 8 * 3 + if center_grid_hw_indices: + # Center the grid height and width indices around zero + grid_h = grid_h - grid_h.max() / 2 + grid_w = grid_w - grid_w.max() / 2 + + if equal_split_ratio is None: + dim_t = embed_dim // 4 + dim_h = embed_dim // 8 * 3 + dim_w = embed_dim // 8 * 3 + else: + dim_t = embed_dim // equal_split_ratio + dim_h = embed_dim // equal_split_ratio + dim_w = embed_dim // equal_split_ratio # Temporal frequencies freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, theta=theta, use_real=True) diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 6b80ea6c82a5..94a80ba8daac 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -30,6 +30,7 @@ from .transformer_kandinsky import Kandinsky5Transformer3DModel from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_lumina2 import Lumina2Transformer2DModel + from .transformer_magi1 import Magi1Transformer3DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_omnigen import OmniGenTransformer2DModel from .transformer_qwenimage import QwenImageTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_magi1.py b/src/diffusers/models/transformers/transformer_magi1.py new file mode 100644 index 000000000000..2040e986be83 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_magi1.py @@ -0,0 +1,896 @@ +# Copyright 2025 The MAGI Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, is_kernels_available, logging, scale_lora_layers, unscale_lora_layers +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import AttentionBackendName, dispatch_attention_fn +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput +from ..cache_utils import CacheMixin +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin, get_parameter_dtype +from ..normalization import FP32LayerNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _get_qkv_projections(attn: "Magi1Attention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor): + # encoder_hidden_states is only passed for cross-attention + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + if attn.fused_projections: + if attn.cross_attention_dim_head is None: + # In self-attention layers, we can fuse the entire QKV projection into a single linear + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + else: + # In cross-attention layers, we can only fuse the KV projections into a single linear + query = attn.to_q(hidden_states) + key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1) + else: + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + return query, key, value + + +def range_mod_pytorch(x, c_mapping, gatings): + """ + PyTorch implementation of range_mod_triton. # TODO: Ensure that this implementation is correct and matches the + range_mod_triton implementation. + + Inputs: + x: (s, b, h). Tensor of inputs embedding (images or latent representations of images) c_mapping: (s, b). Tensor + of condition map gatings: (b, denoising_range_num, h). Tensor of condition embedding + """ + s, b, h = x.shape + + # Flatten x and c_mapping to 2D for easier indexing + x_flat = x.transpose(0, 1).flatten(0, 1) # (s*b, h) + c_mapping_flat = c_mapping.transpose(0, 1).flatten(0, 1) # (s*b,) + gatings_flat = gatings.flatten(0, 1) # (b*denoising_range_num, h) + + # Use advanced indexing to select the appropriate gating for each row + # c_mapping_flat contains indices into gatings_flat + selected_gatings = gatings_flat[c_mapping_flat] # (s*b, h) + + # Element-wise multiplication + y_flat = x_flat * selected_gatings # (s*b, h) + + # Reshape back to original dimensions + y = y_flat.reshape(b, s, h).transpose(0, 1) # (s, b, h) + + return y + + +if is_kernels_available(): + from kernels import use_kernel_forward_from_hub + + range_mod_pytorch = use_kernel_forward_from_hub("range_mod_triton")(range_mod_pytorch) + + +class Magi1AttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("Magi1AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: "Magi1Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + attention_kwargs = attention_kwargs or {} + + query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) + + query = query.unflatten(2, (attn.heads, -1)) + kv_h = attn.kv_heads if attn.kv_heads is not None else attn.heads + key = key.unflatten(2, (kv_h, -1)) + value = value.unflatten(2, (kv_h, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if rotary_emb is not None and attn.cross_attention_dim_head is None: + + def apply_rotary_emb(x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> torch.Tensor: + # x: [B, S, H, D] + x1, x2 = x.unflatten(-1, (-1, 2)).unbind(-1) # [B, S, H, D/2] + # Broadcast cos/sin to [1, S, 1, D/2] + cos = freqs_cos.view(1, -1, 1, freqs_cos.shape[-1])[..., 0::2] + sin = freqs_sin.view(1, -1, 1, freqs_sin.shape[-1])[..., 1::2] + out = torch.empty_like(x) + out[..., 0::2] = x1 * cos - x2 * sin + out[..., 1::2] = x1 * sin + x2 * cos + return out.type_as(x) + + query = apply_rotary_emb(query, *rotary_emb) + key = apply_rotary_emb(key, *rotary_emb) + + kv_heads = kv_h + n_rep = attn.heads // kv_heads + if n_rep > 1: + key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + + # Use MAGI backend if varlen parameters are provided + backend = self._attention_backend + if attention_kwargs.get("q_ranges") is not None: + backend = AttentionBackendName.MAGI + + out = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + enable_gqa=True, + backend=backend, + parallel_config=self._parallel_config, + attention_kwargs=attention_kwargs, + ) + out = out.flatten(2, 3).type_as(hidden_states) + return out + + +class Magi1Attention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = Magi1AttnProcessor + _available_processors = [Magi1AttnProcessor] + + def __init__( + self, + dim: int, + heads: int = 24, + kv_heads: Optional[int] = None, + dim_head: int = 128, + eps: float = 1e-6, + dropout: float = 0.0, + added_kv_proj_dim: Optional[int] = None, + cross_attention_dim_head: Optional[int] = None, + processor=None, + is_cross_attention=None, + ): + super().__init__() + + self.inner_dim = dim_head * heads + self.heads = heads + self.kv_heads = kv_heads if kv_heads is not None else heads + self.added_kv_proj_dim = added_kv_proj_dim + self.cross_attention_dim_head = cross_attention_dim_head + self.kv_inner_dim = dim_head * self.kv_heads + + self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=False) + self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=False) + self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=False) + # Note: Output projection is handled in Magi1TransformerBlock to match original architecture + # where [self_attn, cross_attn] outputs are concatenated, rearranged, then projected together + self.norm_q = FP32LayerNorm(dim_head, eps) + self.norm_k = FP32LayerNorm(dim_head, eps) + + self.add_k_proj = self.add_v_proj = None + if added_kv_proj_dim is not None: + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps) + + self.is_cross_attention = cross_attention_dim_head is not None + + self.set_processor(processor) + + def fuse_projections(self): + if getattr(self, "fused_projections", False): + return + + if self.cross_attention_dim_head is None: + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_qkv = nn.Linear(in_features, out_features, bias=False) + self.to_qkv.load_state_dict({"weight": concatenated_weights}, strict=True, assign=True) + else: + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_kv = nn.Linear(in_features, out_features, bias=False) + self.to_kv.load_state_dict({"weight": concatenated_weights}, strict=True, assign=True) + + if self.added_kv_proj_dim is not None: + concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data]) + concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_added_kv = nn.Linear(in_features, out_features, bias=True) + self.to_added_kv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + + self.fused_projections = True + + @torch.no_grad() + def unfuse_projections(self): + if not getattr(self, "fused_projections", False): + return + + if hasattr(self, "to_qkv"): + delattr(self, "to_qkv") + if hasattr(self, "to_kv"): + delattr(self, "to_kv") + if hasattr(self, "to_added_kv"): + delattr(self, "to_added_kv") + + self.fused_projections = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs) + + +class Magi1TextProjection(nn.Module): + """ + Projects caption embeddings. + """ + + def __init__(self, in_features, hidden_size, adaln_dim): + super().__init__() + self.y_proj_xattn = nn.Sequential(nn.Linear(in_features, hidden_size), nn.SiLU()) + self.y_proj_adaln = nn.Linear(in_features, adaln_dim) + + def forward(self, caption): + caption_xattn = self.y_proj_xattn(caption) + caption_adaln = self.y_proj_adaln(caption) + return caption_xattn, caption_adaln + + +class Magi1TimeTextEmbedding(nn.Module): + """ + Combined time, text embedding module for the MAGI-1 model. + + This module handles the encoding of two types of conditioning inputs: + 1. Timestep embeddings for diffusion process control + 2. Text embeddings for text-to-video generation + + Args: + dim (`int`): Hidden dimension of the transformer model. + time_freq_dim (`int`): Dimension for sinusoidal time embeddings. + text_embed_dim (`int`): Input dimension of text embeddings. + enable_distillation (`bool`, optional): Enable distillation timestep adjustments. + """ + + def __init__( + self, + dim: int, + time_freq_dim: int, + text_embed_dim: int, + enable_distillation: bool = False, + ): + super().__init__() + + # NOTE: timestep_rescale_factor=1000 to match original implementation (dit_module.py:71) + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=int(dim * 0.25)) + self.text_embedder = Magi1TextProjection(text_embed_dim, dim, adaln_dim=int(dim * 0.25)) + + self.enable_distillation = enable_distillation + + def forward( + self, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + num_steps: Optional[int] = None, + distill_interval: Optional[int] = None, + ): + timestep = self.timesteps_proj(timestep) + + time_embedder_dtype = get_parameter_dtype(self.time_embedder) + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + y_xattn, y_adaln = self.text_embedder(encoder_hidden_states) + + # Apply distillation logic if enabled + if self.enable_distillation and num_steps is not None: + distill_dt_scalar = 2 + if num_steps == 12 and distill_interval is not None: + base_chunk_step = 4 + distill_dt_factor = base_chunk_step / distill_interval * distill_dt_scalar + else: + distill_dt_factor = num_steps / 4 * distill_dt_scalar + + distill_dt = torch.ones_like(timestep) * distill_dt_factor + distill_dt_embed = self.time_embedder(distill_dt) + temb = temb + distill_dt_embed + + return temb, y_xattn, y_adaln + + +class Magi1RotaryPosEmbed(nn.Module): + """ + Rotary Position Embedding for MAGI-1 model. + + Args: + dim (`int`): The embedding dimension. + theta (`float`, *optional*, defaults to 10000.0): Base for the geometric progression. + """ + + def __init__( + self, + dim: int, + theta: float = 10000.0, + ): + super().__init__() + + num_bands = dim // 8 + exp = torch.arange(0, num_bands, dtype=torch.float32) / num_bands + bands = 1.0 / (theta**exp) + self.bands = nn.Parameter(bands) + + def forward(self, hidden_states: torch.Tensor, T_total: int) -> torch.Tensor: + # Rebuild bands and embeddings every call, use if target shape changes + device = hidden_states.device + batch_size, num_channels, num_frames, post_patch_height, post_patch_width = hidden_states.shape + feat_shape = [T_total, post_patch_height, post_patch_width] + + # Calculate rescale_factor for multi-resolution & multi aspect-ratio training + # the base_size [16*16] is A predefined size based on data:(256x256) vae: (8,8,4) patch size: (1,1,2) + # This definition do not have any relationship with the actual input/model/setting. + # ref_feat_shape is used to calculate innner rescale factor, so it can be float. + rescale_factor = math.sqrt((post_patch_height * post_patch_width) / (16 * 16)) + ref_feat_shape = [T_total, post_patch_height / rescale_factor, post_patch_width / rescale_factor] + + f = torch.arange(num_frames, device=device, dtype=torch.float32) + h = torch.arange(post_patch_height, device=device, dtype=torch.float32) + w = torch.arange(post_patch_width, device=device, dtype=torch.float32) + + # Align spatial center (H/2, W/2) to (0,0) + h = h - (post_patch_height - 1) / 2 + w = w - (post_patch_width - 1) / 2 + + if ref_feat_shape is not None: + # eva's scheme for resizing rope embeddings (ref shape = pretrain) + # aligning to the endpoint e.g [0,1,2] -> [0, 0.4, 0.8, 1.2, 1.6, 2] + fhw_rescaled = [] + fhw = [f, h, w] + for x, shape, ref_shape in zip(fhw, feat_shape, ref_feat_shape): + if shape == 1: # Deal with image input + if ref_shape != 1: + raise ValueError("ref_feat_shape must be 1 when feat_shape is 1") + fhw_rescaled.append(x) + else: + fhw_rescaled.append(x / (shape - 1) * (ref_shape - 1)) + f, h, w = fhw_rescaled + + # Create 3D meshgrid & stack into grid tensor: [T, H, W, 3] + grid = torch.stack(torch.meshgrid(f, h, w, indexing="ij"), dim=-1) + grid = grid.unsqueeze(-1) # [T, H, W, 3, 1] + + # Apply frequency bands + freqs = grid * self.bands # [T, H, W, 3, num_bands] + + freqs_cos = freqs.cos() + freqs_sin = freqs.sin() + + # This would be much nicer as a .numel() call to torch.Size(), but torchscript sucks + num_spatial_dim = 1 + for x in feat_shape: + num_spatial_dim *= x + + freqs_cos = freqs_cos.reshape(num_spatial_dim, -1) + freqs_sin = freqs_sin.reshape(num_spatial_dim, -1) + + return freqs_cos, freqs_sin + + +class Magi1TransformerBlock(nn.Module): + """ + A transformer block used in the MAGI-1 model. + + Args: + dim (`int`): The number of channels in the input and output. + ffn_dim (`int`): The number of channels in the feed-forward layer. + num_heads (`int`): The number of attention heads. + num_kv_heads (`int`): The number of key-value attention heads. + eps (`float`): The epsilon value for layer normalization. + gated_linear_unit (`bool`, defaults to `False`): + Whether to use gated linear units (SwiGLU) in the feed-forward network. + """ + + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + num_kv_heads: int, + eps: float = 1e-6, + gated_linear_unit: bool = False, + ): + super().__init__() + + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps) + self.attn1 = Magi1Attention( + dim=dim, + heads=num_heads, + kv_heads=num_kv_heads, + dim_head=dim // num_heads, + eps=eps, + cross_attention_dim_head=None, + processor=Magi1AttnProcessor(), + ) + + # 2. Cross-attention + self.attn2 = Magi1Attention( + dim=dim, + heads=num_heads, + kv_heads=num_kv_heads, + dim_head=dim // num_heads, + eps=eps, + cross_attention_dim_head=dim // num_kv_heads, + processor=Magi1AttnProcessor(), + ) + + # Combined output projection for concatenated [self_attn, cross_attn] outputs + # Matches original architecture: concat -> rearrange -> project + self.attn_proj = nn.Linear(2 * dim, dim, bias=False) + + self.ada_modulate_layer = nn.Sequential( + nn.SiLU(), + nn.Linear( + int(dim * 0.25), + int(dim * 2), + ), + ) + self.norm2 = FP32LayerNorm(dim, eps) + self.norm3 = FP32LayerNorm(dim, eps) + + # 3. Feed-forward + # Use SwiGLU activation for gated linear units, GELU otherwise + activation_fn = "swiglu" if gated_linear_unit else "gelu" + self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn=activation_fn, bias=False) + + self.norm4 = FP32LayerNorm(dim, eps) + with torch.no_grad(): + self.norm2.weight.add_(1.0) + self.norm4.weight.add_(1.0) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + rotary_emb: torch.Tensor, + temb_map: torch.Tensor, + encoder_attention_mask: Optional[torch.Tensor] = None, + self_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + residual = hidden_states.float() + + mixed_qqkv = self.norm1(hidden_states) + + self_attn_output = self.attn1(mixed_qqkv, None, None, rotary_emb, attention_kwargs=self_attention_kwargs) + + cross_attn_output = self.attn2( + mixed_qqkv, encoder_hidden_states, encoder_attention_mask, None, attention_kwargs=encoder_attention_kwargs + ) + + # 3. Concatenate attention outputs + # Shape: [sq, b, num_heads * head_dim + num_heads * head_dim] = [sq, b, 2 * dim] + hidden_states = torch.concat([self_attn_output, cross_attn_output], dim=2) + + # 4. Rearrange to interleave query groups from self and cross attention + # This matches the original: rearrange(attn_out, "sq b (n hn hd) -> sq b (hn n hd)", n=2, hn=num_query_groups) + # The interleaving is done at the query group level (not per-head level) + # For 48 heads with 8 query groups: each group has 6 heads = 768 dims + # Interleaving pattern: [self_g0, cross_g0, self_g1, cross_g1, ..., self_g7, cross_g7] + batch_size, seq_len, _ = hidden_states.shape + num_query_groups = self.attn1.kv_heads if self.attn1.kv_heads is not None else self.attn1.heads + group_dim = self_attn_output.shape[2] // num_query_groups + hidden_states = hidden_states.reshape(batch_size, seq_len, 2, num_query_groups, group_dim) + hidden_states = hidden_states.permute(0, 1, 3, 2, 4) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1) + + # 5. Apply combined projection + hidden_states = self.attn_proj(hidden_states) + + gate_output = self.ada_modulate_layer(temb) + # Softcap with 1.0 + gate_output = torch.tanh(gate_output.float()).to(gate_output.dtype) + gate_msa, gate_mlp = gate_output.chunk(2, dim=-1) + + # Residual connection for self-attention + original_dtype = hidden_states.dtype + hidden_states = range_mod_pytorch(hidden_states.float().transpose(0, 1), temb_map, gate_msa).transpose(0, 1) + hidden_states = self.norm2(hidden_states) + hidden_states = hidden_states + residual + residual = hidden_states + hidden_states = hidden_states.to(original_dtype) + + hidden_states = self.norm3(hidden_states) + hidden_states = self.ffn(hidden_states) + + # Residual connection for MLP + original_dtype = hidden_states.dtype + hidden_states = range_mod_pytorch(hidden_states.float().transpose(0, 1), temb_map, gate_mlp).transpose(0, 1) + hidden_states = self.norm4(hidden_states) + hidden_states = hidden_states + residual + hidden_states = hidden_states.to(original_dtype) + return hidden_states + + +class Magi1Transformer3DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin +): + r""" + A Transformer model for video-like data used in the Magi1 model. + + This model implements a 3D transformer architecture for video generation with support for text conditioning and + optional image conditioning. The model uses rotary position embeddings and adaptive layer normalization for + temporal and spatial modeling. + + Args: + patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). + num_attention_heads (`int`, defaults to `16`): + The number of attention heads in each transformer block. + attention_head_dim (`int`, defaults to `64`): + The dimension of each attention head. + in_channels (`int`, defaults to `16`): + The number of input channels (from VAE latent space). + out_channels (`int`, defaults to `16`): + The number of output channels (to VAE latent space). + cross_attention_dim (`int`, defaults to `4096`): + The dimension of cross-attention (text encoder hidden size). + freq_dim (`int`, defaults to `256`): + Dimension for sinusoidal time embeddings. + ffn_dim (`int`, defaults to `4096`): + Intermediate dimension in feed-forward network. + num_layers (`int`, defaults to `34`): + The number of transformer layers to use. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + gated_linear_unit (`bool`, defaults to `False`): + Whether to use gated linear units (SwiGLU activation) in the feed-forward network. If True, uses SwiGLU + activation; if False, uses GELU activation. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "rope"] + _no_split_modules = ["Magi1TransformerBlock"] + _keep_in_fp32_modules = [ + "condition_embedder", + "scale_shift_table", + "norm_out", + "norm_q", + "norm_k", + "patch_embedding", + "rope", + ] + _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + _repeated_blocks = ["Magi1TransformerBlock"] + _cp_plan = { + "rope": { + 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), + 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), + }, + "blocks.0": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + }, + "blocks.*": { + "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + }, + "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + } + + @register_to_config + def __init__( + self, + patch_size: Tuple[int] = (1, 2, 2), + num_attention_heads: int = 24, + attention_head_dim: int = 128, + num_kv_heads: int = 8, + in_channels: int = 16, + out_channels: int = 16, + cross_attention_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 12288, + num_layers: int = 34, + eps: float = 1e-6, + x_rescale_factor: int = 1, + half_channel_vae: bool = False, + enable_distillation: bool = False, + gated_linear_unit: bool = False, + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Patch & position embedding + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size, bias=False) + self.rope = Magi1RotaryPosEmbed(inner_dim // num_attention_heads) + + # 2. Condition embeddings + self.condition_embedder = Magi1TimeTextEmbedding( + dim=inner_dim, + time_freq_dim=freq_dim, + text_embed_dim=cross_attention_dim, + enable_distillation=enable_distillation, + ) + + # 3. Transformer blocks + self.blocks = nn.ModuleList( + [ + Magi1TransformerBlock( + inner_dim, + ffn_dim, + num_attention_heads, + num_kv_heads, + eps, + gated_linear_unit, + ) + for _ in range(num_layers) + ] + ) + + # 4. Output norm & projection + self.norm_out = FP32LayerNorm(inner_dim, eps) + self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size), bias=False) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + denoising_range_num: Optional[int] = None, + range_num: Optional[int] = None, + slice_point: Optional[int] = 0, + kv_range: Optional[Tuple[int, int]] = None, + num_steps: Optional[int] = None, + distill_interval: Optional[int] = None, + extract_prefix_video_feature: Optional[bool] = False, + fwd_extra_1st_chunk: Optional[bool] = False, + distill_nearly_clean_chunk: Optional[bool] = False, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Forward pass of the MAGI-1 transformer. + + Args: + hidden_states (`torch.Tensor`): + Input tensor of shape `(batch_size, num_channels, num_frames, height, width)`. + timestep (`torch.LongTensor`): + Timesteps for diffusion process. Shape: `(batch_size, denoising_range_num)`. + encoder_hidden_states (`torch.Tensor`): + Text embeddings from the text encoder for cross-attention conditioning. + encoder_attention_mask (`torch.Tensor`, *optional*): + Attention mask for text embeddings to handle variable-length sequences. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a dictionary or a tuple. + attention_kwargs (`dict`, *optional*): + Additional keyword arguments for attention processors (e.g., LoRA scale). + denoising_range_num (`int`, *optional*): + Number of denoising ranges for autoregressive video generation. Each range represents a chunk of video + frames being denoised in parallel. + range_num (`int`, *optional*): + Total number of ranges in the video generation process. + slice_point (`int`, *optional*, defaults to 0): + Index indicating how many clean (already generated) frames precede the current denoising chunks. Used + for autoregressive context. + kv_range (`Tuple[int, int]`, *optional*): + Key-value attention ranges for each denoising chunk, defining which frames each chunk can attend to. + Required for MAGI-1's autoregressive attention pattern. + num_steps (`int`, *optional*): + Number of diffusion sampling steps. Used for distillation timestep adjustments. + distill_interval (`int`, *optional*): + Interval for distillation when using distilled models. Used with `num_steps`. + extract_prefix_video_feature (`bool`, *optional*, defaults to `False`): + Whether to extract features from prefix (clean) video frames. + fwd_extra_1st_chunk (`bool`, *optional*, defaults to `False`): + Whether to forward an extra first chunk in the current iteration. + distill_nearly_clean_chunk (`bool`, *optional*, defaults to `False`): + Whether to apply distillation to nearly clean chunks during generation. + + Returns: + `Transformer2DModelOutput` or `tuple`: + If `return_dict` is True, returns a `Transformer2DModelOutput` containing the sample. Otherwise, + returns a tuple with the sample as the first element. + + Note: + MAGI-1 uses an autoregressive video generation approach where video frames are generated in chunks. The + `denoising_range_num`, `kv_range`, and related parameters control this autoregressive pattern, allowing + each chunk to attend to previously generated (clean) frames while maintaining causal constraints. + """ + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + # `kv_range` is optional when not using MAGI varlen backend. No requirement here. + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + frame_in_range = post_patch_num_frames // denoising_range_num + prev_clean_T = frame_in_range * slice_point + T_total = post_patch_num_frames + prev_clean_T + + hidden_states = hidden_states * self.config.x_rescale_factor + + if self.config.half_channel_vae: + if hidden_states.shape[1] != 16: + raise ValueError( + "When `config.half_channel_vae` is True, the input `hidden_states` must have 16 channels." + ) + hidden_states = torch.cat([hidden_states, hidden_states], dim=1) + + # Patch & position embedding + hidden_states = self.patch_embedding(hidden_states) + freqs_cos, freqs_sin = self.rope(hidden_states, T_total) + # The shape of freqs_* is (total_seq_length, head_dim). Keep only the last tokens corresponding to current window. + keep = post_patch_num_frames * post_patch_height * post_patch_width + freqs_cos = freqs_cos[-keep:] + freqs_sin = freqs_sin[-keep:] + rotary_emb = (freqs_cos, freqs_sin) + + hidden_states = hidden_states.flatten(2).transpose( + 1, 2 + ) # (B, post_patch_num_frames * post_patch_height * post_patch_width, C) + + temb, y_encoder_attention, y_adaln = self.condition_embedder( + timestep.flatten(), + encoder_hidden_states, + num_steps, + distill_interval, + ) + + # Pool AdaLN text conditioning over valid tokens (mask) to get per-batch vector, then broadcast per range + if encoder_attention_mask is not None: + mask_2d = encoder_attention_mask.squeeze(1).squeeze(1).to(torch.bool) # [B, L] + denom = mask_2d.sum(dim=1, keepdim=True).clamp(min=1) + y_adaln_pooled = (y_adaln * mask_2d.unsqueeze(-1)).sum(dim=1) / denom + else: + mask_2d = None + y_adaln_pooled = y_adaln.mean(dim=1) + + temb = temb.reshape(batch_size, denoising_range_num, -1) + y_adaln_pooled.unsqueeze(1).expand( + -1, denoising_range_num, -1 + ) + + seqlen_per_chunk = (post_patch_num_frames * post_patch_height * post_patch_width) // denoising_range_num + temb_map = torch.arange(batch_size * denoising_range_num, device=hidden_states.device) + temb_map = torch.repeat_interleave(temb_map, seqlen_per_chunk) + temb_map = temb_map.reshape(batch_size, -1).transpose(0, 1).contiguous() + + # Build varlen metadata for MAGI backend + clip_token_nums = post_patch_height * post_patch_width * frame_in_range + + self_attention_kwargs = None + if kv_range is not None: + cu_seqlens_q = torch.tensor( + [0] + ([clip_token_nums] * denoising_range_num * batch_size), dtype=torch.int64, device=hidden_states.device + ).cumsum(-1).to(torch.int32) + # q_ranges pairs from cu_seqlens_q + q_ranges = torch.cat([cu_seqlens_q[:-1].unsqueeze(1), cu_seqlens_q[1:].unsqueeze(1)], dim=1) + flat_kv = torch.unique(kv_range, sorted=True) + max_seqlen_k = int((flat_kv[-1] - flat_kv[0]).item()) + self_attention_kwargs = { + "q_ranges": q_ranges, + "k_ranges": kv_range, + "max_seqlen_q": clip_token_nums, + "max_seqlen_k": max_seqlen_k, + } + + encoder_attention_kwargs = None + if mask_2d is not None: + y_index = mask_2d.sum(dim=-1).to(torch.int32) + cu_seqlens_q = torch.tensor( + [0] + ([clip_token_nums] * denoising_range_num * batch_size), dtype=torch.int64, device=hidden_states.device + ).cumsum(-1).to(torch.int32) + cu_seqlens_k = torch.cat([y_index.new_zeros(1, dtype=torch.int32), y_index.to(torch.int32)]).cumsum(-1) + q_ranges = torch.cat([cu_seqlens_q[:-1].unsqueeze(1), cu_seqlens_q[1:].unsqueeze(1)], dim=1) + k_ranges = torch.cat([cu_seqlens_k[:-1].unsqueeze(1), cu_seqlens_k[1:].unsqueeze(1)], dim=1) + max_seqlen_kv = int(y_index.max().item()) if y_index.numel() > 0 else 0 + encoder_attention_kwargs = { + "q_ranges": q_ranges, + "k_ranges": k_ranges, + "max_seqlen_q": clip_token_nums, + "max_seqlen_k": max_seqlen_kv, + "cu_seqlens_q": cu_seqlens_q, + "cu_seqlens_k": cu_seqlens_k, + } + + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.blocks: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + y_encoder_attention, + temb, + rotary_emb, + temb_map, + mask_2d, + self_attention_kwargs, + encoder_attention_kwargs, + ) + else: + for block in self.blocks: + hidden_states = block( + hidden_states, + y_encoder_attention, + temb, + rotary_emb, + temb_map, + mask_2d, + self_attention_kwargs, + encoder_attention_kwargs, + ) + + hidden_states = self.norm_out(hidden_states) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index c438caed571f..a3b8936f8fb4 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -28,6 +28,7 @@ "deprecated": [], "latent_diffusion": [], "ledits_pp": [], + "magi1": [], "marigold": [], "pag": [], "stable_diffusion": [], @@ -293,6 +294,7 @@ "MarigoldNormalsPipeline", ] ) + _import_structure["magi1"] = ["Magi1Pipeline", "Magi1ImageToVideoPipeline", "Magi1VideoToVideoPipeline"] _import_structure["mochi"] = ["MochiPipeline"] _import_structure["musicldm"] = ["MusicLDMPipeline"] _import_structure["omnigen"] = ["OmniGenPipeline"] @@ -689,6 +691,7 @@ from .lucy import LucyEditPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline + from .magi1 import Magi1ImageToVideoPipeline, Magi1Pipeline, Magi1VideoToVideoPipeline from .marigold import ( MarigoldDepthPipeline, MarigoldIntrinsicsPipeline, diff --git a/src/diffusers/pipelines/magi1/__init__.py b/src/diffusers/pipelines/magi1/__init__.py new file mode 100644 index 000000000000..5fc6735f6357 --- /dev/null +++ b/src/diffusers/pipelines/magi1/__init__.py @@ -0,0 +1,53 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_magi1"] = ["Magi1Pipeline"] + _import_structure["pipeline_magi1_i2v"] = ["Magi1ImageToVideoPipeline"] + _import_structure["pipeline_magi1_v2v"] = ["Magi1VideoToVideoPipeline"] + _import_structure["pipeline_output"] = ["Magi1PipelineOutput"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_magi1 import Magi1Pipeline + from .pipeline_magi1_i2v import Magi1ImageToVideoPipeline + from .pipeline_magi1_v2v import Magi1VideoToVideoPipeline + from .pipeline_output import Magi1PipelineOutput + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1.py b/src/diffusers/pipelines/magi1/pipeline_magi1.py new file mode 100644 index 000000000000..1948c9c7ca39 --- /dev/null +++ b/src/diffusers/pipelines/magi1/pipeline_magi1.py @@ -0,0 +1,1169 @@ +# Copyright 2025 The SandAI Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MAGI-1 T2V Pipeline with Autoregressive Chunked Generation +# +# ✅ IMPLEMENTED: +# - Autoregressive chunked generation (always enabled, matching original MAGI-1) +# - Window-based scheduling: chunk_width=6, window_size=4 +# - Progressive denoising across overlapping temporal windows +# - Proper CFG with separate forward passes (diffusers style) +# +# ⚠️ CURRENT LIMITATION: +# - No KV caching: attention is recomputed for previous chunks +# - This is less efficient than the original but fully functional +# +# ⏳ FUTURE OPTIMIZATIONS (when diffusers adds generic KV caching): +# 1. **KV Cache Management**: +# - Cache attention keys/values for previously denoised chunks +# - Reuse cached computations instead of recomputing +# - Will significantly speed up generation (2-3x faster expected) +# +# 2. **Special Token Support** (optional enhancement): +# - Duration tokens: indicate how many chunks remain to generate +# - Quality tokens: HQ_TOKEN for high-quality generation +# - Style tokens: THREE_D_MODEL_TOKEN, TWO_D_ANIME_TOKEN +# - Motion tokens: STATIC_FIRST_FRAMES_TOKEN, DYNAMIC_FIRST_FRAMES_TOKEN +# +# 3. **Streaming Generation**: +# - Yield clean chunks as they complete (generator pattern) +# - Enable real-time preview during generation +# +# Reference: https://github.com/SandAI/MAGI-1/blob/main/inference/pipeline/video_generate.py + +import html +import re +from typing import Any, Callable, Dict, List, Optional, Union + +import ftfy +import torch +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import Magi1LoraLoaderMixin +from ...models import AutoencoderKLMagi1, Magi1Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import Magi1PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +def generate_chunk_sequences(chunk_num: int, window_size: int, chunk_offset: int = 0): + """ + Generate chunk scheduling sequences for autoregressive video generation. + + Args: + chunk_num: Total number of chunks to generate + window_size: Number of chunks to process in each window + chunk_offset: Number of clean prefix chunks (for I2V/V2V) + + Returns: + ``` + clip_start: Start index of chunks to process + clip_end: End index of chunks to process + t_start: Start index in time dimension + t_end: End index in time dimension + ``` + + Examples: + ``` + chunk_num=8, window_size=4, chunk_offset=0 + Stage 0: Process chunks [0:1], denoise chunk 0 + Stage 1: Process chunks [0:2], denoise chunk 1 + Stage 2: Process chunks [0:3], denoise chunk 2 + Stage 3: Process chunks [0:4], denoise chunk 3 + Stage 4: Process chunks [1:5], denoise chunk 4 + ... + ``` + """ + start_index = chunk_offset + end_index = chunk_num + window_size - 1 + + clip_start = [max(chunk_offset, i - window_size + 1) for i in range(start_index, end_index)] + clip_end = [min(chunk_num, i + 1) for i in range(start_index, end_index)] + + t_start = [max(0, i - chunk_num + 1) for i in range(start_index, end_index)] + t_end = [ + min(window_size, i - chunk_offset + 1) if i - chunk_offset < window_size else window_size + for i in range(start_index, end_index) + ] + + return clip_start, clip_end, t_start, t_end + + +def load_special_tokens(special_tokens_path: Optional[str] = None) -> Optional[Dict[str, torch.Tensor]]: + """ + Load special conditioning tokens from numpy file. + + Args: + special_tokens_path: Path to special_tokens.npz file. If None, returns None (no special tokens). + + Returns: + Dictionary mapping token names to embeddings, or None if path not provided or file doesn't exist. + """ + if special_tokens_path is None: + return None + + try: + import os + + import numpy as np + + if not os.path.exists(special_tokens_path): + logger.warning(f"Special tokens file not found at {special_tokens_path}, skipping special token loading.") + return None + + special_token_data = np.load(special_tokens_path) + caption_token = torch.tensor(special_token_data["caption_token"].astype(np.float16)) + logo_token = torch.tensor(special_token_data["logo_token"].astype(np.float16)) + other_tokens = special_token_data["other_tokens"] + + tokens = { + "CAPTION_TOKEN": caption_token, + "LOGO_TOKEN": logo_token, + "TRANS_TOKEN": torch.tensor(other_tokens[:1].astype(np.float16)), + "HQ_TOKEN": torch.tensor(other_tokens[1:2].astype(np.float16)), + "STATIC_FIRST_FRAMES_TOKEN": torch.tensor(other_tokens[2:3].astype(np.float16)), + "DYNAMIC_FIRST_FRAMES_TOKEN": torch.tensor(other_tokens[3:4].astype(np.float16)), + "BORDERNESS_TOKEN": torch.tensor(other_tokens[4:5].astype(np.float16)), + "THREE_D_MODEL_TOKEN": torch.tensor(other_tokens[15:16].astype(np.float16)), + "TWO_D_ANIME_TOKEN": torch.tensor(other_tokens[16:17].astype(np.float16)), + } + + # Duration tokens (8 total, representing 1-8 chunks remaining) + for i in range(8): + tokens[f"DURATION_TOKEN_{i + 1}"] = torch.tensor(other_tokens[i + 7 : i + 8].astype(np.float16)) + + logger.info(f"Loaded {len(tokens)} special tokens from {special_tokens_path}") + return tokens + except Exception as e: + logger.warning(f"Failed to load special tokens: {e}") + return None + + +def prepend_special_tokens( + prompt_embeds: torch.Tensor, + special_tokens: Optional[Dict[str, torch.Tensor]], + use_hq_token: bool = False, + use_3d_style: bool = False, + use_2d_anime_style: bool = False, + use_static_first_frames: bool = False, + use_dynamic_first_frames: bool = False, + max_sequence_length: int = 800, +) -> torch.Tensor: + """ + Prepend special conditioning tokens to text embeddings. + + Args: + prompt_embeds: Text embeddings [batch, seq_len, hidden_dim] + special_tokens: Dictionary of special token embeddings + use_hq_token: Whether to add high-quality token + use_3d_style: Whether to add 3D model style token + use_2d_anime_style: Whether to add 2D anime style token + use_static_first_frames: Whether to add static motion token + use_dynamic_first_frames: Whether to add dynamic motion token + max_sequence_length: Maximum sequence length after prepending + + Returns: + Text embeddings with special tokens prepended + """ + if special_tokens is None: + return prompt_embeds + + device = prompt_embeds.device + dtype = prompt_embeds.dtype + batch_size, seq_len, hidden_dim = prompt_embeds.shape + + # Collect tokens to prepend (in order: motion, quality, style) + tokens_to_add = [] + if use_static_first_frames and "STATIC_FIRST_FRAMES_TOKEN" in special_tokens: + tokens_to_add.append(special_tokens["STATIC_FIRST_FRAMES_TOKEN"]) + if use_dynamic_first_frames and "DYNAMIC_FIRST_FRAMES_TOKEN" in special_tokens: + tokens_to_add.append(special_tokens["DYNAMIC_FIRST_FRAMES_TOKEN"]) + if use_hq_token and "HQ_TOKEN" in special_tokens: + tokens_to_add.append(special_tokens["HQ_TOKEN"]) + if use_3d_style and "THREE_D_MODEL_TOKEN" in special_tokens: + tokens_to_add.append(special_tokens["THREE_D_MODEL_TOKEN"]) + if use_2d_anime_style and "TWO_D_ANIME_TOKEN" in special_tokens: + tokens_to_add.append(special_tokens["TWO_D_ANIME_TOKEN"]) + + # Prepend tokens + for token in tokens_to_add: + token = token.to(device=device, dtype=dtype).unsqueeze(0).expand(batch_size, -1, -1) + prompt_embeds = torch.cat([token, prompt_embeds], dim=1) + + # Truncate to max length + if prompt_embeds.shape[1] > max_sequence_length: + prompt_embeds = prompt_embeds[:, :max_sequence_length, :] + + return prompt_embeds + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import Magi1Pipeline, AutoencoderKLMagi1, FlowMatchEulerDiscreteScheduler + >>> from diffusers.utils import export_to_video + + >>> model_id = "SandAI/Magi1-T2V-14B-480P-Diffusers" + >>> vae = AutoencoderKLMagi1.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + + >>> # IMPORTANT: MAGI-1 requires shift=3.0 for the scheduler (SD3-style time resolution transform) + >>> scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler", shift=3.0) + + >>> pipe = Magi1Pipeline.from_pretrained(model_id, vae=vae, scheduler=scheduler, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen." + >>> negative_prompt = "Bright tones, overexposed, static, blurred details, worst quality, low quality" + + >>> output = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=720, + ... width=1280, + ... num_frames=81, + ... guidance_scale=5.0, + ... num_inference_steps=50, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=16) + ``` +""" + + +class Magi1Pipeline(DiffusionPipeline, Magi1LoraLoaderMixin): + r""" + Pipeline for text-to-video generation using Magi1. + + MAGI-1 is a DiT-based video generation model that supports autoregressive chunked generation for long videos. + + **Note**: This implementation uses autoregressive chunked generation (chunk_width=6, window_size=4) as in the + original MAGI-1 paper, with support for special conditioning tokens for quality, style, and motion control. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`Magi1Transformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A flow matching scheduler with Euler discretization, using SD3-style time resolution transform. + vae ([`AutoencoderKLMagi1`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: Magi1Transformer3DModel, + vae: AutoencoderKLMagi1, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4 + ) + self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Special tokens for conditioning (optional) + self.special_tokens = None + + def load_special_tokens_from_file(self, special_tokens_path: str): + """ + Load special conditioning tokens from a numpy file. + + Args: + special_tokens_path: Path to special_tokens.npz file + """ + self.special_tokens = load_special_tokens(special_tokens_path) + if self.special_tokens is not None: + logger.info("Special tokens loaded successfully. You can now use quality, style, and motion control.") + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 800, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + # Repeat mask the same way as embeddings and keep size [B*num, L] + mask = mask.repeat(1, num_videos_per_prompt) + mask = mask.view(batch_size * num_videos_per_prompt, -1).to(device) + + return prompt_embeds, mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 800, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + else: + prompt_mask = None + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + else: + negative_mask = None + return prompt_embeds, negative_prompt_embeds, prompt_mask, negative_mask + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 800, + use_hq_token: bool = False, + use_3d_style: bool = False, + use_2d_anime_style: bool = False, + use_static_first_frames: bool = False, + use_dynamic_first_frames: bool = False, + enable_distillation: bool = False, + distill_nearly_clean_chunk_threshold: float = 0.3, + ): + r""" + The call function to the pipeline for generation. + + **Note**: This implementation uses autoregressive chunked generation (chunk_width=6, window_size=4) as in the + original MAGI-1 paper. The implementation currently works without KV caching (attention is recomputed for + previous chunks), which is less efficient than the original but still functional. KV caching optimization will + be added when diffusers implements generic caching support for transformers. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + height (`int`, defaults to `480`): + The height in pixels of the generated image. + width (`int`, defaults to `832`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, negative_prompt_embeds will be generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. Choose between `"latent"`, `"pt"`, or `"np"`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`Magi1PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `800`): + The maximum sequence length for the text encoder. Sequences longer than this will be truncated. MAGI-1 + uses a max length of 800 tokens. + use_hq_token (`bool`, *optional*, defaults to `False`): + Whether to prepend the high-quality control token to the text embeddings. This token conditions the + model to generate higher quality outputs. Requires special tokens to be loaded via + `load_special_tokens_from_file`. + use_3d_style (`bool`, *optional*, defaults to `False`): + Whether to prepend the 3D model style token to the text embeddings. This token conditions the model to + generate outputs with 3D modeling aesthetics. Requires special tokens to be loaded. + use_2d_anime_style (`bool`, *optional*, defaults to `False`): + Whether to prepend the 2D anime style token to the text embeddings. This token conditions the model to + generate outputs with 2D anime aesthetics. Requires special tokens to be loaded. + use_static_first_frames (`bool`, *optional*, defaults to `False`): + Whether to prepend the static first frames token to the text embeddings. This token conditions the + model to start the video with minimal motion in the first few frames. Requires special tokens to be + loaded. + use_dynamic_first_frames (`bool`, *optional*, defaults to `False`): + Whether to prepend the dynamic first frames token to the text embeddings. This token conditions the + model to start the video with significant motion in the first few frames. Requires special tokens to be + loaded. + enable_distillation (`bool`, *optional*, defaults to `False`): + Whether to enable distillation mode. In distillation mode, the model uses modified timestep embeddings + to support distilled (faster) inference. This requires a distilled model checkpoint. + distill_nearly_clean_chunk_threshold (`float`, *optional*, defaults to `0.3`): + Threshold for identifying nearly-clean chunks in distillation mode. Chunks with timestep > threshold + are considered nearly clean and processed differently. Only used when `enable_distillation=True`. + + Examples: + + Returns: + [`~Magi1PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`Magi1PipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated videos. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds, prompt_mask, negative_mask = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + dtype=self.text_encoder.dtype, + ) + + # 3.5. Prepend special tokens if requested + if self.special_tokens is not None and any( + [use_hq_token, use_3d_style, use_2d_anime_style, use_static_first_frames, use_dynamic_first_frames] + ): + prompt_embeds = prepend_special_tokens( + prompt_embeds=prompt_embeds, + special_tokens=self.special_tokens, + use_hq_token=use_hq_token, + use_3d_style=use_3d_style, + use_2d_anime_style=use_2d_anime_style, + use_static_first_frames=use_static_first_frames, + use_dynamic_first_frames=use_dynamic_first_frames, + max_sequence_length=max_sequence_length, + ) + if negative_prompt_embeds is not None: + negative_prompt_embeds = prepend_special_tokens( + prompt_embeds=negative_prompt_embeds, + special_tokens=self.special_tokens, + use_hq_token=use_hq_token, + use_3d_style=use_3d_style, + use_2d_anime_style=use_2d_anime_style, + use_static_first_frames=use_static_first_frames, + use_dynamic_first_frames=use_dynamic_first_frames, + max_sequence_length=max_sequence_length, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Denoising loop (autoregressive chunked generation) + # MAGI-1 always uses autoregressive generation with chunk_width=6 and window_size=4 + # Note: num_warmup_steps is calculated for compatibility but not used in progress bar logic + # because autoregressive generation has a different iteration structure (stages × steps) + # For FlowMatchEulerDiscreteScheduler (order=1), this doesn't affect the results + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + # Autoregressive chunked generation parameters + chunk_width = 6 # Original MAGI-1 default + window_size = 4 # Original MAGI-1 default + + num_latent_frames = latents.shape[2] + num_chunks = (num_latent_frames + chunk_width - 1) // chunk_width + + # Calculate chunk scheduling: which chunks to process at each stage + clip_start, clip_end, t_start, t_end = generate_chunk_sequences(num_chunks, window_size, chunk_offset=0) + num_stages = len(clip_start) + + # Number of denoising steps per stage + denoise_step_per_stage = len(timesteps) // window_size + + # Track how many times each chunk has been denoised + chunk_denoise_count = {i: 0 for i in range(num_chunks)} + + with self.progress_bar(total=num_stages * denoise_step_per_stage) as progress_bar: + for stage_idx in range(num_stages): + # Determine which chunks to process in this stage + chunk_start_idx = clip_start[stage_idx] + chunk_end_idx = clip_end[stage_idx] + t_start_idx = t_start[stage_idx] + t_end_idx = t_end[stage_idx] + + # Extract chunk range in latent space + latent_start = chunk_start_idx * chunk_width + latent_end = min(chunk_end_idx * chunk_width, num_latent_frames) + + # Number of chunks in current window + num_chunks_in_window = chunk_end_idx - chunk_start_idx + + # Prepare per-chunk conditioning with duration/borderness tokens + # Duration tokens indicate how many chunks remain in the video + # Borderness tokens condition on chunk boundaries + chunk_prompt_embeds_list = [] + chunk_negative_prompt_embeds_list = [] + chunk_prompt_masks_list = [] + chunk_negative_masks_list = [] + + if self.special_tokens is not None: + # Prepare per-chunk embeddings with duration tokens + # Each chunk gets a different duration token based on chunks remaining + for i, chunk_idx in enumerate(range(chunk_start_idx, chunk_end_idx)): + chunks_remaining = num_chunks - chunk_idx - 1 + # Duration token ranges from 1-8 chunks + duration_idx = min(chunks_remaining, 7) + 1 + + # Add duration and borderness tokens for this chunk + token_embeds = prompt_embeds.clone() + if f"DURATION_TOKEN_{duration_idx}" in self.special_tokens: + duration_token = self.special_tokens[f"DURATION_TOKEN_{duration_idx}"] + duration_token = duration_token.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) + duration_token = duration_token.unsqueeze(0).expand(batch_size, -1, -1) + token_embeds = torch.cat([duration_token, token_embeds], dim=1) + + if "BORDERNESS_TOKEN" in self.special_tokens: + borderness_token = self.special_tokens["BORDERNESS_TOKEN"] + borderness_token = borderness_token.to( + device=prompt_embeds.device, dtype=prompt_embeds.dtype + ) + borderness_token = borderness_token.unsqueeze(0).expand(batch_size, -1, -1) + token_embeds = torch.cat([borderness_token, token_embeds], dim=1) + + # Build per-chunk mask by prepending ones for each added token and truncating + token_mask = prompt_mask + add_count = 0 + if f"DURATION_TOKEN_{duration_idx}" in self.special_tokens: + add_count += 1 + if "BORDERNESS_TOKEN" in self.special_tokens: + add_count += 1 + if add_count > 0: + prepend = torch.ones( + token_mask.shape[0], add_count, dtype=token_mask.dtype, device=token_mask.device + ) + token_mask = torch.cat([prepend, token_mask], dim=1) + if token_embeds.shape[1] > max_sequence_length: + token_embeds = token_embeds[:, :max_sequence_length, :] + token_mask = token_mask[:, :max_sequence_length] + + chunk_prompt_embeds_list.append(token_embeds) + chunk_prompt_masks_list.append(token_mask) + + # Same for negative prompts + if self.do_classifier_free_guidance: + neg_token_embeds = negative_prompt_embeds.clone() + if f"DURATION_TOKEN_{duration_idx}" in self.special_tokens: + duration_token = self.special_tokens[f"DURATION_TOKEN_{duration_idx}"] + duration_token = duration_token.to( + device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype + ) + duration_token = duration_token.unsqueeze(0).expand(batch_size, -1, -1) + neg_token_embeds = torch.cat([duration_token, neg_token_embeds], dim=1) + + if "BORDERNESS_TOKEN" in self.special_tokens: + borderness_token = self.special_tokens["BORDERNESS_TOKEN"] + borderness_token = borderness_token.to( + device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype + ) + borderness_token = borderness_token.unsqueeze(0).expand(batch_size, -1, -1) + neg_token_embeds = torch.cat([borderness_token, neg_token_embeds], dim=1) + + # Build negative per-chunk mask + neg_mask = negative_mask + add_count = 0 + if f"DURATION_TOKEN_{duration_idx}" in self.special_tokens: + add_count += 1 + if "BORDERNESS_TOKEN" in self.special_tokens: + add_count += 1 + if add_count > 0: + prepend = torch.ones( + neg_mask.shape[0], add_count, dtype=neg_mask.dtype, device=neg_mask.device + ) + neg_mask = torch.cat([prepend, neg_mask], dim=1) + if neg_token_embeds.shape[1] > max_sequence_length: + neg_token_embeds = neg_token_embeds[:, :max_sequence_length, :] + neg_mask = neg_mask[:, :max_sequence_length] + + chunk_negative_prompt_embeds_list.append(neg_token_embeds) + chunk_negative_masks_list.append(neg_mask) + + # Denoise this chunk range for denoise_step_per_stage steps + for denoise_idx in range(denoise_step_per_stage): + if self.interrupt: + break + + # Calculate timestep index for each chunk in the current window + # Chunks at different stages get different timesteps based on their denoise progress + # Original MAGI-1: get_timestep() and get_denoise_step_of_each_chunk() + timestep_indices = [] + for offset in range(num_chunks_in_window): + # Map offset within window to time index + t_idx_within_window = t_start_idx + offset + t_idx = t_idx_within_window * denoise_step_per_stage + denoise_idx + timestep_indices.append(t_idx) + + # Reverse order: later chunks in window are noisier (higher timestep index) + timestep_indices.reverse() + + # Clamp indices to valid range + timestep_indices = [min(idx, len(timesteps) - 1) for idx in timestep_indices] + + # Get actual timesteps + current_timesteps = timesteps[timestep_indices] + + # Create per-chunk timestep tensor: [batch_size, num_chunks_in_window] + # Each chunk gets its own timestep based on how many times it's been denoised + timestep_per_chunk = current_timesteps.unsqueeze(0).expand(batch_size, -1) + + # Store first timestep for progress tracking + self._current_timestep = current_timesteps[0] + + # Extract chunk + latent_chunk = latents[:, :, latent_start:latent_end].to(transformer_dtype) + + # Prepare distillation parameters if enabled + num_steps = None + distill_interval = None + distill_nearly_clean_chunk = None + + if enable_distillation: + # Distillation mode uses modified timestep embeddings for faster inference + # The interval represents the step size in the distilled schedule + num_steps = num_inference_steps + distill_interval = len(timesteps) / num_inference_steps + + # Determine if chunks are nearly clean (low noise) based on their timesteps + # Check the first chunk's timestep (after reversing, this is the noisiest actively denoising chunk) + # Original: checks t[0, int(fwd_extra_1st_chunk)].item() + nearly_clean_chunk_t = current_timesteps[0].item() / self.scheduler.config.num_train_timesteps + distill_nearly_clean_chunk = nearly_clean_chunk_t < distill_nearly_clean_chunk_threshold + + # Prepare per-chunk embeddings + # The transformer expects embeddings in shape [batch_size * num_chunks_in_window, seq_len, hidden_dim] + # Each chunk gets its own embedding with appropriate duration/borderness tokens + if chunk_prompt_embeds_list: + # Stack per-chunk embeddings: [num_chunks_in_window, batch_size, seq_len, hidden_dim] + chunk_prompt_embeds = torch.stack(chunk_prompt_embeds_list, dim=0) + # Reshape to [batch_size * num_chunks_in_window, seq_len, hidden_dim] + chunk_prompt_embeds = chunk_prompt_embeds.transpose(0, 1).flatten(0, 1) + + if chunk_negative_prompt_embeds_list: + chunk_negative_prompt_embeds = torch.stack(chunk_negative_prompt_embeds_list, dim=0) + chunk_negative_prompt_embeds = chunk_negative_prompt_embeds.transpose(0, 1).flatten(0, 1) + else: + chunk_negative_prompt_embeds = None + else: + # Fallback: repeat shared embeddings for each chunk + chunk_prompt_embeds = prompt_embeds.unsqueeze(1).repeat(1, num_chunks_in_window, 1, 1) + chunk_prompt_embeds = chunk_prompt_embeds.flatten(0, 1) + prompt_mask_rep = prompt_mask.unsqueeze(1).repeat(1, num_chunks_in_window, 1) + prompt_mask_rep = prompt_mask_rep.flatten(0, 1) + + if negative_prompt_embeds is not None: + chunk_negative_prompt_embeds = negative_prompt_embeds.unsqueeze(1).repeat( + 1, num_chunks_in_window, 1, 1 + ) + chunk_negative_prompt_embeds = chunk_negative_prompt_embeds.flatten(0, 1) + negative_mask_rep = negative_mask.unsqueeze(1).repeat(1, num_chunks_in_window, 1) + negative_mask_rep = negative_mask_rep.flatten(0, 1) + else: + chunk_negative_prompt_embeds = None + negative_mask_rep = None + + # Create encoder attention mask(s) for per-chunk embeddings from tokenizer masks + if chunk_prompt_masks_list: + chunk_prompt_masks = torch.stack(chunk_prompt_masks_list, dim=0).transpose(0, 1).flatten(0, 1) + else: + chunk_prompt_masks = prompt_mask_rep + encoder_attention_mask = chunk_prompt_masks.to(dtype=chunk_prompt_embeds.dtype).view( + batch_size * num_chunks_in_window, 1, 1, -1 + ) + if chunk_negative_masks_list: + chunk_negative_masks = torch.stack(chunk_negative_masks_list, dim=0).transpose(0, 1).flatten(0, 1) + encoder_attention_mask_neg = chunk_negative_masks.to(dtype=chunk_prompt_embeds.dtype).view( + batch_size * num_chunks_in_window, 1, 1, -1 + ) + else: + encoder_attention_mask_neg = ( + negative_mask_rep.to(dtype=chunk_prompt_embeds.dtype).view( + batch_size * num_chunks_in_window, 1, 1, -1 + ) + if negative_mask_rep is not None + else encoder_attention_mask + ) + + # Generate KV range for autoregressive attention + # Each chunk can attend to itself and all previous chunks in the sequence + # Shape: [batch_size * num_chunks_in_window, 2] where each row is [start_token_idx, end_token_idx] + chunk_token_nums = ( + (latent_chunk.shape[2] // num_chunks_in_window) # frames per chunk + * (latent_chunk.shape[3] // self.transformer.config.patch_size[1]) # height tokens + * (latent_chunk.shape[4] // self.transformer.config.patch_size[2]) # width tokens + ) + kv_range = [] + for b in range(batch_size): + # Calculate proper batch offset based on total number of chunks in video + batch_offset = b * num_chunks + for c in range(num_chunks_in_window): + # This chunk can attend from the start of its batch up to its own end + chunk_global_idx = chunk_start_idx + c + k_start = batch_offset * chunk_token_nums + k_end = (batch_offset + chunk_global_idx + 1) * chunk_token_nums + kv_range.append([k_start, k_end]) + kv_range = torch.tensor(kv_range, dtype=torch.int32, device=device) + + # Predict noise (conditional) + noise_pred = self.transformer( + hidden_states=latent_chunk, + timestep=timestep_per_chunk, + encoder_hidden_states=chunk_prompt_embeds, + encoder_attention_mask=encoder_attention_mask, + attention_kwargs=attention_kwargs, + denoising_range_num=num_chunks_in_window, + range_num=chunk_end_idx, + slice_point=chunk_start_idx, + kv_range=kv_range, + num_steps=num_steps, + distill_interval=distill_interval, + distill_nearly_clean_chunk=distill_nearly_clean_chunk, + return_dict=False, + )[0] + + # Classifier-free guidance + if self.do_classifier_free_guidance: + noise_pred_uncond = self.transformer( + hidden_states=latent_chunk, + timestep=timestep_per_chunk, + encoder_hidden_states=chunk_negative_prompt_embeds, + encoder_attention_mask=encoder_attention_mask_neg, + attention_kwargs=attention_kwargs, + denoising_range_num=num_chunks_in_window, + range_num=chunk_end_idx, + slice_point=chunk_start_idx, + kv_range=kv_range, + num_steps=num_steps, + distill_interval=distill_interval, + distill_nearly_clean_chunk=distill_nearly_clean_chunk, + return_dict=False, + )[0] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) + + # Update latent chunk using scheduler step + # FlowMatchEulerDiscreteScheduler implements: x_new = x_old + dt * velocity + # For per-chunk timesteps, we use manual integration since scheduler doesn't support + # different timesteps for different chunks in a single call + + # Calculate dt for each chunk + # Get next timesteps for integration + next_timestep_indices = [min(idx + 1, len(timesteps) - 1) for idx in timestep_indices] + next_timesteps = timesteps[next_timestep_indices] + + # Convert to sigmas (time in [0, 1] range) + current_sigmas = current_timesteps / self.scheduler.config.num_train_timesteps + next_sigmas = next_timesteps / self.scheduler.config.num_train_timesteps + + # Reshape for per-chunk application: [batch_size, num_chunks_in_window, 1, 1, 1] + dt = (next_sigmas - current_sigmas).view(batch_size, num_chunks_in_window, 1, 1, 1) + + # Reshape latents and velocity to separate chunks: [batch_size, channels, chunks, frames_per_chunk, h, w] + B, C, T, H, W = latent_chunk.shape + frames_per_chunk = T // num_chunks_in_window + latent_chunk = latent_chunk.view(B, C, num_chunks_in_window, frames_per_chunk, H, W) + noise_pred = noise_pred.view(B, C, num_chunks_in_window, frames_per_chunk, H, W) + + # Apply Euler integration per chunk: x_new = x_old + dt * velocity + latent_chunk = latent_chunk + dt * noise_pred + + # Reshape back to original shape + latent_chunk = latent_chunk.view(B, C, T, H, W) + + # Write back to full latents + latents[:, :, latent_start:latent_end] = latent_chunk + + # Update chunk denoise counts + for chunk_idx in range(chunk_start_idx, chunk_end_idx): + chunk_denoise_count[chunk_idx] += 1 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end( + self, + stage_idx * denoise_step_per_stage + denoise_idx, + current_timesteps[0], + callback_kwargs, + ) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return Magi1PipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1_i2v.py b/src/diffusers/pipelines/magi1/pipeline_magi1_i2v.py new file mode 100644 index 000000000000..e875d15a4948 --- /dev/null +++ b/src/diffusers/pipelines/magi1/pipeline_magi1_i2v.py @@ -0,0 +1,1530 @@ +# Copyright 2025 The SandAI Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MAGI-1 I2V Pipeline with Autoregressive Chunked Generation +# +# ✅ IMPLEMENTED: +# - Image-to-Video generation using prefix video conditioning +# - Autoregressive chunked generation (always enabled, matching original MAGI-1) +# - Window-based scheduling: chunk_width=6, window_size=4 +# - Progressive denoising across overlapping temporal windows +# - Proper CFG with separate forward passes (diffusers style) +# - Input image encoding to VAE latent as clean prefix chunk +# +# ⚠️ CURRENT LIMITATION: +# - No KV caching: attention is recomputed for previous chunks +# - This is less efficient than the original but fully functional +# +# ⏳ FUTURE OPTIMIZATIONS (when diffusers adds generic KV caching): +# 1. **KV Cache Management**: +# - Cache attention keys/values for previously denoised chunks +# - Reuse cached computations instead of recomputing +# - Will significantly speed up generation (2-3x faster expected) +# +# 2. **Special Token Support** (optional enhancement): +# - Duration tokens: indicate how many chunks remain to generate +# - Quality tokens: HQ_TOKEN for high-quality generation +# - Style tokens: THREE_D_MODEL_TOKEN, TWO_D_ANIME_TOKEN +# - Motion tokens: STATIC_FIRST_FRAMES_TOKEN, DYNAMIC_FIRST_FRAMES_TOKEN +# +# 3. **Streaming Generation**: +# - Yield clean chunks as they complete (generator pattern) +# - Enable real-time preview during generation +# +# Reference: https://github.com/SandAI/MAGI-1/blob/main/inference/pipeline/video_generate.py + +import html +import re +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import PIL.Image +import torch +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import Magi1LoraLoaderMixin +from ...models import AutoencoderKLMagi1, Magi1Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import Magi1PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def generate_chunk_sequences(chunk_num: int, window_size: int, chunk_offset: int = 0): + """ + Generate chunk scheduling sequences for autoregressive video generation. + + Args: + chunk_num: Total number of chunks to generate + window_size: Number of chunks to process in each window + chunk_offset: Number of clean prefix chunks (for I2V/V2V) + + Returns: + ``` + clip_start: Start index of chunks to process + clip_end: End index of chunks to process + t_start: Start index in time dimension + t_end: End index in time dimension + ``` + + Examples: + ``` + chunk_num=8, window_size=4, chunk_offset=0 + Stage 0: Process chunks [0:1], denoise chunk 0 + Stage 1: Process chunks [0:2], denoise chunk 1 + Stage 2: Process chunks [0:3], denoise chunk 2 + Stage 3: Process chunks [0:4], denoise chunk 3 + Stage 4: Process chunks [1:5], denoise chunk 4 + ... + ``` + """ + start_index = chunk_offset + end_index = chunk_num + window_size - 1 + + clip_start = [max(chunk_offset, i - window_size + 1) for i in range(start_index, end_index)] + clip_end = [min(chunk_num, i + 1) for i in range(start_index, end_index)] + + t_start = [max(0, i - chunk_num + 1) for i in range(start_index, end_index)] + t_end = [ + min(window_size, i - chunk_offset + 1) if i - chunk_offset < window_size else window_size + for i in range(start_index, end_index) + ] + + return clip_start, clip_end, t_start, t_end + + +def load_special_tokens(special_tokens_path: Optional[str] = None) -> Optional[Dict[str, torch.Tensor]]: + """ + Load special conditioning tokens from numpy file. + + Args: + special_tokens_path: Path to special_tokens.npz file. If None, returns None (no special tokens). + + Returns: + Dictionary mapping token names to embeddings, or None if path not provided or file doesn't exist. + """ + if special_tokens_path is None: + return None + + try: + import os + + import numpy as np + + if not os.path.exists(special_tokens_path): + logger.warning(f"Special tokens file not found at {special_tokens_path}, skipping special token loading.") + return None + + special_token_data = np.load(special_tokens_path) + caption_token = torch.tensor(special_token_data["caption_token"].astype(np.float16)) + logo_token = torch.tensor(special_token_data["logo_token"].astype(np.float16)) + other_tokens = special_token_data["other_tokens"] + + tokens = { + "CAPTION_TOKEN": caption_token, + "LOGO_TOKEN": logo_token, + "TRANS_TOKEN": torch.tensor(other_tokens[:1].astype(np.float16)), + "HQ_TOKEN": torch.tensor(other_tokens[1:2].astype(np.float16)), + "STATIC_FIRST_FRAMES_TOKEN": torch.tensor(other_tokens[2:3].astype(np.float16)), + "DYNAMIC_FIRST_FRAMES_TOKEN": torch.tensor(other_tokens[3:4].astype(np.float16)), + "BORDERNESS_TOKEN": torch.tensor(other_tokens[4:5].astype(np.float16)), + "THREE_D_MODEL_TOKEN": torch.tensor(other_tokens[15:16].astype(np.float16)), + "TWO_D_ANIME_TOKEN": torch.tensor(other_tokens[16:17].astype(np.float16)), + } + + # Duration tokens (8 total, representing 1-8 chunks remaining) + for i in range(8): + tokens[f"DURATION_TOKEN_{i + 1}"] = torch.tensor(other_tokens[i + 7 : i + 8].astype(np.float16)) + + logger.info(f"Loaded {len(tokens)} special tokens from {special_tokens_path}") + return tokens + except Exception as e: + logger.warning(f"Failed to load special tokens: {e}") + return None + + +def prepare_i2v_embeddings( + prompt_embeds: torch.Tensor, + negative_prompt_embeds: Optional[torch.Tensor], + num_chunks: int, + clean_chunk_num: int, + max_sequence_length: int = 800, + prompt_mask: Optional[torch.Tensor] = None, + negative_mask: Optional[torch.Tensor] = None, +) -> Tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], +]: + """ + Prepare per-chunk text embeddings for I2V generation. + + In I2V, clean prefix chunks (from the input image) use null embeddings, while chunks to be denoised use the actual + text embeddings. + + Args: + prompt_embeds: Text embeddings [batch_size, seq_len, hidden_dim] + negative_prompt_embeds: Negative text embeddings (optional) + num_chunks: Total number of chunks + clean_chunk_num: Number of clean prefix chunks (typically 1 for I2V single image) + max_sequence_length: Maximum sequence length + + Returns: + - prompt_embeds_per_chunk: [B, num_chunks, L, D] + - negative_prompt_embeds_per_chunk: [B, num_chunks, L, D] or None + - prompt_masks_per_chunk: [B, num_chunks, L] or None + - negative_masks_per_chunk: [B, num_chunks, L] or None + """ + batch_size = prompt_embeds.shape[0] + seq_len = prompt_embeds.shape[1] + hidden_dim = prompt_embeds.shape[2] + device = prompt_embeds.device + dtype = prompt_embeds.dtype + + # Number of chunks that need denoising + denoise_chunk_num = num_chunks - clean_chunk_num + + # Create null embeddings (zeros) for clean chunks + null_embeds = torch.zeros(batch_size, 1, seq_len, hidden_dim, device=device, dtype=dtype) + + # Expand prompt embeddings for denoise chunks + # Shape: [batch_size, denoise_chunk_num, seq_len, hidden_dim] + denoise_embeds = prompt_embeds.unsqueeze(1).repeat(1, denoise_chunk_num, 1, 1) + + # Concatenate: [null_embeds for clean chunks] + [text_embeds for denoise chunks] + # Shape: [batch_size, num_chunks, seq_len, hidden_dim] + if clean_chunk_num > 0: + null_embeds_expanded = null_embeds.repeat(1, clean_chunk_num, 1, 1) + prompt_embeds_per_chunk = torch.cat([null_embeds_expanded, denoise_embeds], dim=1) + else: + prompt_embeds_per_chunk = denoise_embeds + + # Build masks per chunk (zeros for clean chunks, prompt_mask for denoise chunks) + prompt_masks_per_chunk = None + negative_masks_per_chunk = None + + if prompt_mask is not None: + denoise_masks = prompt_mask.unsqueeze(1).repeat(1, denoise_chunk_num, 1) + if clean_chunk_num > 0: + null_masks = torch.zeros(prompt_mask.shape[0], clean_chunk_num, prompt_mask.shape[1], device=device, dtype=prompt_mask.dtype) + prompt_masks_per_chunk = torch.cat([null_masks, denoise_masks], dim=1) + else: + prompt_masks_per_chunk = denoise_masks + + # Same for negative embeddings + if negative_prompt_embeds is not None: + denoise_neg_embeds = negative_prompt_embeds.unsqueeze(1).repeat(1, denoise_chunk_num, 1, 1) + if clean_chunk_num > 0: + null_neg_embeds_expanded = null_embeds.repeat(1, clean_chunk_num, 1, 1) + negative_prompt_embeds_per_chunk = torch.cat([null_neg_embeds_expanded, denoise_neg_embeds], dim=1) + else: + negative_prompt_embeds_per_chunk = denoise_neg_embeds + else: + negative_prompt_embeds_per_chunk = None + + if negative_mask is not None: + denoise_neg_masks = negative_mask.unsqueeze(1).repeat(1, denoise_chunk_num, 1) + if clean_chunk_num > 0: + null_neg_masks = torch.zeros(negative_mask.shape[0], clean_chunk_num, negative_mask.shape[1], device=device, dtype=negative_mask.dtype) + negative_masks_per_chunk = torch.cat([null_neg_masks, denoise_neg_masks], dim=1) + else: + negative_masks_per_chunk = denoise_neg_masks + + return ( + prompt_embeds_per_chunk, + negative_prompt_embeds_per_chunk, + prompt_masks_per_chunk, + negative_masks_per_chunk, + ) + + +def prepend_special_tokens( + prompt_embeds: torch.Tensor, + special_tokens: Optional[Dict[str, torch.Tensor]], + use_hq_token: bool = False, + use_3d_style: bool = False, + use_2d_anime_style: bool = False, + use_static_first_frames: bool = False, + use_dynamic_first_frames: bool = False, + max_sequence_length: int = 800, +) -> torch.Tensor: + """ + Prepend special conditioning tokens to text embeddings. + + Args: + prompt_embeds: Text embeddings [batch, seq_len, hidden_dim] + special_tokens: Dictionary of special token embeddings + use_hq_token: Whether to add high-quality token + use_3d_style: Whether to add 3D model style token + use_2d_anime_style: Whether to add 2D anime style token + use_static_first_frames: Whether to add static motion token + use_dynamic_first_frames: Whether to add dynamic motion token + max_sequence_length: Maximum sequence length after prepending + + Returns: + Text embeddings with special tokens prepended + """ + if special_tokens is None: + return prompt_embeds + + device = prompt_embeds.device + dtype = prompt_embeds.dtype + batch_size, seq_len, hidden_dim = prompt_embeds.shape + + # Collect tokens to prepend (in order: motion, quality, style) + tokens_to_add = [] + if use_static_first_frames and "STATIC_FIRST_FRAMES_TOKEN" in special_tokens: + tokens_to_add.append(special_tokens["STATIC_FIRST_FRAMES_TOKEN"]) + if use_dynamic_first_frames and "DYNAMIC_FIRST_FRAMES_TOKEN" in special_tokens: + tokens_to_add.append(special_tokens["DYNAMIC_FIRST_FRAMES_TOKEN"]) + if use_hq_token and "HQ_TOKEN" in special_tokens: + tokens_to_add.append(special_tokens["HQ_TOKEN"]) + if use_3d_style and "THREE_D_MODEL_TOKEN" in special_tokens: + tokens_to_add.append(special_tokens["THREE_D_MODEL_TOKEN"]) + if use_2d_anime_style and "TWO_D_ANIME_TOKEN" in special_tokens: + tokens_to_add.append(special_tokens["TWO_D_ANIME_TOKEN"]) + + # Prepend tokens + for token in tokens_to_add: + token = token.to(device=device, dtype=dtype).unsqueeze(0).expand(batch_size, -1, -1) + prompt_embeds = torch.cat([token, prompt_embeds], dim=1) + + # Truncate to max length + if prompt_embeds.shape[1] > max_sequence_length: + prompt_embeds = prompt_embeds[:, :max_sequence_length, :] + + return prompt_embeds + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import Magi1ImageToVideoPipeline, AutoencoderKLMagi1, FlowMatchEulerDiscreteScheduler + >>> from diffusers.utils import export_to_video, load_image + + >>> model_id = "SandAI/Magi1-I2V-14B-480P-Diffusers" + >>> vae = AutoencoderKLMagi1.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + + >>> # IMPORTANT: MAGI-1 requires shift=3.0 for the scheduler (SD3-style time resolution transform) + >>> scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler", shift=3.0) + + >>> pipe = Magi1ImageToVideoPipeline.from_pretrained(model_id, vae=vae, scheduler=scheduler, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" + ... ) + >>> prompt = ( + ... "An astronaut walking on the moon's surface, with the Earth visible in the background. " + ... "The astronaut moves slowly in a low-gravity environment." + ... ) + >>> negative_prompt = "Bright tones, overexposed, static, blurred details, worst quality, low quality" + + >>> output = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=720, + ... width=1280, + ... num_frames=81, + ... guidance_scale=5.0, + ... num_inference_steps=50, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=16) + ``` +""" + + +class Magi1ImageToVideoPipeline(DiffusionPipeline, Magi1LoraLoaderMixin): + r""" + Pipeline for image-to-video generation using Magi1. + + MAGI-1 is a DiT-based video generation model that supports autoregressive chunked generation for long videos. This + I2V pipeline takes an input image and generates a video animation starting from that image. + + **Note**: This implementation uses autoregressive chunked generation (chunk_width=6, window_size=4) as in the + original MAGI-1 paper. The input image is encoded to a latent representation and used as a clean prefix chunk to + condition the generation. Text prompts provide additional semantic guidance for the animation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`Magi1Transformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A flow matching scheduler with Euler discretization, using SD3-style time resolution transform. + vae ([`AutoencoderKLMagi1`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: Magi1Transformer3DModel, + vae: AutoencoderKLMagi1, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4 + ) + self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Special tokens for conditioning (optional) + self.special_tokens = None + + def load_special_tokens_from_file(self, special_tokens_path: str): + """ + Load special conditioning tokens from a numpy file. + + Args: + special_tokens_path: Path to special_tokens.npz file + """ + self.special_tokens = load_special_tokens(special_tokens_path) + if self.special_tokens is not None: + logger.info("Special tokens loaded successfully. You can now use quality, style, and motion control.") + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 800, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + # Repeat mask similarly and keep [B*num, L] + mask = mask.repeat(1, num_videos_per_prompt) + mask = mask.view(batch_size * num_videos_per_prompt, -1).to(device) + + return prompt_embeds, mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 800, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + else: + prompt_mask = None + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + else: + negative_mask = None + + return prompt_embeds, negative_prompt_embeds, prompt_mask, negative_mask + + def check_inputs( + self, + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if image is None: + raise ValueError( + "Provide `image` for image-to-video generation. Cannot leave `image` undefined for I2V pipeline." + ) + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + def prepare_latents( + self, + image: Optional[PipelineImageInput], + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Prepare latents for I2V generation, including encoding the input image as prefix_video. + + Args: + image: Input image for I2V generation + batch_size: Batch size + num_channels_latents: Number of latent channels + height: Video height + width: Video width + num_frames: Total number of frames to generate + dtype: Data type + device: Device + generator: Random generator + latents: Pre-generated latents (optional) + + Returns: + Tuple of (latents, prefix_video) where: + - latents: Random noise tensor for generation [batch, channels, num_latent_frames, H, W] + - prefix_video: Encoded image as clean latent [batch, channels, 1, H, W] (or None if no image) + """ + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # Prepare random latents for generation + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + # Encode input image to latent as prefix_video + prefix_video = None + if image is not None: + # Preprocess image to target size + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + + # Add temporal dimension: [batch, channels, height, width] -> [batch, channels, 1, height, width] + if image.ndim == 4: + image = image.unsqueeze(2) + + # Encode to latent space using VAE + # VAE expects [batch, channels, frames, height, width] + if isinstance(generator, list): + prefix_video = [ + retrieve_latents(self.vae.encode(image), sample_mode="sample", generator=g) for g in generator + ] + prefix_video = torch.cat(prefix_video) + else: + prefix_video = retrieve_latents(self.vae.encode(image), sample_mode="sample", generator=generator) + prefix_video = prefix_video.repeat(batch_size, 1, 1, 1, 1) + + # Normalize latent using VAE's latent statistics + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(prefix_video.device, dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + prefix_video.device, dtype + ) + prefix_video = prefix_video.to(dtype) + prefix_video = (prefix_video - latents_mean) * latents_std + + return latents, prefix_video + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 800, + use_hq_token: bool = False, + use_3d_style: bool = False, + use_2d_anime_style: bool = False, + use_static_first_frames: bool = False, + use_dynamic_first_frames: bool = False, + enable_distillation: bool = False, + distill_nearly_clean_chunk_threshold: float = 0.3, + ): + r""" + The call function to the pipeline for image-to-video generation. + + **Note**: This implementation uses autoregressive chunked generation (chunk_width=6, window_size=4) as in the + original MAGI-1 paper. The input image is encoded to a VAE latent and used as a clean prefix chunk to condition + the video generation. The implementation currently works without KV caching (attention is recomputed for + previous chunks), which is less efficient than the original but still functional. KV caching optimization will + be added when diffusers implements generic caching support for transformers. + + Args: + image (`PipelineImageInput`): + The input image to condition the video generation on. Must be an image, a list of images, or a + `torch.Tensor`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to avoid during video generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + height (`int`, defaults to `480`): + The height in pixels of the generated image. + width (`int`, defaults to `832`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, negative_prompt_embeds will be generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. Choose between `"latent"`, `"pt"`, or `"np"`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`Magi1PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `800`): + The maximum sequence length for the text encoder. Sequences longer than this will be truncated. MAGI-1 + uses a max length of 800 tokens. + use_hq_token (`bool`, *optional*, defaults to `False`): + Whether to prepend the high-quality control token to the text embeddings. This token conditions the + model to generate higher quality outputs. Requires special tokens to be loaded via + `load_special_tokens_from_file`. + use_3d_style (`bool`, *optional*, defaults to `False`): + Whether to prepend the 3D model style token to the text embeddings. This token conditions the model to + generate outputs with 3D modeling aesthetics. Requires special tokens to be loaded. + use_2d_anime_style (`bool`, *optional*, defaults to `False`): + Whether to prepend the 2D anime style token to the text embeddings. This token conditions the model to + generate outputs with 2D anime aesthetics. Requires special tokens to be loaded. + use_static_first_frames (`bool`, *optional*, defaults to `False`): + Whether to prepend the static first frames token to the text embeddings. This token conditions the + model to start the video with minimal motion in the first few frames. Requires special tokens to be + loaded. + use_dynamic_first_frames (`bool`, *optional*, defaults to `False`): + Whether to prepend the dynamic first frames token to the text embeddings. This token conditions the + model to start the video with significant motion in the first few frames. Requires special tokens to be + loaded. + enable_distillation (`bool`, *optional*, defaults to `False`): + Whether to enable distillation mode. In distillation mode, the model uses modified timestep embeddings + to support distilled (faster) inference. This requires a distilled model checkpoint. + distill_nearly_clean_chunk_threshold (`float`, *optional*, defaults to `0.3`): + Threshold for identifying nearly-clean chunks in distillation mode. Chunks with timestep > threshold + are considered nearly clean and processed differently. Only used when `enable_distillation=True`. + + Examples: + + Returns: + [`~Magi1PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`Magi1PipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated videos. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds, prompt_mask, negative_mask = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + dtype=self.text_encoder.dtype, + ) + + # 3.5. Prepend special tokens if requested + if self.special_tokens is not None and any( + [use_hq_token, use_3d_style, use_2d_anime_style, use_static_first_frames, use_dynamic_first_frames] + ): + prompt_embeds = prepend_special_tokens( + prompt_embeds=prompt_embeds, + special_tokens=self.special_tokens, + use_hq_token=use_hq_token, + use_3d_style=use_3d_style, + use_2d_anime_style=use_2d_anime_style, + use_static_first_frames=use_static_first_frames, + use_dynamic_first_frames=use_dynamic_first_frames, + max_sequence_length=max_sequence_length, + ) + if negative_prompt_embeds is not None: + negative_prompt_embeds = prepend_special_tokens( + prompt_embeds=negative_prompt_embeds, + special_tokens=self.special_tokens, + use_hq_token=use_hq_token, + use_3d_style=use_3d_style, + use_2d_anime_style=use_2d_anime_style, + use_static_first_frames=use_static_first_frames, + use_dynamic_first_frames=use_dynamic_first_frames, + max_sequence_length=max_sequence_length, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents, prefix_video = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Denoising loop (autoregressive chunked generation with I2V prefix conditioning) + # MAGI-1 I2V uses autoregressive generation with chunk_width=6 and window_size=4 + # The input image is encoded as a clean prefix chunk and used to condition the generation + # Note: num_warmup_steps is calculated for compatibility but not used in progress bar logic + # because autoregressive generation has a different iteration structure (stages × steps) + # For FlowMatchEulerDiscreteScheduler (order=1), this doesn't affect the results + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + # Autoregressive chunked generation parameters + chunk_width = 6 # Original MAGI-1 default + window_size = 4 # Original MAGI-1 default + + num_latent_frames = latents.shape[2] + num_chunks = (num_latent_frames + chunk_width - 1) // chunk_width + + # Calculate chunk_offset from prefix_video (for I2V, this is the clean image chunk) + chunk_offset = 0 + if prefix_video is not None: + # prefix_video has shape [batch, channels, 1, height, width] for I2V + # Calculate how many chunks are covered by the prefix + prefix_latent_frames = prefix_video.shape[2] + chunk_offset = prefix_latent_frames // chunk_width + + # Pad prefix_video into latents at the beginning + # The prefix frames are already clean and don't need denoising + if prefix_latent_frames > 0: + prefix_video = prefix_video.to(latents.dtype) + latents[:, :, :prefix_latent_frames] = prefix_video + + # Calculate chunk scheduling: which chunks to process at each stage + # chunk_offset skips the clean prefix chunks + clip_start, clip_end, t_start, t_end = generate_chunk_sequences(num_chunks, window_size, chunk_offset) + num_stages = len(clip_start) + + # Prepare per-chunk text embeddings for I2V + # Clean chunks (from input image) use null embeddings, denoise chunks use text embeddings + ( + prompt_embeds_per_chunk, + negative_prompt_embeds_per_chunk, + prompt_masks_per_chunk, + negative_masks_per_chunk, + ) = prepare_i2v_embeddings( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + num_chunks=num_chunks, + clean_chunk_num=chunk_offset, + max_sequence_length=max_sequence_length, + prompt_mask=prompt_mask, + negative_mask=negative_mask, + ) + + # Number of denoising steps per stage + denoise_step_per_stage = len(timesteps) // window_size + + # Track how many times each chunk has been denoised + chunk_denoise_count = {i: 0 for i in range(num_chunks)} + + with self.progress_bar(total=num_stages * denoise_step_per_stage) as progress_bar: + for stage_idx in range(num_stages): + # Determine which chunks to process in this stage + chunk_start_idx = clip_start[stage_idx] + chunk_end_idx = clip_end[stage_idx] + t_start_idx = t_start[stage_idx] + t_end_idx = t_end[stage_idx] + + # Extract chunk range in latent space + latent_start = chunk_start_idx * chunk_width + latent_end = min(chunk_end_idx * chunk_width, num_latent_frames) + + # Number of chunks in current window + num_chunks_in_window = chunk_end_idx - chunk_start_idx + + # Prepare per-chunk conditioning with duration/borderness tokens + # Duration tokens indicate how many chunks remain in the video + # Borderness tokens condition on chunk boundaries + chunk_prompt_embeds_list = [] + chunk_negative_prompt_embeds_list = [] + chunk_prompt_masks_list = [] + chunk_negative_masks_list = [] + + if self.special_tokens is not None: + # Prepare per-chunk embeddings with duration tokens + # Each chunk gets a different duration token based on chunks remaining + for i, chunk_idx in enumerate(range(chunk_start_idx, chunk_end_idx)): + chunks_remaining = num_chunks - chunk_idx - 1 + # Duration token ranges from 1-8 chunks + duration_idx = min(chunks_remaining, 7) + 1 + + # Get base embeddings for this chunk (clean chunks have null embeds, denoise chunks have text embeds) + token_embeds = prompt_embeds_per_chunk[:, chunk_idx].clone() + + # Add duration and borderness tokens for this chunk + if f"DURATION_TOKEN_{duration_idx}" in self.special_tokens: + duration_token = self.special_tokens[f"DURATION_TOKEN_{duration_idx}"] + duration_token = duration_token.to(device=token_embeds.device, dtype=token_embeds.dtype) + duration_token = duration_token.unsqueeze(0).expand(batch_size, -1, -1) + token_embeds = torch.cat([duration_token, token_embeds], dim=1) + + if "BORDERNESS_TOKEN" in self.special_tokens: + borderness_token = self.special_tokens["BORDERNESS_TOKEN"] + borderness_token = borderness_token.to( + device=token_embeds.device, dtype=token_embeds.dtype + ) + borderness_token = borderness_token.unsqueeze(0).expand(batch_size, -1, -1) + token_embeds = torch.cat([borderness_token, token_embeds], dim=1) + + # Build mask for this chunk from base per-chunk mask + token_mask = ( + prompt_masks_per_chunk[:, chunk_idx] + if prompt_masks_per_chunk is not None + else torch.ones(batch_size, token_embeds.shape[1], device=token_embeds.device, dtype=torch.int64) + ) + add_count = 0 + if f"DURATION_TOKEN_{duration_idx}" in self.special_tokens: + add_count += 1 + if "BORDERNESS_TOKEN" in self.special_tokens: + add_count += 1 + if add_count > 0: + prepend = torch.ones(batch_size, add_count, dtype=token_mask.dtype, device=token_mask.device) + token_mask = torch.cat([prepend, token_mask], dim=1) + # Truncate to max length + if token_embeds.shape[1] > max_sequence_length: + token_embeds = token_embeds[:, :max_sequence_length, :] + token_mask = token_mask[:, :max_sequence_length] + + chunk_prompt_embeds_list.append(token_embeds) + chunk_prompt_masks_list.append(token_mask) + + # Same for negative prompts + if self.do_classifier_free_guidance and negative_prompt_embeds_per_chunk is not None: + neg_token_embeds = negative_prompt_embeds_per_chunk[:, chunk_idx].clone() + + if f"DURATION_TOKEN_{duration_idx}" in self.special_tokens: + duration_token = self.special_tokens[f"DURATION_TOKEN_{duration_idx}"] + duration_token = duration_token.to( + device=neg_token_embeds.device, dtype=neg_token_embeds.dtype + ) + duration_token = duration_token.unsqueeze(0).expand(batch_size, -1, -1) + neg_token_embeds = torch.cat([duration_token, neg_token_embeds], dim=1) + + if "BORDERNESS_TOKEN" in self.special_tokens: + borderness_token = self.special_tokens["BORDERNESS_TOKEN"] + borderness_token = borderness_token.to( + device=neg_token_embeds.device, dtype=neg_token_embeds.dtype + ) + borderness_token = borderness_token.unsqueeze(0).expand(batch_size, -1, -1) + neg_token_embeds = torch.cat([borderness_token, neg_token_embeds], dim=1) + + # Build negative mask for this chunk + neg_mask = ( + negative_masks_per_chunk[:, chunk_idx] + if negative_masks_per_chunk is not None + else torch.ones(batch_size, neg_token_embeds.shape[1], device=neg_token_embeds.device, dtype=torch.int64) + ) + add_count = 0 + if f"DURATION_TOKEN_{duration_idx}" in self.special_tokens: + add_count += 1 + if "BORDERNESS_TOKEN" in self.special_tokens: + add_count += 1 + if add_count > 0: + prepend = torch.ones(batch_size, add_count, dtype=neg_mask.dtype, device=neg_mask.device) + neg_mask = torch.cat([prepend, neg_mask], dim=1) + if neg_token_embeds.shape[1] > max_sequence_length: + neg_token_embeds = neg_token_embeds[:, :max_sequence_length, :] + neg_mask = neg_mask[:, :max_sequence_length] + + chunk_negative_prompt_embeds_list.append(neg_token_embeds) + chunk_negative_masks_list.append(neg_mask) + + # Denoise this chunk range for denoise_step_per_stage steps + for denoise_idx in range(denoise_step_per_stage): + if self.interrupt: + break + + # Calculate timestep index for each chunk in the current window + # Chunks at different stages get different timesteps based on their denoise progress + timestep_indices = [] + for offset in range(num_chunks_in_window): + # Map offset within window to time index + t_idx_within_window = t_start_idx + offset + if t_idx_within_window < t_end_idx: + # This chunk is actively being denoised in this window + t_idx = t_idx_within_window * denoise_step_per_stage + denoise_idx + else: + # This chunk is beyond the active window, use max timestep (it's already cleaner) + t_idx = min((window_size - 1) * denoise_step_per_stage + denoise_idx, len(timesteps) - 1) + timestep_indices.append(t_idx) + + # Reverse order: chunks further from start are noisier + timestep_indices.reverse() + + # Get actual timesteps (reversed order: high noise to low noise) + current_timesteps = timesteps[timestep_indices] + + # Create per-chunk timestep tensor: [batch_size, num_chunks_in_window] + # Each chunk gets its own timestep based on how many times it's been denoised + timestep_per_chunk = current_timesteps.unsqueeze(0).expand(batch_size, -1) + + # Store first timestep for progress tracking + self._current_timestep = current_timesteps[0] + + # Extract chunk + latent_chunk = latents[:, :, latent_start:latent_end].to(transformer_dtype) + + # Prepare distillation parameters if enabled + num_steps = None + distill_interval = None + distill_nearly_clean_chunk = None + + if enable_distillation: + # distill_interval represents the time interval between denoising steps + distill_interval = len(timesteps) / num_inference_steps + + # Determine if chunks are nearly clean (low noise) based on their timesteps + # Check the first active chunk's timestep (after reversing, this is the noisiest chunk being actively denoised) + # Normalize timestep to [0, 1] range where 0=clean, 1=noise + nearly_clean_chunk_t = current_timesteps[0].item() / self.scheduler.config.num_train_timesteps + distill_nearly_clean_chunk = nearly_clean_chunk_t < distill_nearly_clean_chunk_threshold + + num_steps = num_inference_steps + + # Prepare per-chunk embeddings + # The transformer expects embeddings in shape [batch_size * num_chunks_in_window, seq_len, hidden_dim] + # Each chunk gets its own embedding with appropriate duration/borderness tokens + if chunk_prompt_embeds_list: + # Stack per-chunk embeddings: [num_chunks_in_window, batch_size, seq_len, hidden_dim] + chunk_prompt_embeds = torch.stack(chunk_prompt_embeds_list, dim=0) + # Reshape to [batch_size * num_chunks_in_window, seq_len, hidden_dim] + chunk_prompt_embeds = chunk_prompt_embeds.transpose(0, 1).flatten(0, 1) + + if chunk_negative_prompt_embeds_list: + chunk_negative_prompt_embeds = torch.stack(chunk_negative_prompt_embeds_list, dim=0) + chunk_negative_prompt_embeds = chunk_negative_prompt_embeds.transpose(0, 1).flatten(0, 1) + else: + chunk_negative_prompt_embeds = None + else: + # Fallback: use per-chunk embeddings without special tokens + # Extract embeddings for the current chunk range + chunk_prompt_embeds = prompt_embeds_per_chunk[:, chunk_start_idx:chunk_end_idx] + chunk_prompt_embeds = chunk_prompt_embeds.flatten(0, 1) + chunk_prompt_masks = ( + prompt_masks_per_chunk[:, chunk_start_idx:chunk_end_idx].flatten(0, 1) + if prompt_masks_per_chunk is not None + else None + ) + + if negative_prompt_embeds_per_chunk is not None: + chunk_negative_prompt_embeds = negative_prompt_embeds_per_chunk[ + :, chunk_start_idx:chunk_end_idx + ] + chunk_negative_prompt_embeds = chunk_negative_prompt_embeds.flatten(0, 1) + chunk_negative_masks = ( + negative_masks_per_chunk[:, chunk_start_idx:chunk_end_idx].flatten(0, 1) + if negative_masks_per_chunk is not None + else None + ) + else: + chunk_negative_prompt_embeds = None + chunk_negative_masks = None + + # Create encoder attention mask(s) from tokenizer masks + if chunk_prompt_embeds_list: + prompt_masks_stacked = torch.stack(chunk_prompt_masks_list, dim=0).transpose(0, 1).flatten(0, 1) + else: + prompt_masks_stacked = chunk_prompt_masks + if prompt_masks_stacked is None: + prompt_masks_stacked = torch.ones( + batch_size * num_chunks_in_window, + chunk_prompt_embeds.shape[1], + dtype=torch.int64, + device=chunk_prompt_embeds.device, + ) + encoder_attention_mask = prompt_masks_stacked.to(dtype=chunk_prompt_embeds.dtype).view( + batch_size * num_chunks_in_window, 1, 1, -1 + ) + + if self.do_classifier_free_guidance: + if chunk_negative_prompt_embeds_list: + negative_masks_stacked = ( + torch.stack(chunk_negative_masks_list, dim=0).transpose(0, 1).flatten(0, 1) + ) + else: + negative_masks_stacked = chunk_negative_masks + if negative_masks_stacked is None and chunk_negative_prompt_embeds is not None: + negative_masks_stacked = torch.ones( + batch_size * num_chunks_in_window, + chunk_negative_prompt_embeds.shape[1], + dtype=torch.int64, + device=chunk_negative_prompt_embeds.device, + ) + encoder_attention_mask_neg = ( + negative_masks_stacked.to(dtype=chunk_prompt_embeds.dtype).view( + batch_size * num_chunks_in_window, 1, 1, -1 + ) + if negative_masks_stacked is not None + else encoder_attention_mask + ) + + # Pad prefix video into latent_chunk if applicable (I2V) + # This ensures clean prefix frames are maintained during denoising + if prefix_video is not None: + prefix_length = prefix_video.shape[2] + prefix_video_start = chunk_start_idx * chunk_width + + if prefix_length > prefix_video_start: + # Calculate how many frames to pad + padding_length = min(prefix_length - prefix_video_start, latent_chunk.shape[2]) + prefix_video_end = prefix_video_start + padding_length + + # Pad clean prefix frames into latent_chunk + latent_chunk = latent_chunk.clone() + latent_chunk[:, :, :padding_length] = prefix_video[ + :, :, prefix_video_start:prefix_video_end + ] + + # Set timesteps for clean prefix chunks to maximum (indicates "already clean") + # This matches original MAGI-1's try_pad_prefix_video logic + num_clean_chunks_in_window = padding_length // chunk_width + if num_clean_chunks_in_window > 0: + # Get max timestep from scheduler + max_timestep = timesteps[0] + timestep_per_chunk[:, :num_clean_chunks_in_window] = max_timestep + + # Generate KV range for autoregressive attention + # Each chunk can attend to itself and all previous chunks in the sequence + # Shape: [batch_size * num_chunks_in_window, 2] where each row is [start_token_idx, end_token_idx] + chunk_token_nums = ( + (latent_chunk.shape[2] // num_chunks_in_window) # frames per chunk + * (latent_chunk.shape[3] // self.transformer.config.patch_size[1]) # height tokens + * (latent_chunk.shape[4] // self.transformer.config.patch_size[2]) # width tokens + ) + kv_range = [] + for b in range(batch_size): + # batch_offset should be based on total chunks in the video, not chunk_end_idx + batch_offset = b * num_chunks + for c in range(num_chunks_in_window): + # This chunk can attend from the start of the video up to its own end + chunk_global_idx = chunk_start_idx + c + k_start = batch_offset * chunk_token_nums + k_end = (batch_offset + chunk_global_idx + 1) * chunk_token_nums + kv_range.append([k_start, k_end]) + kv_range = torch.tensor(kv_range, dtype=torch.int32, device=device) + + # Predict noise (conditional) + # Note: MAGI-1 uses velocity field (flow matching), but following diffusers convention + # we use noise_pred naming for consistency across all pipelines + noise_pred = self.transformer( + hidden_states=latent_chunk, + timestep=timestep_per_chunk, + encoder_hidden_states=chunk_prompt_embeds, + encoder_attention_mask=encoder_attention_mask, + attention_kwargs=attention_kwargs, + denoising_range_num=num_chunks_in_window, + range_num=chunk_end_idx, + slice_point=chunk_start_idx, + kv_range=kv_range, + num_steps=num_steps, + distill_interval=distill_interval, + distill_nearly_clean_chunk=distill_nearly_clean_chunk, + return_dict=False, + )[0] + + # Classifier-free guidance: separate forward pass for unconditional + if self.do_classifier_free_guidance: + noise_pred_uncond = self.transformer( + hidden_states=latent_chunk, + timestep=timestep_per_chunk, + encoder_hidden_states=chunk_negative_prompt_embeds, + encoder_attention_mask=encoder_attention_mask_neg, + attention_kwargs=attention_kwargs, + denoising_range_num=num_chunks_in_window, + range_num=chunk_end_idx, + slice_point=chunk_start_idx, + kv_range=kv_range, + num_steps=num_steps, + distill_interval=distill_interval, + distill_nearly_clean_chunk=distill_nearly_clean_chunk, + return_dict=False, + )[0] + # Apply classifier-free guidance + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) + + # CRITICAL: Apply per-chunk Euler integration with different delta_t for each chunk + # This matches the original MAGI-1's integrate() function + # Each chunk in the window is at a different noise level and needs its own time step + + # Calculate per-chunk timesteps for integration + # Get current timesteps (t_before) + current_timesteps = timesteps[timestep_indices] + + # Get next timesteps (t_after) - one step forward for each chunk + next_timestep_indices = [min(idx + 1, len(timesteps) - 1) for idx in timestep_indices] + next_timesteps = timesteps[next_timestep_indices] + + # Convert timesteps to sigmas (matching FlowMatchEulerDiscreteScheduler) + current_sigmas = current_timesteps / self.scheduler.config.num_train_timesteps + next_sigmas = next_timesteps / self.scheduler.config.num_train_timesteps + + # Calculate delta_t for each chunk: [num_chunks_in_window] + delta_t = next_sigmas - current_sigmas + + # Reshape latent_chunk and noise_pred to separate chunks + # From: [batch, channels, frames, height, width] + # To: [batch, channels, num_chunks, chunk_width, height, width] + batch_size_actual, num_channels, total_frames, height_latent, width_latent = latent_chunk.shape + + # Ensure total_frames is divisible by chunk_width for reshaping + # (it should be, but let's handle edge cases) + num_complete_chunks = total_frames // chunk_width + remainder_frames = total_frames % chunk_width + + if remainder_frames == 0: + # Perfect division: reshape and apply per-chunk delta_t + latent_chunk = latent_chunk.reshape( + batch_size_actual, + num_channels, + num_complete_chunks, + chunk_width, + height_latent, + width_latent, + ) + noise_pred = noise_pred.reshape( + batch_size_actual, + num_channels, + num_complete_chunks, + chunk_width, + height_latent, + width_latent, + ) + + # Apply Euler integration: x_chunk = x_chunk + velocity * delta_t + # delta_t shape: [num_chunks] -> broadcast to [1, 1, num_chunks, 1, 1, 1] + delta_t_broadcast = delta_t.reshape(1, 1, -1, 1, 1, 1).to( + latent_chunk.device, latent_chunk.dtype + ) + latent_chunk = latent_chunk + noise_pred * delta_t_broadcast + + # Reshape back to original dimensions + latent_chunk = latent_chunk.reshape( + batch_size_actual, num_channels, total_frames, height_latent, width_latent + ) + else: + # Handle remainder frames separately (edge case for last incomplete chunk) + complete_frames = num_complete_chunks * chunk_width + + # Process complete chunks + latent_chunk_complete = latent_chunk[:, :, :complete_frames] + noise_pred_complete = noise_pred[:, :, :complete_frames] + + latent_chunk_complete = latent_chunk_complete.reshape( + batch_size_actual, + num_channels, + num_complete_chunks, + chunk_width, + height_latent, + width_latent, + ) + noise_pred_complete = noise_pred_complete.reshape( + batch_size_actual, + num_channels, + num_complete_chunks, + chunk_width, + height_latent, + width_latent, + ) + + # Apply per-chunk delta_t to complete chunks + delta_t_broadcast = ( + delta_t[:num_complete_chunks] + .reshape(1, 1, -1, 1, 1, 1) + .to(latent_chunk.device, latent_chunk.dtype) + ) + latent_chunk_complete = latent_chunk_complete + noise_pred_complete * delta_t_broadcast + latent_chunk_complete = latent_chunk_complete.reshape( + batch_size_actual, num_channels, complete_frames, height_latent, width_latent + ) + + # Process remainder frames with last delta_t + if remainder_frames > 0: + latent_chunk_remainder = latent_chunk[:, :, complete_frames:] + noise_pred_remainder = noise_pred[:, :, complete_frames:] + delta_t_remainder = delta_t[-1].to(latent_chunk.device, latent_chunk.dtype) + latent_chunk_remainder = latent_chunk_remainder + noise_pred_remainder * delta_t_remainder + + # Concatenate complete and remainder + latent_chunk = torch.cat([latent_chunk_complete, latent_chunk_remainder], dim=2) + else: + latent_chunk = latent_chunk_complete + + # Write back to full latents + latents[:, :, latent_start:latent_end] = latent_chunk + + # Update chunk denoise counts + for chunk_idx in range(chunk_start_idx, chunk_end_idx): + chunk_denoise_count[chunk_idx] += 1 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + # Use first timestep for callback (most representative) + callback_timestep = current_timesteps[0] + callback_outputs = callback_on_step_end( + self, stage_idx * denoise_step_per_stage + denoise_idx, callback_timestep, callback_kwargs + ) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return Magi1PipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1_v2v.py b/src/diffusers/pipelines/magi1/pipeline_magi1_v2v.py new file mode 100644 index 000000000000..99d647e74ea3 --- /dev/null +++ b/src/diffusers/pipelines/magi1/pipeline_magi1_v2v.py @@ -0,0 +1,1527 @@ +# Copyright 2025 The SandAI Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MAGI-1 V2V Pipeline with Autoregressive Chunked Generation +# +# ✅ IMPLEMENTED: +# - Video-to-Video generation using prefix video conditioning +# - Autoregressive chunked generation (always enabled, matching original MAGI-1) +# - Window-based scheduling: chunk_width=6, window_size=4 +# - Progressive denoising across overlapping temporal windows +# - Proper CFG with separate forward passes (diffusers style) +# - Input video frames encoding to VAE latent as clean prefix chunks +# +# ⚠️ CURRENT LIMITATION: +# - No KV caching: attention is recomputed for previous chunks +# - This is less efficient than the original but fully functional +# +# ⏳ FUTURE OPTIMIZATIONS (when diffusers adds generic KV caching): +# 1. **KV Cache Management**: +# - Cache attention keys/values for previously denoised chunks +# - Reuse cached computations instead of recomputing +# - Will significantly speed up generation (2-3x faster expected) +# +# 2. **Special Token Support** (optional enhancement): +# - Duration tokens: indicate how many chunks remain to generate +# - Quality tokens: HQ_TOKEN for high-quality generation +# - Style tokens: THREE_D_MODEL_TOKEN, TWO_D_ANIME_TOKEN +# - Motion tokens: STATIC_FIRST_FRAMES_TOKEN, DYNAMIC_FIRST_FRAMES_TOKEN +# +# 3. **Streaming Generation**: +# - Yield clean chunks as they complete (generator pattern) +# - Enable real-time preview during generation +# +# Reference: https://github.com/SandAI/MAGI-1/blob/main/inference/pipeline/video_generate.py + +import html +import re +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import PIL.Image +import torch +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import Magi1LoraLoaderMixin +from ...models import AutoencoderKLMagi1, Magi1Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import Magi1PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def generate_chunk_sequences(chunk_num: int, window_size: int, chunk_offset: int = 0): + """ + Generate chunk scheduling sequences for autoregressive video generation. + + Args: + chunk_num: Total number of chunks to generate + window_size: Number of chunks to process in each window + chunk_offset: Number of clean prefix chunks (for I2V/V2V) + + Returns: + ``` + clip_start: Start index of chunks to process + clip_end: End index of chunks to process + t_start: Start index in time dimension + t_end: End index in time dimension + ``` + + Examples: + ``` + chunk_num=8, window_size=4, chunk_offset=0 + Stage 0: Process chunks [0:1], denoise chunk 0 + Stage 1: Process chunks [0:2], denoise chunk 1 + Stage 2: Process chunks [0:3], denoise chunk 2 + Stage 3: Process chunks [0:4], denoise chunk 3 + Stage 4: Process chunks [1:5], denoise chunk 4 + ... + ``` + """ + start_index = chunk_offset + end_index = chunk_num + window_size - 1 + + clip_start = [max(chunk_offset, i - window_size + 1) for i in range(start_index, end_index)] + clip_end = [min(chunk_num, i + 1) for i in range(start_index, end_index)] + + t_start = [max(0, i - chunk_num + 1) for i in range(start_index, end_index)] + t_end = [ + min(window_size, i - chunk_offset + 1) if i - chunk_offset < window_size else window_size + for i in range(start_index, end_index) + ] + + return clip_start, clip_end, t_start, t_end + + +def load_special_tokens(special_tokens_path: Optional[str] = None) -> Optional[Dict[str, torch.Tensor]]: + """ + Load special conditioning tokens from numpy file. + + Args: + special_tokens_path: Path to special_tokens.npz file. If None, returns None (no special tokens). + + Returns: + Dictionary mapping token names to embeddings, or None if path not provided or file doesn't exist. + """ + if special_tokens_path is None: + return None + + try: + import os + + import numpy as np + + if not os.path.exists(special_tokens_path): + logger.warning(f"Special tokens file not found at {special_tokens_path}, skipping special token loading.") + return None + + special_token_data = np.load(special_tokens_path) + caption_token = torch.tensor(special_token_data["caption_token"].astype(np.float16)) + logo_token = torch.tensor(special_token_data["logo_token"].astype(np.float16)) + other_tokens = special_token_data["other_tokens"] + + tokens = { + "CAPTION_TOKEN": caption_token, + "LOGO_TOKEN": logo_token, + "TRANS_TOKEN": torch.tensor(other_tokens[:1].astype(np.float16)), + "HQ_TOKEN": torch.tensor(other_tokens[1:2].astype(np.float16)), + "STATIC_FIRST_FRAMES_TOKEN": torch.tensor(other_tokens[2:3].astype(np.float16)), + "DYNAMIC_FIRST_FRAMES_TOKEN": torch.tensor(other_tokens[3:4].astype(np.float16)), + "BORDERNESS_TOKEN": torch.tensor(other_tokens[4:5].astype(np.float16)), + "THREE_D_MODEL_TOKEN": torch.tensor(other_tokens[15:16].astype(np.float16)), + "TWO_D_ANIME_TOKEN": torch.tensor(other_tokens[16:17].astype(np.float16)), + } + + # Duration tokens (8 total, representing 1-8 chunks remaining) + for i in range(8): + tokens[f"DURATION_TOKEN_{i + 1}"] = torch.tensor(other_tokens[i + 7 : i + 8].astype(np.float16)) + + logger.info(f"Loaded {len(tokens)} special tokens from {special_tokens_path}") + return tokens + except Exception as e: + logger.warning(f"Failed to load special tokens: {e}") + return None + + +def prepend_special_tokens( + prompt_embeds: torch.Tensor, + special_tokens: Optional[Dict[str, torch.Tensor]], + use_hq_token: bool = False, + use_3d_style: bool = False, + use_2d_anime_style: bool = False, + use_static_first_frames: bool = False, + use_dynamic_first_frames: bool = False, + max_sequence_length: int = 800, +) -> torch.Tensor: + """ + Prepend special conditioning tokens to text embeddings. + + Args: + prompt_embeds: Text embeddings [batch, seq_len, hidden_dim] + special_tokens: Dictionary of special token embeddings + use_hq_token: Whether to add high-quality token + use_3d_style: Whether to add 3D model style token + use_2d_anime_style: Whether to add 2D anime style token + use_static_first_frames: Whether to add static motion token + use_dynamic_first_frames: Whether to add dynamic motion token + max_sequence_length: Maximum sequence length after prepending + + Returns: + Text embeddings with special tokens prepended + """ + if special_tokens is None: + return prompt_embeds + + device = prompt_embeds.device + dtype = prompt_embeds.dtype + batch_size, seq_len, hidden_dim = prompt_embeds.shape + + # Collect tokens to prepend (in order: motion, quality, style) + tokens_to_add = [] + if use_static_first_frames and "STATIC_FIRST_FRAMES_TOKEN" in special_tokens: + tokens_to_add.append(special_tokens["STATIC_FIRST_FRAMES_TOKEN"]) + if use_dynamic_first_frames and "DYNAMIC_FIRST_FRAMES_TOKEN" in special_tokens: + tokens_to_add.append(special_tokens["DYNAMIC_FIRST_FRAMES_TOKEN"]) + if use_hq_token and "HQ_TOKEN" in special_tokens: + tokens_to_add.append(special_tokens["HQ_TOKEN"]) + if use_3d_style and "THREE_D_MODEL_TOKEN" in special_tokens: + tokens_to_add.append(special_tokens["THREE_D_MODEL_TOKEN"]) + if use_2d_anime_style and "TWO_D_ANIME_TOKEN" in special_tokens: + tokens_to_add.append(special_tokens["TWO_D_ANIME_TOKEN"]) + + # Prepend tokens + for token in tokens_to_add: + token = token.to(device=device, dtype=dtype).unsqueeze(0).expand(batch_size, -1, -1) + prompt_embeds = torch.cat([token, prompt_embeds], dim=1) + + # Truncate to max length + if prompt_embeds.shape[1] > max_sequence_length: + prompt_embeds = prompt_embeds[:, :max_sequence_length, :] + + return prompt_embeds + + +def prepare_v2v_embeddings( + prompt_embeds: torch.Tensor, + negative_prompt_embeds: Optional[torch.Tensor], + num_chunks: int, + clean_chunk_num: int, + max_sequence_length: int = 800, + prompt_mask: Optional[torch.Tensor] = None, + negative_mask: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Prepare per-chunk text embeddings for V2V generation. + + In V2V, clean prefix chunks (from the input video) use null embeddings, while chunks to be denoised use the actual + text embeddings. + + Args: + prompt_embeds: Text embeddings [batch_size, seq_len, hidden_dim] + negative_prompt_embeds: Negative text embeddings (optional) + num_chunks: Total number of chunks + clean_chunk_num: Number of clean prefix chunks + max_sequence_length: Maximum sequence length + + Returns: + - prompt_embeds_per_chunk: [B, num_chunks, L, D] + - negative_prompt_embeds_per_chunk: [B, num_chunks, L, D] or None + - prompt_masks_per_chunk: [B, num_chunks, L] or None + - negative_masks_per_chunk: [B, num_chunks, L] or None + """ + batch_size = prompt_embeds.shape[0] + seq_len = prompt_embeds.shape[1] + hidden_dim = prompt_embeds.shape[2] + device = prompt_embeds.device + dtype = prompt_embeds.dtype + + # Number of chunks that need denoising + denoise_chunk_num = num_chunks - clean_chunk_num + + # Create null embeddings (zeros) for clean chunks + null_embeds = torch.zeros(batch_size, 1, seq_len, hidden_dim, device=device, dtype=dtype) + + # Expand prompt embeddings for denoise chunks + # Shape: [batch_size, denoise_chunk_num, seq_len, hidden_dim] + denoise_embeds = prompt_embeds.unsqueeze(1).repeat(1, denoise_chunk_num, 1, 1) + + # Concatenate: [null_embeds for clean chunks] + [text_embeds for denoise chunks] + # Shape: [batch_size, num_chunks, seq_len, hidden_dim] + if clean_chunk_num > 0: + null_embeds_expanded = null_embeds.repeat(1, clean_chunk_num, 1, 1) + prompt_embeds_per_chunk = torch.cat([null_embeds_expanded, denoise_embeds], dim=1) + else: + prompt_embeds_per_chunk = denoise_embeds + + # Build prompt masks per chunk + prompt_masks_per_chunk = None + negative_masks_per_chunk = None + if prompt_mask is not None: + denoise_masks = prompt_mask.unsqueeze(1).repeat(1, denoise_chunk_num, 1) + if clean_chunk_num > 0: + null_masks = torch.zeros(prompt_mask.shape[0], clean_chunk_num, prompt_mask.shape[1], device=device, dtype=prompt_mask.dtype) + prompt_masks_per_chunk = torch.cat([null_masks, denoise_masks], dim=1) + else: + prompt_masks_per_chunk = denoise_masks + + # Same for negative embeddings + if negative_prompt_embeds is not None: + denoise_neg_embeds = negative_prompt_embeds.unsqueeze(1).repeat(1, denoise_chunk_num, 1, 1) + if clean_chunk_num > 0: + null_neg_embeds_expanded = null_embeds.repeat(1, clean_chunk_num, 1, 1) + negative_prompt_embeds_per_chunk = torch.cat([null_neg_embeds_expanded, denoise_neg_embeds], dim=1) + else: + negative_prompt_embeds_per_chunk = denoise_neg_embeds + else: + negative_prompt_embeds_per_chunk = None + if negative_mask is not None: + denoise_neg_masks = negative_mask.unsqueeze(1).repeat(1, denoise_chunk_num, 1) + if clean_chunk_num > 0: + null_neg_masks = torch.zeros(negative_mask.shape[0], clean_chunk_num, negative_mask.shape[1], device=device, dtype=negative_mask.dtype) + negative_masks_per_chunk = torch.cat([null_neg_masks, denoise_neg_masks], dim=1) + else: + negative_masks_per_chunk = denoise_neg_masks + + return ( + prompt_embeds_per_chunk, + negative_prompt_embeds_per_chunk, + prompt_masks_per_chunk, + negative_masks_per_chunk, + ) + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import Magi1VideoToVideoPipeline, AutoencoderKLMagi1, FlowMatchEulerDiscreteScheduler + >>> from diffusers.utils import export_to_video, load_video + + >>> model_id = "SandAI/Magi1-V2V-14B-480P-Diffusers" + >>> vae = AutoencoderKLMagi1.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + + >>> # IMPORTANT: MAGI-1 requires shift=3.0 for the scheduler (SD3-style time resolution transform) + >>> scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler", shift=3.0) + + >>> pipe = Magi1VideoToVideoPipeline.from_pretrained(model_id, vae=vae, scheduler=scheduler, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> # Load prefix video (e.g., 24 frames) + >>> video = load_video("path/to/input_video.mp4", num_frames=24) + >>> prompt = ( + ... "Continue this video with smooth camera motion and consistent style. " + ... "The scene evolves naturally with coherent motion." + ... ) + >>> negative_prompt = "Bright tones, overexposed, static, blurred details, worst quality, low quality" + + >>> output = pipe( + ... video=video, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=480, + ... width=832, + ... num_frames=81, # Total frames including prefix + ... guidance_scale=5.0, + ... num_inference_steps=50, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=16) + ``` +""" + + +class Magi1VideoToVideoPipeline(DiffusionPipeline, Magi1LoraLoaderMixin): + r""" + Pipeline for video-to-video generation using Magi1. + + MAGI-1 is a DiT-based video generation model that supports autoregressive chunked generation for long videos. This + V2V pipeline takes an input video and generates a continuation or extension of that video. + + **Note**: This implementation uses autoregressive chunked generation (chunk_width=6, window_size=4) as in the + original MAGI-1 paper. The input video frames are encoded to latent representations and used as clean prefix chunks + to condition the generation. Text prompts provide additional semantic guidance for the video continuation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`Magi1Transformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A flow matching scheduler with Euler discretization, using SD3-style time resolution transform. + vae ([`AutoencoderKLMagi1`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: Magi1Transformer3DModel, + vae: AutoencoderKLMagi1, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4 + ) + self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Special tokens for conditioning (optional) + self.special_tokens = None + + def load_special_tokens_from_file(self, special_tokens_path: str): + """ + Load special conditioning tokens from a numpy file. + + Args: + special_tokens_path: Path to special_tokens.npz file + """ + self.special_tokens = load_special_tokens(special_tokens_path) + if self.special_tokens is not None: + logger.info("Special tokens loaded successfully. You can now use quality, style, and motion control.") + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 800, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + mask = mask.repeat(1, num_videos_per_prompt) + mask = mask.view(batch_size * num_videos_per_prompt, -1).to(device) + return prompt_embeds, mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 800, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + else: + prompt_mask = None + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + else: + negative_mask = None + + return prompt_embeds, negative_prompt_embeds, prompt_mask, negative_mask + + def check_inputs( + self, + prompt, + negative_prompt, + video, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if video is None: + raise ValueError( + "Provide `video` for video-to-video generation. Cannot leave `video` undefined for V2V pipeline." + ) + if video is not None and not isinstance(video, list): + raise ValueError(f"`video` has to be of type `list` (list of PIL Images or tensors) but is {type(video)}") + + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + def prepare_latents( + self, + video: Optional[List[PIL.Image.Image]], + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Prepare latents for V2V generation, including encoding the input video frames as prefix_video. + + Args: + video: Input video frames for V2V generation (list of PIL Images) + batch_size: Batch size + num_channels_latents: Number of latent channels + height: Video height + width: Video width + num_frames: Total number of frames to generate (including prefix) + dtype: Data type + device: Device + generator: Random generator + latents: Pre-generated latents (optional) + + Returns: + Tuple of (latents, prefix_video) where: + - latents: Random noise tensor for generation [batch, channels, num_latent_frames, H, W] + - prefix_video: Encoded video frames as clean latent [batch, channels, prefix_frames, H, W] + """ + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # Prepare random latents for generation + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + # Encode input video frames to latent as prefix_video + prefix_video = None + if video is not None and len(video) > 0: + # Preprocess video frames to target size + # video_processor.preprocess_video expects list of PIL Images + video_tensor = self.video_processor.preprocess_video(video, height=height, width=width).to( + device, dtype=torch.float32 + ) + + # video_tensor shape: [batch, channels, num_frames, height, width] + # For single batch, expand if needed + if video_tensor.ndim == 4: + # [channels, num_frames, height, width] -> [1, channels, num_frames, height, width] + video_tensor = video_tensor.unsqueeze(0) + + # Encode to latent space using VAE + # VAE expects [batch, channels, frames, height, width] + if isinstance(generator, list): + prefix_video = [ + retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="sample", generator=g) + for g, vid in zip(generator, video_tensor) + ] + prefix_video = torch.cat(prefix_video) + else: + prefix_video = retrieve_latents( + self.vae.encode(video_tensor), sample_mode="sample", generator=generator + ) + if prefix_video.shape[0] < batch_size: + prefix_video = prefix_video.repeat(batch_size, 1, 1, 1, 1) + + # Normalize latent using VAE's latent statistics + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(prefix_video.device, dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + prefix_video.device, dtype + ) + prefix_video = prefix_video.to(dtype) + prefix_video = (prefix_video - latents_mean) * latents_std + + return latents, prefix_video + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + video: List[PIL.Image.Image], + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 800, + use_hq_token: bool = False, + use_3d_style: bool = False, + use_2d_anime_style: bool = False, + use_static_first_frames: bool = False, + use_dynamic_first_frames: bool = False, + enable_distillation: bool = False, + distill_nearly_clean_chunk_threshold: float = 0.3, + ): + r""" + The call function to the pipeline for video-to-video generation. + + **Note**: This implementation uses autoregressive chunked generation (chunk_width=6, window_size=4) as in the + original MAGI-1 paper. The input video frames are encoded to VAE latents and used as clean prefix chunks to + condition the video generation. The implementation currently works without KV caching (attention is recomputed + for previous chunks), which is less efficient than the original but still functional. KV caching optimization + will be added when diffusers implements generic caching support for transformers. + + Args: + video (`List[PIL.Image.Image]`): + The input video frames to condition the video generation on. Must be a list of PIL Images representing + the prefix video (e.g., first 24 frames of a video). + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to avoid during video generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + height (`int`, defaults to `480`): + The height in pixels of the generated image. + width (`int`, defaults to `832`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, negative_prompt_embeds will be generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. Choose between `"latent"`, `"pt"`, or `"np"`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`Magi1PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `800`): + The maximum sequence length for the text encoder. Sequences longer than this will be truncated. MAGI-1 + uses a max length of 800 tokens. + use_hq_token (`bool`, *optional*, defaults to `False`): + Whether to prepend the high-quality control token to the text embeddings. This token conditions the + model to generate higher quality outputs. Requires special tokens to be loaded via + `load_special_tokens_from_file`. + use_3d_style (`bool`, *optional*, defaults to `False`): + Whether to prepend the 3D model style token to the text embeddings. This token conditions the model to + generate outputs with 3D modeling aesthetics. Requires special tokens to be loaded. + use_2d_anime_style (`bool`, *optional*, defaults to `False`): + Whether to prepend the 2D anime style token to the text embeddings. This token conditions the model to + generate outputs with 2D anime aesthetics. Requires special tokens to be loaded. + use_static_first_frames (`bool`, *optional*, defaults to `False`): + Whether to prepend the static first frames token to the text embeddings. This token conditions the + model to start the video with minimal motion in the first few frames. Requires special tokens to be + loaded. + use_dynamic_first_frames (`bool`, *optional*, defaults to `False`): + Whether to prepend the dynamic first frames token to the text embeddings. This token conditions the + model to start the video with significant motion in the first few frames. Requires special tokens to be + loaded. + enable_distillation (`bool`, *optional*, defaults to `False`): + Whether to enable distillation mode. In distillation mode, the model uses modified timestep embeddings + to support distilled (faster) inference. This requires a distilled model checkpoint. + distill_nearly_clean_chunk_threshold (`float`, *optional*, defaults to `0.3`): + Threshold for identifying nearly-clean chunks in distillation mode. Chunks with timestep > threshold + are considered nearly clean and processed differently. Only used when `enable_distillation=True`. + + Examples: + + Returns: + [`~Magi1PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`Magi1PipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated videos. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + video, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds, prompt_mask, negative_mask = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + dtype=self.text_encoder.dtype, + ) + + # 3.5. Prepend special tokens if requested + if self.special_tokens is not None and any( + [use_hq_token, use_3d_style, use_2d_anime_style, use_static_first_frames, use_dynamic_first_frames] + ): + prompt_embeds = prepend_special_tokens( + prompt_embeds=prompt_embeds, + special_tokens=self.special_tokens, + use_hq_token=use_hq_token, + use_3d_style=use_3d_style, + use_2d_anime_style=use_2d_anime_style, + use_static_first_frames=use_static_first_frames, + use_dynamic_first_frames=use_dynamic_first_frames, + max_sequence_length=max_sequence_length, + ) + if negative_prompt_embeds is not None: + negative_prompt_embeds = prepend_special_tokens( + prompt_embeds=negative_prompt_embeds, + special_tokens=self.special_tokens, + use_hq_token=use_hq_token, + use_3d_style=use_3d_style, + use_2d_anime_style=use_2d_anime_style, + use_static_first_frames=use_static_first_frames, + use_dynamic_first_frames=use_dynamic_first_frames, + max_sequence_length=max_sequence_length, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents, prefix_video = self.prepare_latents( + video, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Denoising loop (autoregressive chunked generation with V2V prefix conditioning) + # MAGI-1 V2V uses autoregressive generation with chunk_width=6 and window_size=4 + # The input video frames are encoded as clean prefix chunks and used to condition the generation + # Note: num_warmup_steps is calculated for compatibility but not used in progress bar logic + # because autoregressive generation has a different iteration structure (stages × steps) + # For FlowMatchEulerDiscreteScheduler (order=1), this doesn't affect the results + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + # Autoregressive chunked generation parameters + chunk_width = 6 # Original MAGI-1 default + window_size = 4 # Original MAGI-1 default + + num_latent_frames = latents.shape[2] + num_chunks = (num_latent_frames + chunk_width - 1) // chunk_width + + # Calculate chunk_offset from prefix_video (for V2V, these are the clean video frame chunks) + chunk_offset = 0 + if prefix_video is not None: + # prefix_video has shape [batch, channels, num_prefix_frames, height, width] for V2V + # Calculate how many chunks are covered by the prefix + prefix_latent_frames = prefix_video.shape[2] + chunk_offset = prefix_latent_frames // chunk_width + + # Pad prefix_video into latents at the beginning + # The prefix frames are already clean and don't need denoising + if prefix_latent_frames > 0: + prefix_video = prefix_video.to(latents.dtype) + latents[:, :, :prefix_latent_frames] = prefix_video + + # Calculate chunk scheduling: which chunks to process at each stage + # chunk_offset skips the clean prefix chunks + clip_start, clip_end, t_start, t_end = generate_chunk_sequences(num_chunks, window_size, chunk_offset) + num_stages = len(clip_start) + + # Prepare per-chunk text embeddings for V2V + # Clean chunks (from input video) use null embeddings, denoise chunks use text embeddings + ( + prompt_embeds_per_chunk, + negative_prompt_embeds_per_chunk, + prompt_masks_per_chunk, + negative_masks_per_chunk, + ) = prepare_v2v_embeddings( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + num_chunks=num_chunks, + clean_chunk_num=chunk_offset, + max_sequence_length=max_sequence_length, + prompt_mask=prompt_mask, + negative_mask=negative_mask, + ) + + # Number of denoising steps per stage + denoise_step_per_stage = len(timesteps) // window_size + + # Track how many times each chunk has been denoised + chunk_denoise_count = {i: 0 for i in range(num_chunks)} + + with self.progress_bar(total=num_stages * denoise_step_per_stage) as progress_bar: + for stage_idx in range(num_stages): + # Determine which chunks to process in this stage + chunk_start_idx = clip_start[stage_idx] + chunk_end_idx = clip_end[stage_idx] + t_start_idx = t_start[stage_idx] + t_end_idx = t_end[stage_idx] + + # Extract chunk range in latent space + latent_start = chunk_start_idx * chunk_width + latent_end = min(chunk_end_idx * chunk_width, num_latent_frames) + + # Number of chunks in current window + num_chunks_in_window = chunk_end_idx - chunk_start_idx + + # Prepare per-chunk conditioning with duration/borderness tokens + # Duration tokens indicate how many chunks remain in the video + # Borderness tokens condition on chunk boundaries + chunk_prompt_embeds_list = [] + chunk_negative_prompt_embeds_list = [] + chunk_prompt_masks_list = [] + chunk_negative_masks_list = [] + + if self.special_tokens is not None: + # Prepare per-chunk embeddings with duration tokens + # Each chunk gets a different duration token based on chunks remaining + for i, chunk_idx in enumerate(range(chunk_start_idx, chunk_end_idx)): + chunks_remaining = num_chunks - chunk_idx - 1 + # Duration token ranges from 1-8 chunks + duration_idx = min(chunks_remaining, 7) + 1 + + # Get base embeddings for this chunk (clean chunks have null embeds, denoise chunks have text embeds) + token_embeds = prompt_embeds_per_chunk[:, chunk_idx].clone() + + # Add duration and borderness tokens for this chunk + if f"DURATION_TOKEN_{duration_idx}" in self.special_tokens: + duration_token = self.special_tokens[f"DURATION_TOKEN_{duration_idx}"] + duration_token = duration_token.to(device=token_embeds.device, dtype=token_embeds.dtype) + duration_token = duration_token.unsqueeze(0).expand(batch_size, -1, -1) + token_embeds = torch.cat([duration_token, token_embeds], dim=1) + + if "BORDERNESS_TOKEN" in self.special_tokens: + borderness_token = self.special_tokens["BORDERNESS_TOKEN"] + borderness_token = borderness_token.to( + device=token_embeds.device, dtype=token_embeds.dtype + ) + borderness_token = borderness_token.unsqueeze(0).expand(batch_size, -1, -1) + token_embeds = torch.cat([borderness_token, token_embeds], dim=1) + + # Build mask for this chunk from base per-chunk mask + token_mask = ( + prompt_masks_per_chunk[:, chunk_idx] + if prompt_masks_per_chunk is not None + else torch.ones(batch_size, token_embeds.shape[1], device=token_embeds.device, dtype=torch.int64) + ) + add_count = 0 + if f"DURATION_TOKEN_{duration_idx}" in self.special_tokens: + add_count += 1 + if "BORDERNESS_TOKEN" in self.special_tokens: + add_count += 1 + if add_count > 0: + prepend = torch.ones(batch_size, add_count, dtype=token_mask.dtype, device=token_mask.device) + token_mask = torch.cat([prepend, token_mask], dim=1) + # Truncate to max length + if token_embeds.shape[1] > max_sequence_length: + token_embeds = token_embeds[:, :max_sequence_length, :] + token_mask = token_mask[:, :max_sequence_length] + + chunk_prompt_embeds_list.append(token_embeds) + chunk_prompt_masks_list.append(token_mask) + + # Same for negative prompts + if self.do_classifier_free_guidance and negative_prompt_embeds_per_chunk is not None: + neg_token_embeds = negative_prompt_embeds_per_chunk[:, chunk_idx].clone() + + if f"DURATION_TOKEN_{duration_idx}" in self.special_tokens: + duration_token = self.special_tokens[f"DURATION_TOKEN_{duration_idx}"] + duration_token = duration_token.to( + device=neg_token_embeds.device, dtype=neg_token_embeds.dtype + ) + duration_token = duration_token.unsqueeze(0).expand(batch_size, -1, -1) + neg_token_embeds = torch.cat([duration_token, neg_token_embeds], dim=1) + + if "BORDERNESS_TOKEN" in self.special_tokens: + borderness_token = self.special_tokens["BORDERNESS_TOKEN"] + borderness_token = borderness_token.to( + device=neg_token_embeds.device, dtype=neg_token_embeds.dtype + ) + borderness_token = borderness_token.unsqueeze(0).expand(batch_size, -1, -1) + neg_token_embeds = torch.cat([borderness_token, neg_token_embeds], dim=1) + + # Build negative mask for this chunk + neg_mask = ( + negative_masks_per_chunk[:, chunk_idx] + if negative_masks_per_chunk is not None + else torch.ones(batch_size, neg_token_embeds.shape[1], device=neg_token_embeds.device, dtype=torch.int64) + ) + add_count = 0 + if f"DURATION_TOKEN_{duration_idx}" in self.special_tokens: + add_count += 1 + if "BORDERNESS_TOKEN" in self.special_tokens: + add_count += 1 + if add_count > 0: + prepend = torch.ones(batch_size, add_count, dtype=neg_mask.dtype, device=neg_mask.device) + neg_mask = torch.cat([prepend, neg_mask], dim=1) + if neg_token_embeds.shape[1] > max_sequence_length: + neg_token_embeds = neg_token_embeds[:, :max_sequence_length, :] + neg_mask = neg_mask[:, :max_sequence_length] + + chunk_negative_prompt_embeds_list.append(neg_token_embeds) + chunk_negative_masks_list.append(neg_mask) + + # Denoise this chunk range for denoise_step_per_stage steps + for denoise_idx in range(denoise_step_per_stage): + if self.interrupt: + break + + # Calculate timestep index for each chunk in the current window + # Chunks at different stages get different timesteps based on their denoise progress + timestep_indices = [] + for offset in range(num_chunks_in_window): + # Map offset within window to time index + t_idx_within_window = t_start_idx + offset + if t_idx_within_window < t_end_idx: + # This chunk is actively being denoised in this window + t_idx = t_idx_within_window * denoise_step_per_stage + denoise_idx + else: + # This chunk is beyond the active window, use max timestep (it's already cleaner) + t_idx = min((window_size - 1) * denoise_step_per_stage + denoise_idx, len(timesteps) - 1) + timestep_indices.append(t_idx) + + # Reverse order: chunks further from start are noisier + timestep_indices.reverse() + + # Get actual timesteps (reversed order: high noise to low noise) + current_timesteps = timesteps[timestep_indices] + + # Create per-chunk timestep tensor: [batch_size, num_chunks_in_window] + # Each chunk gets its own timestep based on how many times it's been denoised + timestep_per_chunk = current_timesteps.unsqueeze(0).expand(batch_size, -1) + + # Store first timestep for progress tracking + self._current_timestep = current_timesteps[0] + + # Extract chunk + latent_chunk = latents[:, :, latent_start:latent_end].to(transformer_dtype) + + # Prepare distillation parameters if enabled + num_steps = None + distill_interval = None + distill_nearly_clean_chunk = None + + if enable_distillation: + # distill_interval represents the time interval between denoising steps + distill_interval = len(timesteps) / num_inference_steps + + # Determine if chunks are nearly clean (low noise) based on their timesteps + # Check the first active chunk's timestep (after reversing, this is the noisiest chunk being actively denoised) + # Normalize timestep to [0, 1] range where 0=clean, 1=noise + nearly_clean_chunk_t = current_timesteps[0].item() / self.scheduler.config.num_train_timesteps + distill_nearly_clean_chunk = nearly_clean_chunk_t < distill_nearly_clean_chunk_threshold + + num_steps = num_inference_steps + + # Prepare per-chunk embeddings + # The transformer expects embeddings in shape [batch_size * num_chunks_in_window, seq_len, hidden_dim] + # Each chunk gets its own embedding with appropriate duration/borderness tokens + if chunk_prompt_embeds_list: + # Stack per-chunk embeddings: [num_chunks_in_window, batch_size, seq_len, hidden_dim] + chunk_prompt_embeds = torch.stack(chunk_prompt_embeds_list, dim=0) + # Reshape to [batch_size * num_chunks_in_window, seq_len, hidden_dim] + chunk_prompt_embeds = chunk_prompt_embeds.transpose(0, 1).flatten(0, 1) + + if chunk_negative_prompt_embeds_list: + chunk_negative_prompt_embeds = torch.stack(chunk_negative_prompt_embeds_list, dim=0) + chunk_negative_prompt_embeds = chunk_negative_prompt_embeds.transpose(0, 1).flatten(0, 1) + else: + chunk_negative_prompt_embeds = None + else: + # Fallback: use per-chunk embeddings without special tokens + # Extract embeddings for the current chunk range + chunk_prompt_embeds = prompt_embeds_per_chunk[:, chunk_start_idx:chunk_end_idx] + chunk_prompt_embeds = chunk_prompt_embeds.flatten(0, 1) + chunk_prompt_masks = ( + prompt_masks_per_chunk[:, chunk_start_idx:chunk_end_idx].flatten(0, 1) + if prompt_masks_per_chunk is not None + else None + ) + + if negative_prompt_embeds_per_chunk is not None: + chunk_negative_prompt_embeds = negative_prompt_embeds_per_chunk[ + :, chunk_start_idx:chunk_end_idx + ] + chunk_negative_prompt_embeds = chunk_negative_prompt_embeds.flatten(0, 1) + chunk_negative_masks = ( + negative_masks_per_chunk[:, chunk_start_idx:chunk_end_idx].flatten(0, 1) + if negative_masks_per_chunk is not None + else None + ) + else: + chunk_negative_prompt_embeds = None + chunk_negative_masks = None + + # Create encoder attention mask(s) from tokenizer masks + if chunk_prompt_embeds_list: + prompt_masks_stacked = torch.stack(chunk_prompt_masks_list, dim=0).transpose(0, 1).flatten(0, 1) + else: + prompt_masks_stacked = chunk_prompt_masks + if prompt_masks_stacked is None: + prompt_masks_stacked = torch.ones( + batch_size * num_chunks_in_window, + chunk_prompt_embeds.shape[1], + dtype=torch.int64, + device=chunk_prompt_embeds.device, + ) + encoder_attention_mask = prompt_masks_stacked.to(dtype=chunk_prompt_embeds.dtype).view( + batch_size * num_chunks_in_window, 1, 1, -1 + ) + + if self.do_classifier_free_guidance: + if chunk_negative_prompt_embeds_list: + negative_masks_stacked = ( + torch.stack(chunk_negative_masks_list, dim=0).transpose(0, 1).flatten(0, 1) + ) + else: + negative_masks_stacked = chunk_negative_masks + if negative_masks_stacked is None and chunk_negative_prompt_embeds is not None: + negative_masks_stacked = torch.ones( + batch_size * num_chunks_in_window, + chunk_negative_prompt_embeds.shape[1], + dtype=torch.int64, + device=chunk_negative_prompt_embeds.device, + ) + encoder_attention_mask_neg = ( + negative_masks_stacked.to(dtype=chunk_prompt_embeds.dtype).view( + batch_size * num_chunks_in_window, 1, 1, -1 + ) + if negative_masks_stacked is not None + else encoder_attention_mask + ) + + # Pad prefix video into latent_chunk if applicable (I2V) + # This ensures clean prefix frames are maintained during denoising + if prefix_video is not None: + prefix_length = prefix_video.shape[2] + prefix_video_start = chunk_start_idx * chunk_width + + if prefix_length > prefix_video_start: + # Calculate how many frames to pad + padding_length = min(prefix_length - prefix_video_start, latent_chunk.shape[2]) + prefix_video_end = prefix_video_start + padding_length + + # Pad clean prefix frames into latent_chunk + latent_chunk = latent_chunk.clone() + latent_chunk[:, :, :padding_length] = prefix_video[ + :, :, prefix_video_start:prefix_video_end + ] + + # Set timesteps for clean prefix chunks to maximum (indicates "already clean") + # This matches original MAGI-1's try_pad_prefix_video logic + num_clean_chunks_in_window = padding_length // chunk_width + if num_clean_chunks_in_window > 0: + # Get max timestep from scheduler + max_timestep = timesteps[0] + timestep_per_chunk[:, :num_clean_chunks_in_window] = max_timestep + + # Generate KV range for autoregressive attention + # Each chunk can attend to itself and all previous chunks in the sequence + # Shape: [batch_size * num_chunks_in_window, 2] where each row is [start_token_idx, end_token_idx] + chunk_token_nums = ( + (latent_chunk.shape[2] // num_chunks_in_window) # frames per chunk + * (latent_chunk.shape[3] // self.transformer.config.patch_size[1]) # height tokens + * (latent_chunk.shape[4] // self.transformer.config.patch_size[2]) # width tokens + ) + kv_range = [] + for b in range(batch_size): + # batch_offset should be based on total chunks in the video, not chunk_end_idx + batch_offset = b * num_chunks + for c in range(num_chunks_in_window): + # This chunk can attend from the start of the video up to its own end + chunk_global_idx = chunk_start_idx + c + k_start = batch_offset * chunk_token_nums + k_end = (batch_offset + chunk_global_idx + 1) * chunk_token_nums + kv_range.append([k_start, k_end]) + kv_range = torch.tensor(kv_range, dtype=torch.int32, device=device) + + # Predict noise (conditional) + # Note: MAGI-1 uses velocity field (flow matching), but following diffusers convention + # we use noise_pred naming for consistency across all pipelines + noise_pred = self.transformer( + hidden_states=latent_chunk, + timestep=timestep_per_chunk, + encoder_hidden_states=chunk_prompt_embeds, + encoder_attention_mask=encoder_attention_mask, + attention_kwargs=attention_kwargs, + denoising_range_num=num_chunks_in_window, + range_num=chunk_end_idx, + slice_point=chunk_start_idx, + kv_range=kv_range, + num_steps=num_steps, + distill_interval=distill_interval, + distill_nearly_clean_chunk=distill_nearly_clean_chunk, + return_dict=False, + )[0] + + # Classifier-free guidance: separate forward pass for unconditional + if self.do_classifier_free_guidance: + noise_pred_uncond = self.transformer( + hidden_states=latent_chunk, + timestep=timestep_per_chunk, + encoder_hidden_states=chunk_negative_prompt_embeds, + encoder_attention_mask=encoder_attention_mask_neg, + attention_kwargs=attention_kwargs, + denoising_range_num=num_chunks_in_window, + range_num=chunk_end_idx, + slice_point=chunk_start_idx, + kv_range=kv_range, + num_steps=num_steps, + distill_interval=distill_interval, + distill_nearly_clean_chunk=distill_nearly_clean_chunk, + return_dict=False, + )[0] + # Apply classifier-free guidance + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) + + # CRITICAL: Apply per-chunk Euler integration with different delta_t for each chunk + # This matches the original MAGI-1's integrate() function + # Each chunk in the window is at a different noise level and needs its own time step + + # Calculate per-chunk timesteps for integration + # Get current timesteps (t_before) + current_timesteps = timesteps[timestep_indices] + + # Get next timesteps (t_after) - one step forward for each chunk + next_timestep_indices = [min(idx + 1, len(timesteps) - 1) for idx in timestep_indices] + next_timesteps = timesteps[next_timestep_indices] + + # Convert timesteps to sigmas (matching FlowMatchEulerDiscreteScheduler) + current_sigmas = current_timesteps / self.scheduler.config.num_train_timesteps + next_sigmas = next_timesteps / self.scheduler.config.num_train_timesteps + + # Calculate delta_t for each chunk: [num_chunks_in_window] + delta_t = next_sigmas - current_sigmas + + # Reshape latent_chunk and noise_pred to separate chunks + # From: [batch, channels, frames, height, width] + # To: [batch, channels, num_chunks, chunk_width, height, width] + batch_size_actual, num_channels, total_frames, height_latent, width_latent = latent_chunk.shape + + # Ensure total_frames is divisible by chunk_width for reshaping + # (it should be, but let's handle edge cases) + num_complete_chunks = total_frames // chunk_width + remainder_frames = total_frames % chunk_width + + if remainder_frames == 0: + # Perfect division: reshape and apply per-chunk delta_t + latent_chunk = latent_chunk.reshape( + batch_size_actual, + num_channels, + num_complete_chunks, + chunk_width, + height_latent, + width_latent, + ) + noise_pred = noise_pred.reshape( + batch_size_actual, + num_channels, + num_complete_chunks, + chunk_width, + height_latent, + width_latent, + ) + + # Apply Euler integration: x_chunk = x_chunk + velocity * delta_t + # delta_t shape: [num_chunks] -> broadcast to [1, 1, num_chunks, 1, 1, 1] + delta_t_broadcast = delta_t.reshape(1, 1, -1, 1, 1, 1).to( + latent_chunk.device, latent_chunk.dtype + ) + latent_chunk = latent_chunk + noise_pred * delta_t_broadcast + + # Reshape back to original dimensions + latent_chunk = latent_chunk.reshape( + batch_size_actual, num_channels, total_frames, height_latent, width_latent + ) + else: + # Handle remainder frames separately (edge case for last incomplete chunk) + complete_frames = num_complete_chunks * chunk_width + + # Process complete chunks + latent_chunk_complete = latent_chunk[:, :, :complete_frames] + noise_pred_complete = noise_pred[:, :, :complete_frames] + + latent_chunk_complete = latent_chunk_complete.reshape( + batch_size_actual, + num_channels, + num_complete_chunks, + chunk_width, + height_latent, + width_latent, + ) + noise_pred_complete = noise_pred_complete.reshape( + batch_size_actual, + num_channels, + num_complete_chunks, + chunk_width, + height_latent, + width_latent, + ) + + # Apply per-chunk delta_t to complete chunks + delta_t_broadcast = ( + delta_t[:num_complete_chunks] + .reshape(1, 1, -1, 1, 1, 1) + .to(latent_chunk.device, latent_chunk.dtype) + ) + latent_chunk_complete = latent_chunk_complete + noise_pred_complete * delta_t_broadcast + latent_chunk_complete = latent_chunk_complete.reshape( + batch_size_actual, num_channels, complete_frames, height_latent, width_latent + ) + + # Process remainder frames with last delta_t + if remainder_frames > 0: + latent_chunk_remainder = latent_chunk[:, :, complete_frames:] + noise_pred_remainder = noise_pred[:, :, complete_frames:] + delta_t_remainder = delta_t[-1].to(latent_chunk.device, latent_chunk.dtype) + latent_chunk_remainder = latent_chunk_remainder + noise_pred_remainder * delta_t_remainder + + # Concatenate complete and remainder + latent_chunk = torch.cat([latent_chunk_complete, latent_chunk_remainder], dim=2) + else: + latent_chunk = latent_chunk_complete + + # Write back to full latents + latents[:, :, latent_start:latent_end] = latent_chunk + + # Update chunk denoise counts + for chunk_idx in range(chunk_start_idx, chunk_end_idx): + chunk_denoise_count[chunk_idx] += 1 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + # Use first timestep for callback (most representative) + callback_timestep = current_timesteps[0] + callback_outputs = callback_on_step_end( + self, stage_idx * denoise_step_per_stage + denoise_idx, callback_timestep, callback_kwargs + ) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return Magi1PipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/magi1/pipeline_output.py b/src/diffusers/pipelines/magi1/pipeline_output.py new file mode 100644 index 000000000000..200156cffac9 --- /dev/null +++ b/src/diffusers/pipelines/magi1/pipeline_output.py @@ -0,0 +1,36 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import torch + +from ...utils import BaseOutput + + +@dataclass +class Magi1PipelineOutput(BaseOutput): + """ + Output class for MAGI-1 pipeline. + + Args: + frames (`torch.Tensor` or `np.ndarray`): + List of denoised frames from the diffusion process, as a NumPy array of shape `(batch_size, num_frames, + height, width, num_channels)` or a PyTorch tensor of shape `(batch_size, num_channels, num_frames, height, + width)`. + """ + + frames: Union[torch.Tensor, np.ndarray, List[List[np.ndarray]]] diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 63932221b207..e5406f040d99 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -86,6 +86,8 @@ is_kernels_available, is_kornia_available, is_librosa_available, + is_magi_attn_available, + is_magi_attn_version, is_matplotlib_available, is_nltk_available, is_note_seq_available, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 5d62709c28fd..07dddb0c1b2d 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -408,6 +408,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderKLMagi1(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKLMagvit(metaclass=DummyObject): _backends = ["torch"] @@ -993,6 +1008,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class Magi1Transformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class MochiTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 3244ef12ef87..fd7371f1c9f8 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1742,6 +1742,51 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class Magi1ImageToVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class Magi1Pipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class Magi1VideoToVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class MarigoldDepthPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 97065267b004..7763380af8ab 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -226,6 +226,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _sageattention_available, _sageattention_version = _is_package_available("sageattention") _flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") _flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3") +_magi_attn_available, _magi_attn_version = _is_package_available("magi_attention") _kornia_available, _kornia_version = _is_package_available("kornia") _nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True) @@ -406,6 +407,10 @@ def is_flash_attn_3_available(): return _flash_attn_3_available +def is_magi_attn_available(): + return _magi_attn_available + + def is_kornia_available(): return _kornia_available @@ -911,6 +916,21 @@ def is_flash_attn_version(operation: str, version: str): return compare_versions(parse(_flash_attn_version), operation, version) +def is_magi_attn_version(operation: str, version: str): + """ + Compares the current magi-attention version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _magi_attn_available: + return False + return compare_versions(parse(_magi_attn_version), operation, version) + + def get_objects_from_module(module): """ Returns a dict of object names and values in a module, while skipping private/internal objects diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_magi1.py b/tests/models/autoencoders/test_models_autoencoder_kl_magi1.py new file mode 100644 index 000000000000..ab8b45f31837 --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_kl_magi1.py @@ -0,0 +1,155 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import AutoencoderKLMagi1 +from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device + +from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin + + +enable_full_determinism() + + +class AutoencoderKLMagi1Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): + model_class = AutoencoderKLMagi1 + main_input_name = "sample" + base_precision = 1e-2 + + def get_autoencoder_kl_magi1_config(self): + return { + "base_dim": 3, + "z_dim": 16, + "dim_mult": [1, 1, 1, 1], + "num_res_blocks": 1, + "temperal_downsample": [False, True, True], + } + + @property + def dummy_input(self): + batch_size = 2 + num_frames = 9 + num_channels = 3 + sizes = (16, 16) + image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + return {"sample": image} + + @property + def dummy_input_tiling(self): + batch_size = 2 + num_frames = 9 + num_channels = 3 + sizes = (128, 128) + image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + return {"sample": image} + + @property + def input_shape(self): + return (3, 9, 16, 16) + + @property + def output_shape(self): + return (3, 9, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = self.get_autoencoder_kl_magi1_config() + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def prepare_init_args_and_inputs_for_tiling(self): + init_dict = self.get_autoencoder_kl_magi1_config() + inputs_dict = self.dummy_input_tiling + return init_dict, inputs_dict + + def test_enable_disable_tiling(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_tiling() + + torch.manual_seed(0) + model = self.model_class(**init_dict).to(torch_device) + + inputs_dict.update({"return_dict": False}) + + torch.manual_seed(0) + output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + torch.manual_seed(0) + model.enable_tiling(96, 96, 64, 64) + output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertLess( + (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(), + 0.5, + "VAE tiling should not affect the inference results", + ) + + torch.manual_seed(0) + model.disable_tiling() + output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertEqual( + output_without_tiling.detach().cpu().numpy().all(), + output_without_tiling_2.detach().cpu().numpy().all(), + "Without tiling outputs should match with the outputs when tiling is manually disabled.", + ) + + def test_enable_disable_slicing(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + torch.manual_seed(0) + model = self.model_class(**init_dict).to(torch_device) + + inputs_dict.update({"return_dict": False}) + + torch.manual_seed(0) + output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + torch.manual_seed(0) + model.enable_slicing() + output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertLess( + (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(), + 0.05, + "VAE slicing should not affect the inference results", + ) + + torch.manual_seed(0) + model.disable_slicing() + output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertEqual( + output_without_slicing.detach().cpu().numpy().all(), + output_without_slicing_2.detach().cpu().numpy().all(), + "Without slicing outputs should match with the outputs when slicing is manually disabled.", + ) + + @unittest.skip("Gradient checkpointing has not been implemented yet") + def test_gradient_checkpointing_is_applied(self): + pass + + @unittest.skip("Test not supported") + def test_forward_with_norm_groups(self): + pass + + @unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'") + def test_layerwise_casting_inference(self): + pass + + @unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'") + def test_layerwise_casting_training(self): + pass diff --git a/tests/models/transformers/test_models_transformer_magi1.py b/tests/models/transformers/test_models_transformer_magi1.py new file mode 100644 index 000000000000..ed8d775e6058 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_magi1.py @@ -0,0 +1,91 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import Magi1Transformer3DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin + + +enable_full_determinism() + + +class Magi1Transformer3DTests(ModelTesterMixin, unittest.TestCase): + model_class = Magi1Transformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 1 + num_channels = 4 + num_frames = 2 + height = 16 + width = 16 + text_encoder_embedding_dim = 16 + sequence_length = 12 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + } + + @property + def input_shape(self): + return (4, 1, 16, 16) + + @property + def output_shape(self): + return (4, 1, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "patch_size": (1, 2, 2), + "num_attention_heads": 2, + "attention_head_dim": 12, + "in_channels": 4, + "out_channels": 4, + "cross_attention_dim": 16, + "freq_dim": 256, + "ffn_dim": 32, + "num_layers": 2, + "cross_attn_norm": True, + "qk_norm": "rms_norm_across_heads", + "rope_max_seq_len": 32, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"Magi1Transformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class Magi1TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): + model_class = Magi1Transformer3DModel + + def prepare_init_args_and_inputs_for_common(self): + return Magi1Transformer3DTests().prepare_init_args_and_inputs_for_common() diff --git a/tests/pipelines/magi1/test_magi1.py b/tests/pipelines/magi1/test_magi1.py new file mode 100644 index 000000000000..3695bcbe5be7 --- /dev/null +++ b/tests/pipelines/magi1/test_magi1.py @@ -0,0 +1,158 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import AutoencoderKLMagi1, FlowMatchEulerDiscreteScheduler, Magi1Pipeline, Magi1Transformer3DModel +from diffusers.utils.testing_utils import ( + backend_empty_cache, + enable_full_determinism, + require_torch_accelerator, + slow, + torch_device, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import ( + PipelineTesterMixin, +) + + +enable_full_determinism() + + +class Magi1PipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = Magi1Pipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLMagi1( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + # TODO: impl FlowDPMSolverMultistepScheduler + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = Magi1Transformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=16, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "dance monkey", + "negative_prompt": "negative", # TODO + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": 16, + "width": 16, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + expected_video = torch.randn(9, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + +@slow +@require_torch_accelerator +class Magi1PipelineIntegrationTests(unittest.TestCase): + prompt = "A painting of a squirrel eating a burger." + + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + @unittest.skip("TODO: test needs to be implemented") + def test_Magi1(self): + pass diff --git a/tests/pipelines/magi1/test_magi1_image_to_video.py b/tests/pipelines/magi1/test_magi1_image_to_video.py new file mode 100644 index 000000000000..50ca59f5bbf5 --- /dev/null +++ b/tests/pipelines/magi1/test_magi1_image_to_video.py @@ -0,0 +1,146 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import PIL +import torch +from transformers import AutoTokenizer, CLIPVisionModel, T5EncoderModel + +from diffusers import ( + AutoencoderKLMagi1, + FlowMatchEulerDiscreteScheduler, + Magi1ImageToVideoPipeline, + Magi1Transformer3DModel, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, +) + +from ..pipeline_params import ( + TEXT_TO_IMAGE_BATCH_PARAMS, + TEXT_TO_IMAGE_IMAGE_PARAMS, + TEXT_TO_IMAGE_PARAMS, +) +from ..test_pipelines_common import ( + PipelineTesterMixin, +) + + +enable_full_determinism() + + +class Magi1ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = Magi1ImageToVideoPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLMagi1( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + # TODO: impl FlowDPMSolverMultistepScheduler + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + image_encoder = CLIPVisionModel.from_pretrained("hf-internal-testing/tiny-random-clip") + + torch.manual_seed(0) + transformer = Magi1Transformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=16, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "image_encoder": image_encoder, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "dance monkey", + "negative_prompt": "negative", # TODO + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": 16, + "width": 16, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + "image": PIL.Image.new("RGB", (16, 16)), + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + expected_video = torch.randn(9, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass diff --git a/tests/pipelines/magi1/test_magi1_video_to_video.py b/tests/pipelines/magi1/test_magi1_video_to_video.py new file mode 100644 index 000000000000..97efc4da0539 --- /dev/null +++ b/tests/pipelines/magi1/test_magi1_video_to_video.py @@ -0,0 +1,153 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import PIL +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import ( + AutoencoderKLMagi1, + FlowMatchEulerDiscreteScheduler, + Magi1Transformer3DModel, + Magi1VideoToVideoPipeline, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import ( + PipelineTesterMixin, +) + + +enable_full_determinism() + + +class Magi1VideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = Magi1VideoToVideoPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLMagi1( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = Magi1Transformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=16, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + # Create a list of PIL images to simulate video input + video_frames = [PIL.Image.new("RGB", (16, 16)) for _ in range(9)] + inputs = { + "prompt": "dance monkey", + "negative_prompt": "negative", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": 16, + "width": 16, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + "video": video_frames, + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + expected_video = torch.randn(9, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + @unittest.skip( + "Magi1VideoToVideoPipeline has to run in mixed precision. Casting the entire pipeline will result in errors" + ) + def test_model_cpu_offload_forward_pass(self): + pass + + @unittest.skip( + "Magi1VideoToVideoPipeline has to run in mixed precision. Save/Load the entire pipeline in FP16 will result in errors" + ) + def test_save_load_float16(self): + pass diff --git a/tests/single_file/test_model_magi_autoencoder_single_file.py b/tests/single_file/test_model_magi_autoencoder_single_file.py new file mode 100644 index 000000000000..8721d884fa25 --- /dev/null +++ b/tests/single_file/test_model_magi_autoencoder_single_file.py @@ -0,0 +1,61 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import AutoencoderKLMagi1 +from diffusers.utils.testing_utils import ( + require_torch_gpu, + slow, + torch_device, +) + + +class AutoencoderKLMagiSingleFileTests(unittest.TestCase): + model_class = AutoencoderKLMagi1 + ckpt_path = "https://huggingface.co/sand-ai/MAGI-1/blob/main/vae/diffusion_pytorch_model.safetensors" + repo_id = "sand-ai/MAGI-1" + + @slow + @require_torch_gpu + def test_single_file_components(self): + model = self.model_class.from_single_file(self.ckpt_path) + model.to(torch_device) + + batch_size = 1 + num_frames = 2 + num_channels = 3 + sizes = (16, 16) + image = torch.randn((batch_size, num_channels, num_frames) + sizes).to(torch_device) + + with torch.no_grad(): + model(image, return_dict=False) + + @slow + @require_torch_gpu + def test_single_file_components_from_hub(self): + model = self.model_class.from_pretrained(self.repo_id, subfolder="vae") + model.to(torch_device) + + batch_size = 1 + num_frames = 2 + num_channels = 3 + sizes = (16, 16) + image = torch.randn((batch_size, num_channels, num_frames) + sizes).to(torch_device) + + with torch.no_grad(): + model(image, return_dict=False) diff --git a/tests/single_file/test_model_magi_transformer3d_single_file.py b/tests/single_file/test_model_magi_transformer3d_single_file.py new file mode 100644 index 000000000000..fb6b0ae04622 --- /dev/null +++ b/tests/single_file/test_model_magi_transformer3d_single_file.py @@ -0,0 +1,83 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import Magi1Transformer3DModel +from diffusers.utils.testing_utils import ( + require_torch_gpu, + slow, + torch_device, +) + + +class Magi1Transformer3DModelText2VideoSingleFileTest(unittest.TestCase): + model_class = Magi1Transformer3DModel + ckpt_path = "https://huggingface.co/sand-ai/MAGI-1/blob/main/transformer/diffusion_pytorch_model.safetensors" + repo_id = "sand-ai/MAGI-1" + + @slow + @require_torch_gpu + def test_single_file_components(self): + model = self.model_class.from_single_file(self.ckpt_path) + model.to(torch_device) + + batch_size = 1 + num_channels = 4 + num_frames = 2 + height = 16 + width = 16 + text_encoder_embedding_dim = 16 + sequence_length = 12 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) + + with torch.no_grad(): + model( + hidden_states=hidden_states, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + return_dict=False, + ) + + @slow + @require_torch_gpu + def test_single_file_components_from_hub(self): + model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer") + model.to(torch_device) + + batch_size = 1 + num_channels = 4 + num_frames = 2 + height = 16 + width = 16 + text_encoder_embedding_dim = 16 + sequence_length = 12 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) + + with torch.no_grad(): + model( + hidden_states=hidden_states, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + return_dict=False, + )