diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index ffd7385d81a5..a8eb2d28ea1c 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -372,6 +372,8 @@
title: Lumina2Transformer2DModel
- local: api/models/lumina_nextdit2d
title: LuminaNextDiT2DModel
+ - local: api/models/magi_transformer_3d
+ title: Magi1Transformer3DModel
- local: api/models/mochi_transformer3d
title: MochiTransformer3DModel
- local: api/models/omnigen_transformer
@@ -430,6 +432,8 @@
title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoderkl_ltx_video
title: AutoencoderKLLTXVideo
+ - local: api/models/autoencoder_kl_magi1
+ title: AutoencoderKLMagi1
- local: api/models/autoencoderkl_magvit
title: AutoencoderKLMagvit
- local: api/models/autoencoderkl_mochi
diff --git a/docs/source/en/api/models/autoencoder_kl_magi1.md b/docs/source/en/api/models/autoencoder_kl_magi1.md
new file mode 100644
index 000000000000..2301d08b72da
--- /dev/null
+++ b/docs/source/en/api/models/autoencoder_kl_magi1.md
@@ -0,0 +1,34 @@
+
+
+# AutoencoderKLMagi1
+
+The 3D variational autoencoder (VAE) model with KL loss used in [MAGI-1: Autoregressive Video Generation at Scale](https://arxiv.org/abs/2505.13211) by Sand.ai.
+
+MAGI-1 uses a transformer-based VAE with 8x spatial and 4x temporal compression, providing fast average decoding time and highly competitive reconstruction quality.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import AutoencoderKLMagi1
+
+vae = AutoencoderKLMagi1.from_pretrained("sand-ai/MAGI-1", subfolder="vae", torch_dtype=torch.float32)
+```
+
+## AutoencoderKLMagi1
+
+[[autodoc]] AutoencoderKLMagi1
+ - decode
+ - all
+
+## DecoderOutput
+
+[[autodoc]] models.autoencoders.vae.DecoderOutput
diff --git a/docs/source/en/api/models/magi1_transformer_3d.md b/docs/source/en/api/models/magi1_transformer_3d.md
new file mode 100644
index 000000000000..8fb369f16253
--- /dev/null
+++ b/docs/source/en/api/models/magi1_transformer_3d.md
@@ -0,0 +1,32 @@
+
+
+# Magi1Transformer3DModel
+
+A Diffusion Transformer model for 3D video-like data was introduced in [MAGI-1: Autoregressive Video Generation at Scale](https://arxiv.org/abs/2505.13211) by Sand.ai.
+
+MAGI-1 is an autoregressive denoising video generation model that generates videos chunk-by-chunk instead of as a whole. Each chunk (24 frames) is denoised holistically, and the generation of the next chunk begins as soon as the current one reaches a certain level of denoising.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import Magi1Transformer3DModel
+
+transformer = Magi1Transformer3DModel.from_pretrained("sand-ai/MAGI-1", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## Magi1Transformer3DModel
+
+[[autodoc]] Magi1Transformer3DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
\ No newline at end of file
diff --git a/docs/source/en/api/pipelines/magi1.md b/docs/source/en/api/pipelines/magi1.md
new file mode 100644
index 000000000000..b64c6f57c9b1
--- /dev/null
+++ b/docs/source/en/api/pipelines/magi1.md
@@ -0,0 +1,261 @@
+
+
+
+
+# MAGI-1
+
+[MAGI-1: Autoregressive Video Generation at Scale](https://arxiv.org/abs/2505.13211) by Sand.ai.
+
+*MAGI-1 is an autoregressive video generation model that generates videos chunk-by-chunk instead of as a whole. Each chunk (24 frames) is denoised holistically, and the generation of the next chunk begins as soon as the current one reaches a certain level of denoising. This pipeline design enables concurrent processing of up to four chunks for efficient video generation. The model leverages a specialized architecture with a transformer-based VAE with 8x spatial and 4x temporal compression, and a diffusion transformer with several key innovations including Block-Causal Attention, Parallel Attention Block, QK-Norm and GQA, Sandwich Normalization in FFN, SwiGLU, and Softcap Modulation.*
+
+The original repo: https://github.com/SandAI-org/MAGI-1
+
+This model was contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz).
+
+You can find the MAGI-1 checkpoints under the [sand-ai](https://huggingface.co/sand-ai) organization.
+
+The following MAGI-1 models are supported in Diffusers:
+
+**Base Models:**
+- [MAGI-1 24B](https://huggingface.co/sand-ai/MAGI-1)
+- [MAGI-1 4.5B](https://huggingface.co/sand-ai/MAGI-1-4.5B)
+
+**Distilled Models (faster inference):**
+- [MAGI-1 24B Distill](https://huggingface.co/sand-ai/MAGI-1/tree/main/ckpt/magi/24B_distill)
+- [MAGI-1 24B Distill+Quant (FP8)](https://huggingface.co/sand-ai/MAGI-1/tree/main/ckpt/magi/24B_distill_quant)
+- [MAGI-1 4.5B Distill](https://huggingface.co/sand-ai/MAGI-1/tree/main/ckpt/magi/4.5B_distill)
+- [MAGI-1 4.5B Distill+Quant (FP8)](https://huggingface.co/sand-ai/MAGI-1/tree/main/ckpt/magi/4.5B_distill_quant)
+
+> [!TIP]
+> Click on the MAGI-1 models in the right sidebar for more examples of video generation.
+
+### Text-to-Video Generation
+
+The example below demonstrates how to generate a video from text optimized for memory or inference speed.
+
+
+
+
+Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.
+
+The MAGI-1 text-to-video model below requires ~13GB of VRAM.
+
+```py
+import torch
+import numpy as np
+from diffusers import AutoModel, Magi1Pipeline
+from diffusers.hooks.group_offloading import apply_group_offloading
+from diffusers.utils import export_to_video
+from transformers import T5EncoderModel
+
+text_encoder = T5EncoderModel.from_pretrained("sand-ai/MAGI-1", subfolder="text_encoder", torch_dtype=torch.bfloat16)
+vae = AutoModel.from_pretrained("sand-ai/MAGI-1", subfolder="vae", torch_dtype=torch.float32)
+transformer = AutoModel.from_pretrained("sand-ai/MAGI-1", subfolder="transformer", torch_dtype=torch.bfloat16)
+
+# group-offloading
+onload_device = torch.device("cuda")
+offload_device = torch.device("cpu")
+apply_group_offloading(text_encoder,
+ onload_device=onload_device,
+ offload_device=offload_device,
+ offload_type="block_level",
+ num_blocks_per_group=4
+)
+transformer.enable_group_offload(
+ onload_device=onload_device,
+ offload_device=offload_device,
+ offload_type="leaf_level",
+ use_stream=True
+)
+
+pipeline = Magi1Pipeline.from_pretrained(
+ "sand-ai/MAGI-1",
+ vae=vae,
+ transformer=transformer,
+ text_encoder=text_encoder,
+ torch_dtype=torch.bfloat16
+)
+pipeline.to("cuda")
+
+prompt = """
+A majestic eagle soaring over a mountain landscape. The eagle's wings are spread wide,
+catching the golden sunlight as it glides through the clear blue sky. Below, snow-capped
+mountains stretch to the horizon, with pine forests and a winding river visible in the valley.
+"""
+negative_prompt = """
+Poor quality, blurry, pixelated, low resolution, distorted proportions, unnatural colors,
+watermark, text overlay, incomplete rendering, glitches, artifacts, unrealistic lighting
+"""
+
+output = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_frames=24,
+ guidance_scale=7.0,
+).frames[0]
+export_to_video(output, "output.mp4", fps=8)
+```
+
+
+
+
+[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster.
+
+```py
+import torch
+import numpy as np
+from diffusers import AutoModel, Magi1Pipeline
+from diffusers.utils import export_to_video
+from transformers import T5EncoderModel
+
+text_encoder = T5EncoderModel.from_pretrained("sand-ai/MAGI-1", subfolder="text_encoder", torch_dtype=torch.bfloat16)
+vae = AutoModel.from_pretrained("sand-ai/MAGI-1", subfolder="vae", torch_dtype=torch.float32)
+transformer = AutoModel.from_pretrained("sand-ai/MAGI-1", subfolder="transformer", torch_dtype=torch.bfloat16)
+
+pipeline = Magi1Pipeline.from_pretrained(
+ "sand-ai/MAGI-1",
+ vae=vae,
+ transformer=transformer,
+ text_encoder=text_encoder,
+ torch_dtype=torch.bfloat16
+)
+pipeline.to("cuda")
+
+# torch.compile
+pipeline.transformer.to(memory_format=torch.channels_last)
+pipeline.transformer = torch.compile(
+ pipeline.transformer, mode="max-autotune", fullgraph=True
+)
+
+prompt = """
+A majestic eagle soaring over a mountain landscape. The eagle's wings are spread wide,
+catching the golden sunlight as it glides through the clear blue sky. Below, snow-capped
+mountains stretch to the horizon, with pine forests and a winding river visible in the valley.
+"""
+negative_prompt = """
+Poor quality, blurry, pixelated, low resolution, distorted proportions, unnatural colors,
+watermark, text overlay, incomplete rendering, glitches, artifacts, unrealistic lighting
+"""
+
+output = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_frames=24,
+ guidance_scale=7.0,
+).frames[0]
+export_to_video(output, "output.mp4", fps=8)
+```
+
+
+
+
+### Image-to-Video Generation
+
+The example below demonstrates how to use the image-to-video pipeline to generate a video animation from a single image using text prompts for guidance.
+
+
+
+
+```python
+import torch
+from diffusers import Magi1ImageToVideoPipeline, AutoencoderKLMagi1
+from diffusers.utils import export_to_video, load_image
+
+model_id = "sand-ai/MAGI-1-I2V"
+vae = AutoencoderKLMagi1.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+pipe = Magi1ImageToVideoPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+# Load input image
+image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg")
+
+prompt = (
+ "An astronaut walking on the moon's surface, with the Earth visible in the background. "
+ "The astronaut moves slowly in a low-gravity environment, kicking up lunar dust with each step."
+)
+negative_prompt = "Bright tones, overexposed, static, blurred details, worst quality, low quality"
+
+output = pipe(
+ image=image,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=480,
+ width=832,
+ num_frames=81, # Generate 81 frames (~5 seconds at 16fps)
+ guidance_scale=5.0,
+ num_inference_steps=50,
+).frames[0]
+export_to_video(output, "astronaut_animation.mp4", fps=16)
+```
+
+
+
+
+### Video-to-Video Generation
+
+The example below demonstrates how to use the video-to-video pipeline to extend or continue an existing video using text prompts.
+
+
+
+
+```python
+import torch
+from diffusers import Magi1VideoToVideoPipeline, AutoencoderKLMagi1
+from diffusers.utils import export_to_video, load_video
+
+model_id = "sand-ai/MAGI-1-V2V"
+vae = AutoencoderKLMagi1.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+pipe = Magi1VideoToVideoPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+# Load prefix video (e.g., first 24 frames of a video)
+video = load_video("path/to/input_video.mp4", num_frames=24)
+
+prompt = (
+ "Continue this video with smooth camera motion and consistent style. "
+ "The scene evolves naturally with coherent motion and lighting."
+)
+negative_prompt = "Bright tones, overexposed, static, blurred details, worst quality, low quality, jumpy motion"
+
+output = pipe(
+ video=video,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=480,
+ width=832,
+ num_frames=81, # Total frames including prefix (24 prefix + 57 generated)
+ guidance_scale=5.0,
+ num_inference_steps=50,
+).frames[0]
+export_to_video(output, "video_continuation.mp4", fps=16)
+```
+
+
+
+
+## Notes
+
+- MAGI-1 uses autoregressive chunked generation with `chunk_width=6` and `window_size=4`, enabling efficient long video generation.
+- The model supports special tokens for quality control (HQ_TOKEN), style (THREE_D_MODEL_TOKEN, TWO_D_ANIME_TOKEN), and motion guidance (STATIC_FIRST_FRAMES_TOKEN, DYNAMIC_FIRST_FRAMES_TOKEN).
+- For I2V, the input image is encoded as a clean prefix chunk to condition the video generation.
+- For V2V, input video frames (typically 24 frames or ~1.5 seconds) are encoded as clean prefix chunks, and the model generates a continuation.
+- MAGI-1 supports LoRAs with [`~loaders.Magi1LoraLoaderMixin.load_lora_weights`].
+- Distillation mode can be enabled for faster inference with `enable_distillation=True` (requires distilled model checkpoint).
\ No newline at end of file
diff --git a/docs/source/en/optimization/attention_backends.md b/docs/source/en/optimization/attention_backends.md
index 8be2c0603009..82547fedceec 100644
--- a/docs/source/en/optimization/attention_backends.md
+++ b/docs/source/en/optimization/attention_backends.md
@@ -149,5 +149,6 @@ Refer to the table below for a complete list of available attention backends and
| `_sage_qk_int8_pv_fp16_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP16 PV (CUDA) |
| `_sage_qk_int8_pv_fp16_triton` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP16 PV (Triton) |
| `xformers` | [xFormers](https://github.com/facebookresearch/xformers) | Memory-efficient attention |
+| `magi` | [MagiAttention](https://github.com/SandAI-org/MagiAttention) | A CP-based Attention Towards Linear Scalability, Heterogeneous Mask Training |
-
\ No newline at end of file
+
diff --git a/scripts/convert_magi1_to_diffusers.py b/scripts/convert_magi1_to_diffusers.py
new file mode 100644
index 000000000000..57c5d88a0f42
--- /dev/null
+++ b/scripts/convert_magi1_to_diffusers.py
@@ -0,0 +1,548 @@
+import argparse
+import json
+import os
+import shutil
+import tempfile
+
+import torch
+from huggingface_hub import hf_hub_download
+from safetensors import safe_open
+from safetensors.torch import load_file
+
+from diffusers import Magi1Pipeline, Magi1Transformer3DModel
+from diffusers.models.autoencoders import AutoencoderKLMagi1
+
+
+def convert_magi1_transformer(model_type):
+ """
+ Convert MAGI-1 transformer for a specific model type.
+
+ Args:
+ model_type: The model type (e.g., "MAGI-1-T2V-4.5B-distill", "MAGI-1-T2V-24B-distill", etc.)
+
+ Returns:
+ The converted transformer model.
+ """
+
+ model_type_mapping = {
+ "MAGI-1-T2V-4.5B-distill": "4.5B_distill",
+ "MAGI-1-T2V-24B-distill": "24B_distill",
+ "MAGI-1-T2V-4.5B": "4.5B",
+ "MAGI-1-T2V-24B": "24B",
+ "4.5B_distill": "4.5B_distill",
+ "24B_distill": "24B_distill",
+ "4.5B": "4.5B",
+ "24B": "24B",
+ }
+
+ repo_path = model_type_mapping.get(model_type, model_type)
+
+ temp_dir = tempfile.mkdtemp()
+ transformer_ckpt_dir = os.path.join(temp_dir, "transformer_checkpoint")
+ os.makedirs(transformer_ckpt_dir, exist_ok=True)
+
+ checkpoint_files = []
+ shard_index = 1
+ while True:
+ try:
+ if shard_index == 1:
+ shard_filename = f"model-{shard_index:05d}-of-00002.safetensors"
+ shard_path = hf_hub_download(
+ "sand-ai/MAGI-1", f"ckpt/magi/{repo_path}/inference_weight.distill/{shard_filename}"
+ )
+ checkpoint_files.append(shard_path)
+ print(f"Downloaded {shard_filename}")
+ shard_index += 1
+ elif shard_index == 2:
+ shard_filename = f"model-{shard_index:05d}-of-00002.safetensors"
+ shard_path = hf_hub_download(
+ "sand-ai/MAGI-1", f"ckpt/magi/{repo_path}/inference_weight.distill/{shard_filename}"
+ )
+ checkpoint_files.append(shard_path)
+ print(f"Downloaded {shard_filename}")
+ break
+ else:
+ break
+ except Exception as e:
+ print(f"No more shards found or error downloading shard {shard_index}: {e}")
+ break
+
+ if not checkpoint_files:
+ raise ValueError(f"No checkpoint files found for model type: {model_type}")
+
+ for i, shard_path in enumerate(checkpoint_files):
+ dest_path = os.path.join(transformer_ckpt_dir, f"model-{i + 1:05d}-of-{len(checkpoint_files):05d}.safetensors")
+ shutil.copy2(shard_path, dest_path)
+
+ transformer = convert_magi1_transformer_checkpoint(transformer_ckpt_dir)
+
+ return transformer
+
+
+def convert_magi1_vae():
+ vae_ckpt_path = hf_hub_download("sand-ai/MAGI-1", "ckpt/vae/diffusion_pytorch_model.safetensors")
+ checkpoint = load_file(vae_ckpt_path)
+
+ config = {
+ "patch_size": (4, 8, 8),
+ "num_attention_heads": 16,
+ "attention_head_dim": 64,
+ "z_dim": 16,
+ "height": 256,
+ "width": 256,
+ "num_frames": 16,
+ "ffn_dim": 4 * 1024,
+ "num_layers": 24,
+ "eps": 1e-6,
+ }
+
+ vae = AutoencoderKLMagi1(
+ patch_size=config["patch_size"],
+ num_attention_heads=config["num_attention_heads"],
+ attention_head_dim=config["attention_head_dim"],
+ z_dim=config["z_dim"],
+ height=config["height"],
+ width=config["width"],
+ num_frames=config["num_frames"],
+ ffn_dim=config["ffn_dim"],
+ num_layers=config["num_layers"],
+ eps=config["eps"],
+ )
+
+ converted_state_dict = convert_vae_state_dict(checkpoint)
+
+ vae.load_state_dict(converted_state_dict, strict=True)
+
+ return vae
+
+
+def convert_vae_state_dict(checkpoint):
+ """
+ Convert MAGI-1 VAE state dict to diffusers format.
+
+ Maps the keys from the MAGI-1 VAE state dict to the diffusers VAE state dict.
+ """
+ state_dict = {}
+
+ state_dict["encoder.patch_embedding.weight"] = checkpoint["encoder.patch_embed.proj.weight"]
+ state_dict["encoder.patch_embedding.bias"] = checkpoint["encoder.patch_embed.proj.bias"]
+
+ state_dict["encoder.pos_embed"] = checkpoint["encoder.pos_embed"]
+
+ state_dict["encoder.cls_token"] = checkpoint["encoder.cls_token"]
+
+ for i in range(24):
+ qkv_weight = checkpoint[f"encoder.blocks.{i}.attn.qkv.weight"]
+ qkv_bias = checkpoint[f"encoder.blocks.{i}.attn.qkv.bias"]
+
+ q_weight, k_weight, v_weight = qkv_weight.chunk(3, dim=0)
+ q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
+
+ state_dict[f"encoder.blocks.{i}.attn.to_q.weight"] = q_weight
+ state_dict[f"encoder.blocks.{i}.attn.to_q.bias"] = q_bias
+ state_dict[f"encoder.blocks.{i}.attn.to_k.weight"] = k_weight
+ state_dict[f"encoder.blocks.{i}.attn.to_k.bias"] = k_bias
+ state_dict[f"encoder.blocks.{i}.attn.to_v.weight"] = v_weight
+ state_dict[f"encoder.blocks.{i}.attn.to_v.bias"] = v_bias
+
+ state_dict[f"encoder.blocks.{i}.attn.to_out.0.weight"] = checkpoint[f"encoder.blocks.{i}.attn.proj.weight"]
+ state_dict[f"encoder.blocks.{i}.attn.to_out.0.bias"] = checkpoint[f"encoder.blocks.{i}.attn.proj.bias"]
+
+ state_dict[f"encoder.blocks.{i}.norm.weight"] = checkpoint[f"encoder.blocks.{i}.norm2.weight"]
+ state_dict[f"encoder.blocks.{i}.norm.bias"] = checkpoint[f"encoder.blocks.{i}.norm2.bias"]
+
+ state_dict[f"encoder.blocks.{i}.proj_out.net.0.proj.weight"] = checkpoint[f"encoder.blocks.{i}.mlp.fc1.weight"]
+ state_dict[f"encoder.blocks.{i}.proj_out.net.0.proj.bias"] = checkpoint[f"encoder.blocks.{i}.mlp.fc1.bias"]
+ state_dict[f"encoder.blocks.{i}.proj_out.net.2.weight"] = checkpoint[f"encoder.blocks.{i}.mlp.fc2.weight"]
+
+ state_dict[f"encoder.blocks.{i}.proj_out.net.2.bias"] = checkpoint[f"encoder.blocks.{i}.mlp.fc2.bias"]
+
+ state_dict["encoder.norm_out.weight"] = checkpoint["encoder.norm.weight"]
+ state_dict["encoder.norm_out.bias"] = checkpoint["encoder.norm.bias"]
+
+ state_dict["encoder.linear_out.weight"] = checkpoint["encoder.last_layer.weight"]
+ state_dict["encoder.linear_out.bias"] = checkpoint["encoder.last_layer.bias"]
+
+ state_dict["decoder.proj_in.weight"] = checkpoint["decoder.proj_in.weight"]
+ state_dict["decoder.proj_in.bias"] = checkpoint["decoder.proj_in.bias"]
+
+ state_dict["decoder.pos_embed"] = checkpoint["decoder.pos_embed"]
+
+ state_dict["decoder.cls_token"] = checkpoint["decoder.cls_token"]
+
+ for i in range(24):
+ qkv_weight = checkpoint[f"decoder.blocks.{i}.attn.qkv.weight"]
+ qkv_bias = checkpoint[f"decoder.blocks.{i}.attn.qkv.bias"]
+
+ q_weight, k_weight, v_weight = qkv_weight.chunk(3, dim=0)
+ q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
+
+ state_dict[f"decoder.blocks.{i}.attn.to_q.weight"] = q_weight
+ state_dict[f"decoder.blocks.{i}.attn.to_q.bias"] = q_bias
+ state_dict[f"decoder.blocks.{i}.attn.to_k.weight"] = k_weight
+ state_dict[f"decoder.blocks.{i}.attn.to_k.bias"] = k_bias
+ state_dict[f"decoder.blocks.{i}.attn.to_v.weight"] = v_weight
+ state_dict[f"decoder.blocks.{i}.attn.to_v.bias"] = v_bias
+
+ state_dict[f"decoder.blocks.{i}.attn.to_out.0.weight"] = checkpoint[f"decoder.blocks.{i}.attn.proj.weight"]
+ state_dict[f"decoder.blocks.{i}.attn.to_out.0.bias"] = checkpoint[f"decoder.blocks.{i}.attn.proj.bias"]
+
+ state_dict[f"decoder.blocks.{i}.norm.weight"] = checkpoint[f"decoder.blocks.{i}.norm2.weight"]
+ state_dict[f"decoder.blocks.{i}.norm.bias"] = checkpoint[f"decoder.blocks.{i}.norm2.bias"]
+
+ state_dict[f"decoder.blocks.{i}.proj_out.net.0.proj.weight"] = checkpoint[f"decoder.blocks.{i}.mlp.fc1.weight"]
+ state_dict[f"decoder.blocks.{i}.proj_out.net.0.proj.bias"] = checkpoint[f"decoder.blocks.{i}.mlp.fc1.bias"]
+ state_dict[f"decoder.blocks.{i}.proj_out.net.2.weight"] = checkpoint[f"decoder.blocks.{i}.mlp.fc2.weight"]
+ state_dict[f"decoder.blocks.{i}.proj_out.net.2.bias"] = checkpoint[f"decoder.blocks.{i}.mlp.fc2.bias"]
+
+ state_dict["decoder.norm_out.weight"] = checkpoint["decoder.norm.weight"]
+ state_dict["decoder.norm_out.bias"] = checkpoint["decoder.norm.bias"]
+
+ state_dict["decoder.conv_out.weight"] = checkpoint["decoder.last_layer.weight"]
+ state_dict["decoder.conv_out.bias"] = checkpoint["decoder.last_layer.bias"]
+
+ return state_dict
+
+
+def load_magi1_transformer_checkpoint(checkpoint_path):
+ """
+ Load a MAGI-1 transformer checkpoint.
+
+ Args:
+ checkpoint_path: Path to the MAGI-1 transformer checkpoint.
+
+ Returns:
+ The loaded checkpoint state dict.
+ """
+ if checkpoint_path.endswith(".safetensors"):
+ state_dict = load_file(checkpoint_path)
+ elif os.path.isdir(checkpoint_path):
+ safetensors_files = [f for f in os.listdir(checkpoint_path) if f.endswith(".safetensors")]
+ if safetensors_files:
+ state_dict = {}
+ for safetensors_file in sorted(safetensors_files):
+ file_path = os.path.join(checkpoint_path, safetensors_file)
+ with safe_open(file_path, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ state_dict[key] = f.get_tensor(key)
+ else:
+ checkpoint_files = [f for f in os.listdir(checkpoint_path) if f.endswith(".pt") or f.endswith(".pth")]
+ if not checkpoint_files:
+ raise ValueError(f"No checkpoint files found in {checkpoint_path}")
+
+ checkpoint_file = os.path.join(checkpoint_path, checkpoint_files[0])
+ checkpoint_data = torch.load(checkpoint_file, map_location="cpu")
+
+ if isinstance(checkpoint_data, dict):
+ if "model" in checkpoint_data:
+ state_dict = checkpoint_data["model"]
+ elif "state_dict" in checkpoint_data:
+ state_dict = checkpoint_data["state_dict"]
+ else:
+ state_dict = checkpoint_data
+ else:
+ state_dict = checkpoint_data
+ else:
+ checkpoint_data = torch.load(checkpoint_path, map_location="cpu")
+
+ if isinstance(checkpoint_data, dict):
+ if "model" in checkpoint_data:
+ state_dict = checkpoint_data["model"]
+ elif "state_dict" in checkpoint_data:
+ state_dict = checkpoint_data["state_dict"]
+ else:
+ state_dict = checkpoint_data
+ else:
+ state_dict = checkpoint_data
+
+ return state_dict
+
+
+def convert_magi1_transformer_checkpoint(checkpoint_path, transformer_config_file=None, dtype=None, allow_partial=False):
+ """
+ Convert a MAGI-1 transformer checkpoint to a diffusers Magi1Transformer3DModel.
+
+ Args:
+ checkpoint_path: Path to the MAGI-1 transformer checkpoint.
+ transformer_config_file: Optional path to a transformer config file.
+ dtype: Optional dtype for the model.
+
+ Returns:
+ A diffusers Magi1Transformer3DModel model.
+ """
+ if transformer_config_file is not None:
+ with open(transformer_config_file, "r") as f:
+ config = json.load(f)
+ else:
+ config = {
+ "in_channels": 16,
+ "out_channels": 16,
+ "num_layers": 34,
+ "num_attention_heads": 24,
+ "num_kv_heads": 8,
+ "attention_head_dim": 128,
+ "cross_attention_dim": 4096,
+ "freq_dim": 256,
+ "ffn_dim": 12288,
+ "patch_size": (1, 2, 2),
+ "eps": 1e-6,
+ }
+
+ transformer = Magi1Transformer3DModel(
+ in_channels=config["in_channels"],
+ out_channels=config["out_channels"],
+ num_layers=config["num_layers"],
+ num_attention_heads=config["num_attention_heads"],
+ num_kv_heads=config["num_kv_heads"],
+ attention_head_dim=config["attention_head_dim"],
+ cross_attention_dim=config["cross_attention_dim"],
+ freq_dim=config["freq_dim"],
+ ffn_dim=config["ffn_dim"],
+ patch_size=config["patch_size"],
+ eps=config["eps"],
+ )
+
+ checkpoint = load_magi1_transformer_checkpoint(checkpoint_path)
+
+ converted_state_dict, report = convert_transformer_state_dict(checkpoint, transformer, allow_partial=allow_partial)
+
+ # Verify mapping coverage & shapes
+ print("\n=== MAGI-1 -> Diffusers mapping report ===")
+ print(f"Source keys used: {report['used_src_keys']} / {report['total_src_keys']}")
+ if report["missing_src_keys"]:
+ print(f"Missing source keys referenced: {len(report['missing_src_keys'])}")
+ print("Examples:", report["missing_src_keys"][:20])
+
+ # Target verifications
+ expected = transformer.state_dict()
+ expected_keys = set(expected.keys())
+ got_keys = set(converted_state_dict.keys())
+ missing_target = sorted(list(expected_keys - got_keys))
+ unexpected_target = sorted(list(got_keys - expected_keys))
+
+ shape_mismatches = []
+ for k in sorted(list(expected_keys & got_keys)):
+ if tuple(expected[k].shape) != tuple(converted_state_dict[k].shape):
+ shape_mismatches.append((k, tuple(converted_state_dict[k].shape), tuple(expected[k].shape)))
+
+ if missing_target:
+ print(f"Missing target keys: {len(missing_target)}")
+ print("Examples:", missing_target[:20])
+ if unexpected_target:
+ print(f"Unexpected converted keys: {len(unexpected_target)}")
+ print("Examples:", unexpected_target[:20])
+ if shape_mismatches:
+ print(f"Shape mismatches: {len(shape_mismatches)}")
+ print("Examples:", shape_mismatches[:5])
+
+ if (report["missing_src_keys"] or missing_target or shape_mismatches):
+ raise ValueError("Conversion verification failed. See report above.")
+
+ # Enforce strict=True per requirement
+ transformer.load_state_dict(converted_state_dict, strict=True)
+
+ if dtype is not None:
+ transformer = transformer.to(dtype=dtype)
+
+ return transformer
+
+
+def convert_transformer_state_dict(checkpoint, transformer=None, allow_partial=False):
+ """
+ Convert MAGI-1 transformer state dict to diffusers format.
+
+ Maps the original MAGI-1 parameter names to diffusers' standard transformer naming.
+ Handles all shape mismatches and key mappings based on actual checkpoint analysis.
+ """
+ print("Converting MAGI-1 checkpoint keys...")
+
+ converted_state_dict = {}
+ used_src_keys = set()
+ missing_src_keys = []
+
+ def require(key: str) -> torch.Tensor:
+ if key not in checkpoint:
+ missing_src_keys.append(key)
+ if allow_partial:
+ return None # will be skipped by caller
+ raise KeyError(f"Missing source key: {key}")
+ used_src_keys.add(key)
+ return checkpoint[key]
+
+ def assign(src: str, dst: str):
+ val = require(src)
+ if val is not None:
+ converted_state_dict[dst] = val
+
+ def split_assign(src: str, dst_k: str, dst_v: str):
+ kv = require(src)
+ if kv is not None:
+ k, v = kv.chunk(2, dim=0)
+ converted_state_dict[dst_k] = k
+ converted_state_dict[dst_v] = v
+
+ # Simple top-level mappings
+ simple_maps = [
+ ("x_embedder.weight", "patch_embedding.weight"),
+ ("t_embedder.mlp.0.weight", "condition_embedder.time_embedder.linear_1.weight"),
+ ("t_embedder.mlp.0.bias", "condition_embedder.time_embedder.linear_1.bias"),
+ ("t_embedder.mlp.2.weight", "condition_embedder.time_embedder.linear_2.weight"),
+ ("t_embedder.mlp.2.bias", "condition_embedder.time_embedder.linear_2.bias"),
+ ("y_embedder.y_proj_xattn.0.weight", "condition_embedder.text_embedder.y_proj_xattn.0.weight"),
+ ("y_embedder.y_proj_xattn.0.bias", "condition_embedder.text_embedder.y_proj_xattn.0.bias"),
+ ("y_embedder.y_proj_adaln.0.weight", "condition_embedder.text_embedder.y_proj_adaln.weight"),
+ ("y_embedder.y_proj_adaln.0.bias", "condition_embedder.text_embedder.y_proj_adaln.bias"),
+ ("videodit_blocks.final_layernorm.weight", "norm_out.weight"),
+ ("videodit_blocks.final_layernorm.bias", "norm_out.bias"),
+ ("final_linear.linear.weight", "proj_out.weight"),
+ ("rope.bands", "rope.bands"),
+ ]
+
+ for src, dst in simple_maps:
+ try:
+ assign(src, dst)
+ except KeyError:
+ if not allow_partial:
+ raise
+
+ # Determine number of layers
+ if transformer is not None and hasattr(transformer, "config"):
+ num_layers = transformer.config.num_layers
+ else:
+ # Fallback: infer from checkpoint keys
+ num_layers = 0
+ for k in checkpoint.keys():
+ if k.startswith("videodit_blocks.layers."):
+ try:
+ idx = int(k.split(".")[3])
+ num_layers = max(num_layers, idx + 1)
+ except Exception:
+ pass
+
+ # Per-layer mappings
+ for i in range(num_layers):
+ layer_prefix = f"videodit_blocks.layers.{i}"
+ block_prefix = f"blocks.{i}"
+
+ layer_maps = [
+ (f"{layer_prefix}.self_attention.linear_qkv.layer_norm.weight", f"{block_prefix}.norm1.weight"),
+ (f"{layer_prefix}.self_attention.linear_qkv.layer_norm.bias", f"{block_prefix}.norm1.bias"),
+ (f"{layer_prefix}.self_attention.linear_qkv.q.weight", f"{block_prefix}.attn1.to_q.weight"),
+ (f"{layer_prefix}.self_attention.linear_qkv.k.weight", f"{block_prefix}.attn1.to_k.weight"),
+ (f"{layer_prefix}.self_attention.linear_qkv.v.weight", f"{block_prefix}.attn1.to_v.weight"),
+ (f"{layer_prefix}.self_attention.q_layernorm.weight", f"{block_prefix}.attn1.norm_q.weight"),
+ (f"{layer_prefix}.self_attention.q_layernorm.bias", f"{block_prefix}.attn1.norm_q.bias"),
+ (f"{layer_prefix}.self_attention.k_layernorm.weight", f"{block_prefix}.attn1.norm_k.weight"),
+ (f"{layer_prefix}.self_attention.k_layernorm.bias", f"{block_prefix}.attn1.norm_k.bias"),
+ (f"{layer_prefix}.self_attention.linear_qkv.qx.weight", f"{block_prefix}.attn2.to_q.weight"),
+ (f"{layer_prefix}.self_attention.q_layernorm_xattn.weight", f"{block_prefix}.attn2.norm_q.weight"),
+ (f"{layer_prefix}.self_attention.q_layernorm_xattn.bias", f"{block_prefix}.attn2.norm_q.bias"),
+ (f"{layer_prefix}.self_attention.k_layernorm_xattn.weight", f"{block_prefix}.attn2.norm_k.weight"),
+ (f"{layer_prefix}.self_attention.k_layernorm_xattn.bias", f"{block_prefix}.attn2.norm_k.bias"),
+ # Combined projection for concatenated [self_attn, cross_attn] outputs
+ (f"{layer_prefix}.self_attention.linear_proj.weight", f"{block_prefix}.attn_proj.weight"),
+ (f"{layer_prefix}.self_attn_post_norm.weight", f"{block_prefix}.norm2.weight"),
+ (f"{layer_prefix}.self_attn_post_norm.bias", f"{block_prefix}.norm2.bias"),
+ (f"{layer_prefix}.mlp.layer_norm.weight", f"{block_prefix}.norm3.weight"),
+ (f"{layer_prefix}.mlp.layer_norm.bias", f"{block_prefix}.norm3.bias"),
+ (f"{layer_prefix}.mlp.linear_fc1.weight", f"{block_prefix}.ffn.net.0.proj.weight"),
+ (f"{layer_prefix}.mlp.linear_fc2.weight", f"{block_prefix}.ffn.net.2.weight"),
+ (f"{layer_prefix}.mlp_post_norm.weight", f"{block_prefix}.norm4.weight"),
+ (f"{layer_prefix}.mlp_post_norm.bias", f"{block_prefix}.norm4.bias"),
+ (f"{layer_prefix}.ada_modulate_layer.proj.0.weight", f"{block_prefix}.ada_modulate_layer.1.weight"),
+ (f"{layer_prefix}.ada_modulate_layer.proj.0.bias", f"{block_prefix}.ada_modulate_layer.1.bias"),
+ ]
+
+ for src, dst in layer_maps:
+ try:
+ assign(src, dst)
+ except KeyError:
+ if not allow_partial:
+ raise
+
+ # special split for kv
+ try:
+ split_assign(
+ f"{layer_prefix}.self_attention.linear_kv_xattn.weight",
+ f"{block_prefix}.attn2.to_k.weight",
+ f"{block_prefix}.attn2.to_v.weight",
+ )
+ except KeyError:
+ if not allow_partial:
+ raise
+
+ print(f"Converted {len(converted_state_dict)} parameters")
+ report = {
+ "total_src_keys": len(checkpoint),
+ "used_src_keys": len(used_src_keys),
+ "missing_src_keys": missing_src_keys,
+ }
+ return converted_state_dict, report
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_type", type=str, default=None)
+ parser.add_argument("--checkpoint_path", type=str, default=None, help="Local MAGI-1 transformer checkpoint path")
+ parser.add_argument("--config_path", type=str, default=None, help="Optional JSON config for transformer")
+ parser.add_argument("--output_path", type=str, required=True)
+ parser.add_argument("--dtype", default="fp32", choices=["fp32", "fp16", "bf16", "none"])
+ parser.add_argument("--push_to_hub", action="store_true", help="If set, push to the Hub after conversion")
+ parser.add_argument("--repo_id", type=str, default=None, help="Repo ID to push to (when --push_to_hub is set)")
+ parser.add_argument("--allow_partial", action="store_true", help="Allow partial/loose state dict loading")
+ return parser.parse_args()
+
+
+DTYPE_MAPPING = {
+ "fp32": torch.float32,
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+}
+
+if __name__ == "__main__":
+ args = get_args()
+
+ if args.model_type is not None:
+ transformer = convert_magi1_transformer(args.model_type)
+ elif args.checkpoint_path is not None:
+ transformer = convert_magi1_transformer_checkpoint(
+ args.checkpoint_path, transformer_config_file=args.config_path, allow_partial=args.allow_partial
+ )
+ else:
+ raise ValueError("Provide either --model_type for HF download or --checkpoint_path for local conversion.")
+
+ # If user has specified "none", we keep the original dtypes of the state dict without any conversion
+ if args.dtype != "none":
+ dtype = DTYPE_MAPPING[args.dtype]
+ transformer.to(dtype)
+
+ # Save transformer directly to output path (subfolder 'transformer')
+ save_kwargs = {"safe_serialization": True, "max_shard_size": "5GB"}
+ save_dir = os.path.join(args.output_path, "transformer")
+ os.makedirs(save_dir, exist_ok=True)
+ if args.push_to_hub:
+ save_kwargs.update(
+ {
+ "push_to_hub": True,
+ "repo_id": (
+ args.repo_id
+ if args.repo_id is not None
+ else (f"tolgacangoz/{args.model_type}-Magi1Transformer" if args.model_type else "tolgacangoz/Magi1Transformer")
+ ),
+ }
+ )
+ transformer.save_pretrained(save_dir, **save_kwargs)
+
+ # Also write a minimal model_index.json for convenience when composing a pipeline later
+ index_path = os.path.join(args.output_path, "model_index.json")
+ index = {
+ "_class_name": "Magi1Pipeline",
+ "_diffusers_version": "0.0.0",
+ "transformer": ["transformer"],
+ "vae": None,
+ "text_encoder": None,
+ "tokenizer": None,
+ "scheduler": None,
+ }
+ with open(index_path, "w") as f:
+ json.dump(index, f, indent=2)
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index aa500b149441..2906cbc6c9d2 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -186,6 +186,7 @@
"AutoencoderKLCosmos",
"AutoencoderKLHunyuanVideo",
"AutoencoderKLLTXVideo",
+ "AutoencoderKLMagi1",
"AutoencoderKLMagvit",
"AutoencoderKLMochi",
"AutoencoderKLQwenImage",
@@ -225,6 +226,7 @@
"LTXVideoTransformer3DModel",
"Lumina2Transformer2DModel",
"LuminaNextDiT2DModel",
+ "Magi1Transformer3DModel",
"MochiTransformer3DModel",
"ModelMixin",
"MotionAdapter",
@@ -508,6 +510,9 @@
"Lumina2Text2ImgPipeline",
"LuminaPipeline",
"LuminaText2ImgPipeline",
+ "Magi1ImageToVideoPipeline",
+ "Magi1Pipeline",
+ "Magi1VideoToVideoPipeline",
"MarigoldDepthPipeline",
"MarigoldIntrinsicsPipeline",
"MarigoldNormalsPipeline",
@@ -880,6 +885,7 @@
AutoencoderKLCosmos,
AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo,
+ AutoencoderKLMagi1,
AutoencoderKLMagvit,
AutoencoderKLMochi,
AutoencoderKLQwenImage,
@@ -919,6 +925,7 @@
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
LuminaNextDiT2DModel,
+ Magi1Transformer3DModel,
MochiTransformer3DModel,
ModelMixin,
MotionAdapter,
@@ -1172,6 +1179,9 @@
Lumina2Text2ImgPipeline,
LuminaPipeline,
LuminaText2ImgPipeline,
+ Magi1ImageToVideoPipeline,
+ Magi1Pipeline,
+ Magi1VideoToVideoPipeline,
MarigoldDepthPipeline,
MarigoldIntrinsicsPipeline,
MarigoldNormalsPipeline,
diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py
index 48507aae038c..8072db3b8073 100644
--- a/src/diffusers/loaders/__init__.py
+++ b/src/diffusers/loaders/__init__.py
@@ -72,6 +72,7 @@ def text_encoder_attn_modules(text_encoder):
"FluxLoraLoaderMixin",
"CogVideoXLoraLoaderMixin",
"CogView4LoraLoaderMixin",
+ "Magi1LoraLoaderMixin",
"Mochi1LoraLoaderMixin",
"HunyuanVideoLoraLoaderMixin",
"SanaLoraLoaderMixin",
@@ -120,6 +121,7 @@ def text_encoder_attn_modules(text_encoder):
LoraLoaderMixin,
LTXVideoLoraLoaderMixin,
Lumina2LoraLoaderMixin,
+ Magi1LoraLoaderMixin,
Mochi1LoraLoaderMixin,
QwenImageLoraLoaderMixin,
SanaLoraLoaderMixin,
diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index 2bb6c0ea026e..97573024d664 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -3924,6 +3924,240 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
super().unfuse_lora(components=components, **kwargs)
+class Magi1LoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`Magi1Transformer3DModel`]. Specific to [`Magi1Pipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
+
+ state_dict, metadata = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
+ return out
+
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ kwargs["return_lora_metadata"] = True
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Magi1Transformer3DModel
+ def load_lora_into_transformer(
+ cls,
+ state_dict,
+ transformer,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
+ ):
+ """
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with StableDiffusion->Magi1
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not transformer_lora_layers:
+ raise ValueError("You must pass `transformer_lora_layers`.")
+
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an irreversible operation. If you need to unfuse, you'll need to reload the model.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
class WanLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index 8d029bf5d31c..168f66f2d376 100755
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -37,6 +37,7 @@
_import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"]
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
+ _import_structure["autoencoders.autoencoder_kl_magi1"] = ["AutoencoderKLMagi1"]
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
_import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
@@ -94,6 +95,7 @@
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
+ _import_structure["transformers.transformer_magi1"] = ["Magi1Transformer3DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
@@ -134,6 +136,7 @@
AutoencoderKLCosmos,
AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo,
+ AutoencoderKLMagi1,
AutoencoderKLMagvit,
AutoencoderKLMochi,
AutoencoderKLQwenImage,
@@ -188,6 +191,7 @@
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
LuminaNextDiT2DModel,
+ Magi1Transformer3DModel,
MochiTransformer3DModel,
OmniGenTransformer2DModel,
PixArtTransformer2DModel,
diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py
index e1694910997a..ff1c970ece65 100644
--- a/src/diffusers/models/attention_dispatch.py
+++ b/src/diffusers/models/attention_dispatch.py
@@ -31,6 +31,8 @@
is_flash_attn_available,
is_flash_attn_version,
is_kernels_available,
+ is_magi_attn_available,
+ is_magi_attn_version,
is_sageattention_available,
is_sageattention_version,
is_torch_npu_available,
@@ -51,6 +53,7 @@
_REQUIRED_FLEX_VERSION = "2.5.0"
_REQUIRED_XLA_VERSION = "2.2"
_REQUIRED_XFORMERS_VERSION = "0.0.29"
+_REQUIRED_MAGI_VERSION = "1.0.3"
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
@@ -59,7 +62,7 @@
_CAN_USE_NPU_ATTN = is_torch_npu_available()
_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION)
_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION)
-
+_CAN_USE_MAGI_ATTN = is_magi_attn_available() and is_magi_attn_version(">=", _REQUIRED_MAGI_VERSION)
if _CAN_USE_FLASH_ATTN:
from flash_attn import flash_attn_func, flash_attn_varlen_func
@@ -132,6 +135,11 @@
else:
xops = None
+if _CAN_USE_MAGI_ATTN:
+ from magi_attention.functional import flex_flash_attn_func
+else:
+ flex_flash_attn_func = None
+
# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4
if torch.__version__ >= "2.4.0":
_custom_op = torch.library.custom_op
@@ -202,6 +210,9 @@ class AttentionBackendName(str, Enum):
# `xformers`
XFORMERS = "xformers"
+ # `magi-attention`
+ MAGI = "magi"
+
class _AttentionBackendRegistry:
_backends = {}
@@ -450,6 +461,11 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
raise RuntimeError(
f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`."
)
+ elif backend == AttentionBackendName.MAGI:
+ if not _CAN_USE_MAGI_ATTN:
+ raise RuntimeError(
+ f"Magi Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `magi-attention>={_REQUIRED_MAGI_VERSION}`."
+ )
@functools.lru_cache(maxsize=128)
@@ -1952,3 +1968,176 @@ def _xformers_attention(
out = out.flatten(2, 3)
return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.MAGI,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _magi_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ q_ranges: torch.Tensor,
+ k_ranges: torch.Tensor,
+ max_seqlen_q: int,
+ max_seqlen_k: int,
+ cu_seqlens_q: torch.Tensor | None = None,
+ cu_seqlens_k: torch.Tensor | None = None,
+ attn_mask: torch.Tensor | None = None,
+ attn_type_map: torch.Tensor | None = None,
+ softmax_scale: float | None = None,
+ softcap: float = 0.0,
+ deterministic: bool = False,
+ sm_margin: int = 0,
+ disable_fwd_atomic_reduction: bool = False,
+ auto_range_merge: bool = False,
+ ref_block_size: tuple[int, int] | None = None,
+) -> torch.Tensor:
+ batch_size, seq_len_q, _, _ = query.shape
+ _, seq_len_kv, _, _ = key.shape
+
+ if attn_mask is not None:
+ attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
+
+ if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
+ _prepare_for_flash_attn_or_sage_varlen(
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
+ )
+ )
+ else:
+ seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
+ cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
+ cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
+
+ key_valid, value_valid = [], []
+ for b in range(batch_size):
+ valid_len = seqlens_k[b]
+ key_valid.append(key[b, :valid_len])
+ value_valid.append(value[b, :valid_len])
+
+ query_packed = query.flatten(0, 1)
+ key_packed = torch.cat(key_valid, dim=0)
+ value_packed = torch.cat(value_valid, dim=0)
+
+ out = flex_flash_attn_func(
+ q=query_packed,
+ k=key_packed,
+ v=value_packed,
+ q_ranges=q_ranges,
+ k_ranges=k_ranges,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ attn_type_map=attn_type_map,
+ softmax_scale=softmax_scale,
+ softcap=softcap,
+ deterministic=deterministic,
+ sm_margin=sm_margin,
+ disable_fwd_atomic_reduction=disable_fwd_atomic_reduction,
+ auto_range_merge=auto_range_merge,
+ ref_block_size=ref_block_size,
+ )
+ out = out.unflatten(0, (batch_size, -1))
+
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.MAGI,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape, _check_magi_attn_backend],
+)
+def _magi_varlen_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ return_lse: bool = False,
+ q_ranges: Optional[torch.Tensor] = None,
+ k_ranges: Optional[torch.Tensor] = None,
+ max_seqlen_q: Optional[int] = None,
+ max_seqlen_k: Optional[int] = None,
+ cu_seqlens_q: Optional[torch.Tensor] = None,
+ cu_seqlens_k: Optional[torch.Tensor] = None,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ """
+ MAGI varlen attention backend using flex_flash_attn_func from magi-attention library.
+
+ This backend supports variable-length sequences with q_ranges and k_ranges for
+ MAGI-1's autoregressive video generation pattern.
+
+ Args:
+ query: Query tensor [B, S_q, H, D]
+ key: Key tensor [B, S_kv, H, D] or packed [total_tokens, H, D]
+ value: Value tensor [B, S_kv, H, D] or packed [total_tokens, H, D]
+ q_ranges: Tensor of shape [num_ranges, 2] specifying [start, end) for each query range
+ k_ranges: Tensor of shape [num_ranges, 2] specifying [start, end) for each key range
+ max_seqlen_q: Maximum sequence length for queries
+ max_seqlen_k: Maximum sequence length for keys
+ cu_seqlens_q: Cumulative sequence lengths for queries (alternative to q_ranges)
+ cu_seqlens_k: Cumulative sequence lengths for keys (alternative to k_ranges)
+ """
+ # If q_ranges/k_ranges are provided, use them directly (MAGI-1 style)
+ if q_ranges is not None and k_ranges is not None:
+ # Flatten query/key/value if they're not already packed
+ if query.ndim == 4: # [B, S, H, D]
+ batch_size, seq_len_q, num_heads, head_dim = query.shape
+ query = query.flatten(0, 1) # [B*S, H, D]
+ key = key.flatten(0, 1)
+ value = value.flatten(0, 1)
+
+ # Call flex_flash_attn_func with q_ranges/k_ranges
+ out, _ = flex_flash_attn_func(
+ query,
+ key,
+ value,
+ q_ranges,
+ k_ranges,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ softmax_scale=scale,
+ deterministic=torch.are_deterministic_algorithms_enabled(),
+ disable_fwd_atomic_reduction=True,
+ )
+
+ # Unflatten output back to [B, S, H, D]
+ if query.ndim == 3: # was flattened from [B, S, H, D]
+ out = out.unflatten(0, (batch_size, seq_len_q))
+
+ return out
+
+ # Fallback to cu_seqlens if ranges not provided
+ elif cu_seqlens_q is not None and cu_seqlens_k is not None:
+ # Convert cu_seqlens to ranges
+ q_ranges = torch.cat([cu_seqlens_q[:-1].unsqueeze(1), cu_seqlens_q[1:].unsqueeze(1)], dim=1)
+ k_ranges = torch.cat([cu_seqlens_k[:-1].unsqueeze(1), cu_seqlens_k[1:].unsqueeze(1)], dim=1)
+
+ batch_size, seq_len_q, num_heads, head_dim = query.shape
+ query = query.flatten(0, 1)
+ key = key.flatten(0, 1)
+ value = value.flatten(0, 1)
+
+ out, _ = flex_flash_attn_func(
+ query,
+ key,
+ value,
+ q_ranges,
+ k_ranges,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ softmax_scale=scale,
+ deterministic=torch.are_deterministic_algorithms_enabled(),
+ disable_fwd_atomic_reduction=True,
+ )
+
+ out = out.unflatten(0, (batch_size, seq_len_q))
+ return out
+
+ else:
+ raise ValueError(
+ "MAGI attention backend requires either (q_ranges and k_ranges) or (cu_seqlens_q and cu_seqlens_k)"
+ )
diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py
index c008a45298e8..02f6826f51d5 100644
--- a/src/diffusers/models/autoencoders/__init__.py
+++ b/src/diffusers/models/autoencoders/__init__.py
@@ -6,6 +6,7 @@
from .autoencoder_kl_cosmos import AutoencoderKLCosmos
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
+from .autoencoder_kl_magi1 import AutoencoderKLMagi1
from .autoencoder_kl_magvit import AutoencoderKLMagvit
from .autoencoder_kl_mochi import AutoencoderKLMochi
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_magi1.py b/src/diffusers/models/autoencoders/autoencoder_kl_magi1.py
new file mode 100644
index 000000000000..7b01c685e8a3
--- /dev/null
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_magi1.py
@@ -0,0 +1,1024 @@
+# Copyright 2025 The Sand AI Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin
+from ...utils import logging
+from ...utils.accelerate_utils import apply_forward_hook
+from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
+from ..cache_utils import CacheMixin
+from ..modeling_outputs import AutoencoderKLOutput
+from ..modeling_utils import ModelMixin
+from .vae import DecoderOutput, DiagonalGaussianDistribution
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def resize_pos_embed(posemb, src_shape, target_shape):
+ posemb = posemb.reshape(1, src_shape[0], src_shape[1], src_shape[2], -1)
+ posemb = posemb.permute(0, 4, 1, 2, 3)
+ posemb = nn.functional.interpolate(posemb, size=target_shape, mode="trilinear", align_corners=False)
+ posemb = posemb.permute(0, 2, 3, 4, 1)
+ posemb = posemb.reshape(1, target_shape[0] * target_shape[1] * target_shape[2], -1)
+ return posemb
+
+
+# Copied from diffusers.models.transformers.transformer_wan._get_qkv_projections
+def _get_qkv_projections(attn: "Magi1VAEAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
+ # encoder_hidden_states is only passed for cross-attention
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ if attn.fused_projections:
+ if attn.cross_attention_dim_head is None:
+ # In self-attention layers, we can fuse the entire QKV projection into a single linear
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
+ else:
+ # In cross-attention layers, we can only fuse the KV projections into a single linear
+ query = attn.to_q(hidden_states)
+ key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
+ else:
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+ return query, key, value
+
+
+class Magi1VAELayerNorm(nn.Module):
+ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
+ super(Magi1VAELayerNorm, self).__init__()
+ self.normalized_shape = normalized_shape
+ self.eps = eps
+ self.elementwise_affine = elementwise_affine
+
+ def forward(self, x):
+ mean = x.mean(dim=-1, keepdim=True)
+ std = x.std(dim=-1, keepdim=True, unbiased=False)
+
+ x_normalized = (x - mean) / (std + self.eps)
+
+ return x_normalized
+
+
+class Magi1VAEAttnProcessor:
+ _attention_backend = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("Magi1VAEAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: "Magi1VAEAttention",
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ batch_size, time_height_width, channels = hidden_states.size()
+
+ query, key, value = _get_qkv_projections(attn, hidden_states, None)
+
+ qkv = torch.cat((query, key, value), dim=2)
+ qkv = qkv.reshape(batch_size, time_height_width, 3, attn.heads, channels // attn.heads)
+ qkv = attn.qkv_norm(qkv)
+ query, key, value = qkv.chunk(3, dim=2)
+
+ # Remove the extra dimension from chunking
+ # Shape: (batch_size, time_height_width, num_heads, head_dim)
+ query = query.squeeze(2)
+ key = key.squeeze(2)
+ value = value.squeeze(2)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+class Magi1VAEAttention(torch.nn.Module, AttentionModuleMixin):
+ _default_processor_cls = Magi1VAEAttnProcessor
+ _available_processors = [Magi1VAEAttnProcessor]
+
+ def __init__(
+ self,
+ dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ eps: float = 1e-5,
+ dropout: float = 0.0,
+ added_kv_proj_dim: Optional[int] = None,
+ cross_attention_dim_head: Optional[int] = None,
+ processor=None,
+ ):
+ super().__init__()
+
+ self.inner_dim = dim_head * heads
+ self.heads = heads
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.cross_attention_dim_head = cross_attention_dim_head
+ self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
+
+ self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
+ self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
+ self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
+ self.to_out = torch.nn.ModuleList(
+ [
+ torch.nn.Linear(self.inner_dim, dim, bias=True),
+ torch.nn.Dropout(dropout),
+ ]
+ )
+ self.qkv_norm = Magi1VAELayerNorm(dim // heads, elementwise_affine=False)
+
+ self.add_k_proj = self.add_v_proj = None
+ if added_kv_proj_dim is not None:
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
+ self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
+
+ self.set_processor(processor)
+
+ # Copied from diffusers.models.transformers.transformer_wan.WanAttention.fuse_projections
+ def fuse_projections(self):
+ if getattr(self, "fused_projections", False):
+ return
+
+ if self.cross_attention_dim_head is None:
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_qkv = nn.Linear(in_features, out_features, bias=True)
+ self.to_qkv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+ else:
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_kv = nn.Linear(in_features, out_features, bias=True)
+ self.to_kv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+
+ if self.added_kv_proj_dim is not None:
+ concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
+ concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
+ self.to_added_kv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+
+ self.fused_projections = True
+
+ @torch.no_grad()
+ # Copied from diffusers.models.transformers.transformer_wan.WanAttention.unfuse_projections
+ def unfuse_projections(self):
+ if not getattr(self, "fused_projections", False):
+ return
+
+ if hasattr(self, "to_qkv"):
+ delattr(self, "to_qkv")
+ if hasattr(self, "to_kv"):
+ delattr(self, "to_kv")
+ if hasattr(self, "to_added_kv"):
+ delattr(self, "to_added_kv")
+
+ self.fused_projections = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ return self.processor(self, hidden_states, attention_mask, rotary_emb, **kwargs)
+
+
+class Magi1VAETransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ ffn_dim: int = 4 * 1024,
+ eps: float = 1e-6,
+ ):
+ super().__init__()
+ self.attn = Magi1VAEAttention(
+ dim=dim,
+ heads=num_heads,
+ dim_head=dim // num_heads,
+ eps=eps,
+ cross_attention_dim_head=None,
+ processor=Magi1VAEAttnProcessor(),
+ )
+
+ self.norm = nn.LayerNorm(dim)
+ self.proj_out = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu")
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor):
+ hidden_states = hidden_states + self.attn(hidden_states)
+ hidden_states = hidden_states + self.proj_out(self.norm(hidden_states))
+ return hidden_states
+
+
+class Magi1Encoder3d(nn.Module):
+ def __init__(
+ self,
+ inner_dim=128,
+ z_dim=4,
+ patch_size: Tuple[int] = (1, 2, 2),
+ num_frames: int = 16,
+ height: int = 256,
+ width: int = 256,
+ num_attention_heads: int = 40,
+ ffn_dim: int = 4 * 1024,
+ num_layers: int = 24,
+ eps: float = 1e-6,
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.height = height
+ self.width = width
+ self.num_frames = num_frames
+
+ # 1. Patch & position embedding
+ self.patch_embedding = nn.Conv3d(3, inner_dim, kernel_size=patch_size, stride=patch_size)
+ self.patch_size = patch_size
+
+ self.cls_token_nums = 1
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, inner_dim))
+ # `generator` as a parameter?
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
+
+ p_t, p_h, p_w = patch_size
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p_h
+ post_patch_width = width // p_w
+ num_patches = post_patch_num_frames * post_patch_height * post_patch_width
+
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.cls_token_nums, inner_dim))
+ self.pos_drop = nn.Dropout(p=0.0)
+
+ # 3. Transformer blocks
+ self.blocks = nn.ModuleList(
+ [
+ Magi1VAETransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ ffn_dim,
+ eps,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # output blocks
+ self.norm_out = nn.LayerNorm(inner_dim)
+ self.linear_out = nn.Linear(inner_dim, z_dim * 2)
+
+ # `generator` as a parameter?
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x):
+ B = x.shape[0]
+ # B C T H W -> B C T/pT H/pH W//pW
+ x = self.patch_embedding(x)
+ latentT, latentH, latentW = x.shape[2], x.shape[3], x.shape[4]
+ # B C T/pT H/pH W//pW -> B (T/pT H/pH W//pW) C
+ x = x.flatten(2).transpose(1, 2)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ if latentT != self.patch_size[0] or latentH != self.patch_size[1] or latentW != self.patch_size[2]:
+ pos_embed = resize_pos_embed(
+ self.pos_embed[:, 1:, :],
+ src_shape=(
+ self.num_frames // self.patch_size[0],
+ self.height // self.patch_size[1],
+ self.width // self.patch_size[2],
+ ),
+ target_shape=(latentT, latentH, latentW),
+ )
+ pos_embed = torch.cat((self.pos_embed[:, 0:1, :], pos_embed), dim=1)
+ else:
+ pos_embed = self.pos_embed
+
+ x = x + pos_embed
+ x = self.pos_drop(x)
+
+ ## transformer blocks
+ for block in self.blocks:
+ x = block(x)
+
+ ## head
+ x = self.norm_out(x)
+ x = x[:, 1:] # remove cls_token
+ x = self.linear_out(x)
+
+ # B L C - > B , lT, lH, lW, zC (where zC is now z_dim * 2)
+ x = x.reshape(B, latentT, latentH, latentW, self.z_dim * 2)
+
+ # B , lT, lH, lW, zC -> B, zC, lT, lH, lW
+ x = x.permute(0, 4, 1, 2, 3)
+
+ return x
+
+
+class Magi1Decoder3d(nn.Module):
+ def __init__(
+ self,
+ inner_dim=1024,
+ z_dim=16,
+ patch_size: Tuple[int] = (4, 8, 8),
+ num_frames: int = 16,
+ height: int = 256,
+ width: int = 256,
+ num_attention_heads: int = 16,
+ ffn_dim: int = 4 * 1024,
+ num_layers: int = 24,
+ eps: float = 1e-6,
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.patch_size = patch_size
+ self.height = height
+ self.width = width
+ self.num_frames = num_frames
+
+ # init block
+ self.proj_in = nn.Linear(z_dim, inner_dim)
+
+ self.cls_token_nums = 1
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, inner_dim))
+ # `generator` as a parameter?
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
+
+ p_t, p_h, p_w = patch_size
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p_h
+ post_patch_width = width // p_w
+ num_patches = post_patch_num_frames * post_patch_height * post_patch_width
+
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.cls_token_nums, inner_dim))
+ self.pos_drop = nn.Dropout(p=0.0)
+
+ # 3. Transformer blocks
+ self.blocks = nn.ModuleList(
+ [
+ Magi1VAETransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ ffn_dim,
+ eps,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # output blocks
+ self.norm_out = nn.LayerNorm(inner_dim)
+ self.unpatch_channels = inner_dim // (patch_size[0] * patch_size[1] * patch_size[2])
+ self.conv_out = nn.Conv3d(self.unpatch_channels, 3, 3, padding=1)
+
+ # `generator` as a parameter?
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x):
+ B, C, latentT, latentH, latentW = x.shape
+ x = x.permute(0, 2, 3, 4, 1)
+
+ x = x.reshape(B, -1, C)
+
+ x = self.proj_in(x)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ if latentT != self.patch_size[0] or latentH != self.patch_size[1] or latentW != self.patch_size[2]:
+ pos_embed = resize_pos_embed(
+ self.pos_embed[:, 1:, :],
+ src_shape=(
+ self.num_frames // self.patch_size[0],
+ self.height // self.patch_size[1],
+ self.width // self.patch_size[2],
+ ),
+ target_shape=(latentT, latentH, latentW),
+ )
+ pos_embed = torch.cat((self.pos_embed[:, 0:1, :], pos_embed), dim=1)
+ else:
+ pos_embed = self.pos_embed
+
+ x = x + pos_embed
+ x = self.pos_drop(x)
+
+ ## transformer blocks
+ for block in self.blocks:
+ x = block(x)
+
+ ## head
+ x = self.norm_out(x)
+ x = x[:, 1:] # remove cls_token
+
+ x = x.reshape(
+ B,
+ latentT,
+ latentH,
+ latentW,
+ self.patch_size[0],
+ self.patch_size[1],
+ self.patch_size[2],
+ self.unpatch_channels,
+ )
+ # Rearrange from (B, lT, lH, lW, pT, pH, pW, C) to (B, C, lT*pT, lH*pH, lW*pW)
+ x = x.permute(0, 7, 1, 4, 2, 5, 3, 6) # (B, C, lT, pT, lH, pH, lW, pW)
+ x = x.reshape(
+ B,
+ self.unpatch_channels,
+ latentT * self.patch_size[0],
+ latentH * self.patch_size[1],
+ latentW * self.patch_size[2],
+ )
+
+ x = self.conv_out(x)
+ return x
+
+
+class AutoencoderKLMagi1(ModelMixin, ConfigMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin):
+ r"""
+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
+ Introduced in [Magi1](https://arxiv.org/abs/2505.13211).
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+ """
+
+ _supports_gradient_checkpointing = False
+ _skip_layerwise_casting_patterns = ["patch_embedding", "norm"]
+ _no_split_modules = ["Magi1VAETransformerBlock"]
+ # _keep_in_fp32_modules = ["qkv_norm", "norm1", "norm2"]
+ _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: Tuple[int] = (4, 8, 8),
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 64,
+ z_dim: int = 16,
+ height: int = 256,
+ width: int = 256,
+ num_frames: int = 16,
+ ffn_dim: int = 4 * 1024,
+ num_layers: int = 24,
+ eps: float = 1e-6,
+ latents_mean: List[float] = [
+ -0.7571,
+ -0.7089,
+ -0.9113,
+ 0.1075,
+ -0.1745,
+ 0.9653,
+ -0.1517,
+ 1.5508,
+ 0.4134,
+ -0.0715,
+ 0.5517,
+ -0.3632,
+ -0.1922,
+ -0.9497,
+ 0.2503,
+ -0.2921,
+ ],
+ latents_std: List[float] = [
+ 2.8184,
+ 1.4541,
+ 2.3275,
+ 2.6558,
+ 1.2196,
+ 1.7708,
+ 2.6052,
+ 2.0743,
+ 3.2687,
+ 2.1526,
+ 2.8652,
+ 1.5579,
+ 1.6382,
+ 1.1253,
+ 2.8251,
+ 1.9160,
+ ],
+ ) -> None:
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+ self.z_dim = z_dim
+
+ self.encoder = Magi1Encoder3d(
+ inner_dim,
+ z_dim,
+ patch_size,
+ num_frames,
+ height,
+ width,
+ num_attention_heads,
+ ffn_dim,
+ num_layers,
+ eps,
+ )
+
+ self.decoder = Magi1Decoder3d(
+ inner_dim,
+ z_dim,
+ patch_size,
+ num_frames,
+ height,
+ width,
+ num_attention_heads,
+ ffn_dim,
+ num_layers,
+ eps,
+ )
+
+ self.spatial_compression_ratio = patch_size[1] or patch_size[2]
+ self.temporal_compression_ratio = patch_size[0]
+
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
+ # to perform decoding of a single video latent at a time.
+ self.use_slicing = False
+
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
+ # intermediate tiles together, the memory requirement can be lowered.
+ self.use_tiling = False
+
+ # The minimal tile height and width for spatial tiling to be used
+ self.tile_sample_min_height = 256
+ self.tile_sample_min_width = 256
+
+ # The minimal tile length for temporal tiling to be used
+ self.tile_sample_min_length = 16
+
+ # The minimal distance between two spatial tiles
+ self.tile_sample_stride_height = 192
+ self.tile_sample_stride_width = 192
+
+ # The minimal distance between two temporal tiles
+ self.tile_sample_stride_length = 16
+
+ def enable_tiling(
+ self,
+ tile_sample_min_height: Optional[int] = None,
+ tile_sample_min_width: Optional[int] = None,
+ tile_sample_min_length: Optional[int] = None,
+ tile_sample_stride_height: Optional[float] = None,
+ tile_sample_stride_width: Optional[float] = None,
+ tile_sample_stride_length: Optional[float] = None,
+ ) -> None:
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+
+ Args:
+ tile_sample_min_height (`int`, *optional*):
+ The minimum height required for a sample to be separated into tiles across the height dimension.
+ tile_sample_min_width (`int`, *optional*):
+ The minimum width required for a sample to be separated into tiles across the width dimension.
+ tile_sample_stride_height (`int`, *optional*):
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
+ no tiling artifacts produced across the height dimension.
+ tile_sample_stride_width (`int`, *optional*):
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
+ artifacts produced across the width dimension.
+ """
+ self.use_tiling = True
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
+ self.tile_sample_min_length = tile_sample_min_length or self.tile_sample_min_length
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
+ self.tile_sample_stride_length = tile_sample_stride_length or self.tile_sample_stride_length
+
+ def disable_tiling(self) -> None:
+ r"""
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_tiling = False
+
+ def enable_slicing(self) -> None:
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ def disable_slicing(self) -> None:
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ def _encode(self, x: torch.Tensor):
+ _, _, num_frames, height, width = x.shape
+
+ if self.use_tiling and (
+ width > self.tile_sample_min_width
+ or height > self.tile_sample_min_height
+ or num_frames > self.tile_sample_min_length
+ ):
+ return self.tiled_encode(x)
+
+ out = self.encoder(x)
+
+ return out
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.Tensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ r"""
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded videos. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self._encode(x)
+ posterior = DiagonalGaussianDistribution(h)
+
+ if not return_dict:
+ return (posterior,)
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.Tensor, return_dict: bool = True):
+ _, _, num_frames, height, width = z.shape
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_min_length = self.tile_sample_min_length // self.temporal_compression_ratio
+
+ if self.use_tiling and (
+ width > tile_latent_min_width or height > tile_latent_min_height or num_frames > tile_latent_min_length
+ ):
+ return self.tiled_decode(z, return_dict=return_dict)
+
+ out = self.decoder(z)
+
+ if not return_dict:
+ return (out,)
+
+ return DecoderOutput(sample=out)
+
+ @apply_forward_hook
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Decode a batch of images.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z).sample
+
+ if not return_dict:
+ return (decoded,)
+ return DecoderOutput(sample=decoded)
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int, power: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
+ for y in range(blend_extent):
+ w_a, w_b = 1 - y / blend_extent, y / blend_extent
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (w_a**power) + b[:, :, :, y, :] * (w_b**power)
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int, power: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
+ for x in range(blend_extent):
+ w_a, w_b = 1.0 - x / blend_extent, x / blend_extent
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (w_a**power) + b[:, :, :, :, x] * (w_b**power)
+ return b
+
+ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int, power: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
+ for t in range(blend_extent):
+ w_a, w_b = 1.0 - t / blend_extent, t / blend_extent
+ b[:, :, t, :, :] = a[:, :, -blend_extent + t, :, :] * (w_a**power) + b[:, :, t, :, :] * (w_b**power)
+ return b
+
+ def _encode_tile(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Encode a single tile.
+ """
+ N, C, T, H, W = x.shape
+
+ if T == 1 and self.temporal_compression_ratio > 1:
+ x = x.expand(-1, -1, self.temporal_compression_ratio, -1, -1)
+ x = self.encoder(x)
+ # After temporal expansion and encoding, select only the first temporal frame.
+ x = x[:, :, :1]
+ return x
+ else:
+ x = self.encoder(x)
+ return x
+
+ def _decode_tile(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Decode a single tile.
+ """
+ N, C, T, H, W = x.shape
+
+ x = self.decoder(x)
+ return x[:, :, :1, :, :] if T == 1 else x
+
+ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
+ r"""Encode a batch of images using a tiled encoder.
+
+ Args:
+ x (`torch.Tensor`): Input batch of videos.
+
+ Returns:
+ `torch.Tensor`:
+ The latent representation of the encoded videos.
+ """
+ B, C, num_frames, height, width = x.shape
+
+ # Latent sizes after compression
+ latent_height = height // self.spatial_compression_ratio
+ latent_width = width // self.spatial_compression_ratio
+ latent_length = num_frames // self.temporal_compression_ratio
+
+ # Tile latent sizes / strides
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_min_length = self.tile_sample_min_length // self.temporal_compression_ratio
+
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+ tile_latent_stride_length = self.tile_sample_stride_length // self.temporal_compression_ratio
+
+ # Overlap (blend) sizes in latent space
+ blend_height = tile_latent_min_height - tile_latent_stride_height
+ blend_width = tile_latent_min_width - tile_latent_stride_width
+ blend_length = tile_latent_min_length - tile_latent_stride_length
+
+ # Split x into overlapping tiles and encode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ times = []
+ for t in range(0, num_frames, self.tile_sample_stride_length):
+ rows = []
+ for i in range(0, height, self.tile_sample_stride_height):
+ row = []
+ for j in range(0, width, self.tile_sample_stride_width):
+ tile = x[
+ :,
+ :,
+ t : t + self.tile_sample_min_length,
+ i : i + self.tile_sample_min_height,
+ j : j + self.tile_sample_min_width,
+ ]
+ h_tile = self._encode_tile(tile)
+ # Original implementation samples here and blends the latents.
+ # Instead we're keeping moments (mu and logvar) and blend them.
+ row.append(h_tile)
+ rows.append(row)
+ times.append(rows)
+
+ # Calculate global blending order because blending is not commutative here.
+ nT = len(times)
+ nH = len(times[0]) if nT else 0
+ nW = len(times[0][0]) if nH else 0
+ idx_numel = []
+ for tt in range(nT):
+ for ii in range(nH):
+ for jj in range(nW):
+ idx_numel.append(((tt, ii, jj), times[tt][ii][jj].numel()))
+ global_order = [idx for (idx, _) in sorted(idx_numel, key=lambda kv: kv[1], reverse=True)]
+
+ result_grid = [[[None for _ in range(nW)] for _ in range(nH)] for _ in range(nT)]
+ for t_idx, i_idx, j_idx in global_order:
+ rows = times[t_idx]
+ row = rows[i_idx]
+ h = row[j_idx]
+
+ # Separate the mu and the logvar because mu needs to be blended linearly
+ # but logvar needs to be blended quadratically to obtain numerical equivalence
+ # so that the overall distribution is preserved
+ mu, logvar = h[:, : self.z_dim], h[:, self.z_dim :]
+ var = logvar.exp()
+
+ # Blend the prev tile, the above tile and the left tile
+ # to the current tile and add the current tile to the result grid
+ if t_idx > 0:
+ h_tile = times[t_idx - 1][i_idx][j_idx]
+ mu_prev, logvar_prev = h_tile[:, : self.z_dim], h_tile[:, self.z_dim :]
+ var_prev = logvar_prev.exp()
+ mu = self.blend_t(mu_prev, mu, blend_length, power=1)
+ var = self.blend_t(var_prev, var, blend_length, power=2)
+
+ if i_idx > 0:
+ h_tile = rows[i_idx - 1][j_idx]
+ mu_up, logvar_up = h_tile[:, : self.z_dim], h_tile[:, self.z_dim :]
+ var_up = logvar_up.exp()
+ mu = self.blend_v(mu_up, mu, blend_height, power=1)
+ var = self.blend_v(var_up, var, blend_height, power=2)
+
+ if j_idx > 0:
+ h_tile = row[j_idx - 1]
+ mu_left, logvar_left = h_tile[:, : self.z_dim], h_tile[:, self.z_dim :]
+ var_left = logvar_left.exp()
+ mu = self.blend_h(mu_left, mu, blend_width, power=1)
+ var = self.blend_h(var_left, var, blend_width, power=2)
+
+ logvar = var.clamp_min(1e-12).log()
+ h_blended = torch.cat([mu, logvar], dim=1)
+ h_core = h_blended[
+ :,
+ :,
+ :tile_latent_stride_length,
+ :tile_latent_stride_height,
+ :tile_latent_stride_width,
+ ]
+ result_grid[t_idx][i_idx][j_idx] = h_core
+
+ # Stitch blended tiles to spatially correct places
+ result_times = []
+ for t_idx in range(nT):
+ result_rows = []
+ for i_idx in range(nH):
+ result_row = []
+ for j_idx in range(nW):
+ result_row.append(result_grid[t_idx][i_idx][j_idx])
+ result_rows.append(torch.cat(result_row, dim=4))
+ result_times.append(torch.cat(result_rows, dim=3))
+
+ h_full = torch.cat(result_times, dim=2)[:, :, :latent_length, :latent_height, :latent_width]
+ return h_full
+
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Decode a batch of images using a tiled decoder.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ B, C, num_frames, height, width = z.shape
+ sample_height = height * self.spatial_compression_ratio
+ sample_width = width * self.spatial_compression_ratio
+ sample_length = num_frames * self.temporal_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_min_length = self.tile_sample_min_length // self.temporal_compression_ratio
+
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+ tile_latent_stride_length = self.tile_sample_stride_length // self.temporal_compression_ratio
+
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
+ blend_length = self.tile_sample_min_length - self.tile_sample_stride_length
+
+ # Split z into overlapping tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ times = []
+ for t in range(0, num_frames, tile_latent_stride_length):
+ rows = []
+ for i in range(0, height, tile_latent_stride_height):
+ row = []
+ for j in range(0, width, tile_latent_stride_width):
+ tile = z[
+ :,
+ :,
+ t : t + tile_latent_min_length,
+ i : i + tile_latent_min_height,
+ j : j + tile_latent_min_width,
+ ]
+ decoded = self._decode_tile(tile)
+ row.append(decoded)
+ rows.append(row)
+ times.append(rows)
+
+ result_times = []
+ for t in range(len(times)):
+ result_rows = []
+ for i in range(len(times[t])):
+ result_row = []
+ for j in range(len(times[t][i])):
+ # Clone the current decoded tile to ensure blending uses an unmodified copy of the tile.
+ tile = times[t][i][j].clone()
+
+ if t > 0:
+ tile = self.blend_t(times[t - 1][i][j], tile, blend_length, power=1)
+ if i > 0:
+ tile = self.blend_v(times[t][i - 1][j], tile, blend_height, power=1)
+ if j > 0:
+ tile = self.blend_h(times[t][i][j - 1], tile, blend_width, power=1)
+
+ result_row.append(
+ tile[
+ :,
+ :,
+ : self.tile_sample_stride_length,
+ : self.tile_sample_stride_height,
+ : self.tile_sample_stride_width,
+ ]
+ )
+
+ result_rows.append(torch.cat(result_row, dim=4))
+ result_times.append(torch.cat(result_rows, dim=3))
+ dec = torch.cat(result_times, dim=2)[:, :, :sample_length, :sample_height, :sample_width]
+
+ if not return_dict:
+ return (dec,)
+ return DecoderOutput(sample=dec)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sample_posterior: bool = True,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ """
+ Args:
+ sample (`torch.Tensor`): Input sample.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z, return_dict=return_dict)
+ return dec
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index b51f5d7aec25..c91d7004d3f2 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -830,6 +830,8 @@ def get_3d_rotary_pos_embed(
grid_type: str = "linspace",
max_size: Optional[Tuple[int, int]] = None,
device: Optional[torch.device] = None,
+ center_grid_hw_indices: bool = False,
+ equal_split_ratio: Optional[int] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
RoPE for video tokens with 3D structure.
@@ -876,10 +878,19 @@ def get_3d_rotary_pos_embed(
else:
raise ValueError("Invalid value passed for `grid_type`.")
- # Compute dimensions for each axis
- dim_t = embed_dim // 4
- dim_h = embed_dim // 8 * 3
- dim_w = embed_dim // 8 * 3
+ if center_grid_hw_indices:
+ # Center the grid height and width indices around zero
+ grid_h = grid_h - grid_h.max() / 2
+ grid_w = grid_w - grid_w.max() / 2
+
+ if equal_split_ratio is None:
+ dim_t = embed_dim // 4
+ dim_h = embed_dim // 8 * 3
+ dim_w = embed_dim // 8 * 3
+ else:
+ dim_t = embed_dim // equal_split_ratio
+ dim_h = embed_dim // equal_split_ratio
+ dim_w = embed_dim // equal_split_ratio
# Temporal frequencies
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, theta=theta, use_real=True)
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index 6b80ea6c82a5..94a80ba8daac 100755
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -30,6 +30,7 @@
from .transformer_kandinsky import Kandinsky5Transformer3DModel
from .transformer_ltx import LTXVideoTransformer3DModel
from .transformer_lumina2 import Lumina2Transformer2DModel
+ from .transformer_magi1 import Magi1Transformer3DModel
from .transformer_mochi import MochiTransformer3DModel
from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_qwenimage import QwenImageTransformer2DModel
diff --git a/src/diffusers/models/transformers/transformer_magi1.py b/src/diffusers/models/transformers/transformer_magi1.py
new file mode 100644
index 000000000000..2040e986be83
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_magi1.py
@@ -0,0 +1,896 @@
+# Copyright 2025 The MAGI Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, is_kernels_available, logging, scale_lora_layers, unscale_lora_layers
+from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
+from ..attention_dispatch import AttentionBackendName, dispatch_attention_fn
+from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
+from ..cache_utils import CacheMixin
+from ..embeddings import TimestepEmbedding, Timesteps
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin, get_parameter_dtype
+from ..normalization import FP32LayerNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def _get_qkv_projections(attn: "Magi1Attention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
+ # encoder_hidden_states is only passed for cross-attention
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ if attn.fused_projections:
+ if attn.cross_attention_dim_head is None:
+ # In self-attention layers, we can fuse the entire QKV projection into a single linear
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
+ else:
+ # In cross-attention layers, we can only fuse the KV projections into a single linear
+ query = attn.to_q(hidden_states)
+ key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
+ else:
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+ return query, key, value
+
+
+def range_mod_pytorch(x, c_mapping, gatings):
+ """
+ PyTorch implementation of range_mod_triton. # TODO: Ensure that this implementation is correct and matches the
+ range_mod_triton implementation.
+
+ Inputs:
+ x: (s, b, h). Tensor of inputs embedding (images or latent representations of images) c_mapping: (s, b). Tensor
+ of condition map gatings: (b, denoising_range_num, h). Tensor of condition embedding
+ """
+ s, b, h = x.shape
+
+ # Flatten x and c_mapping to 2D for easier indexing
+ x_flat = x.transpose(0, 1).flatten(0, 1) # (s*b, h)
+ c_mapping_flat = c_mapping.transpose(0, 1).flatten(0, 1) # (s*b,)
+ gatings_flat = gatings.flatten(0, 1) # (b*denoising_range_num, h)
+
+ # Use advanced indexing to select the appropriate gating for each row
+ # c_mapping_flat contains indices into gatings_flat
+ selected_gatings = gatings_flat[c_mapping_flat] # (s*b, h)
+
+ # Element-wise multiplication
+ y_flat = x_flat * selected_gatings # (s*b, h)
+
+ # Reshape back to original dimensions
+ y = y_flat.reshape(b, s, h).transpose(0, 1) # (s, b, h)
+
+ return y
+
+
+if is_kernels_available():
+ from kernels import use_kernel_forward_from_hub
+
+ range_mod_pytorch = use_kernel_forward_from_hub("range_mod_triton")(range_mod_pytorch)
+
+
+class Magi1AttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("Magi1AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: "Magi1Attention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> torch.Tensor:
+ attention_kwargs = attention_kwargs or {}
+
+ query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
+
+ query = query.unflatten(2, (attn.heads, -1))
+ kv_h = attn.kv_heads if attn.kv_heads is not None else attn.heads
+ key = key.unflatten(2, (kv_h, -1))
+ value = value.unflatten(2, (kv_h, -1))
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ if rotary_emb is not None and attn.cross_attention_dim_head is None:
+
+ def apply_rotary_emb(x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> torch.Tensor:
+ # x: [B, S, H, D]
+ x1, x2 = x.unflatten(-1, (-1, 2)).unbind(-1) # [B, S, H, D/2]
+ # Broadcast cos/sin to [1, S, 1, D/2]
+ cos = freqs_cos.view(1, -1, 1, freqs_cos.shape[-1])[..., 0::2]
+ sin = freqs_sin.view(1, -1, 1, freqs_sin.shape[-1])[..., 1::2]
+ out = torch.empty_like(x)
+ out[..., 0::2] = x1 * cos - x2 * sin
+ out[..., 1::2] = x1 * sin + x2 * cos
+ return out.type_as(x)
+
+ query = apply_rotary_emb(query, *rotary_emb)
+ key = apply_rotary_emb(key, *rotary_emb)
+
+ kv_heads = kv_h
+ n_rep = attn.heads // kv_heads
+ if n_rep > 1:
+ key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
+ value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
+
+ # Use MAGI backend if varlen parameters are provided
+ backend = self._attention_backend
+ if attention_kwargs.get("q_ranges") is not None:
+ backend = AttentionBackendName.MAGI
+
+ out = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ enable_gqa=True,
+ backend=backend,
+ parallel_config=self._parallel_config,
+ attention_kwargs=attention_kwargs,
+ )
+ out = out.flatten(2, 3).type_as(hidden_states)
+ return out
+
+
+class Magi1Attention(torch.nn.Module, AttentionModuleMixin):
+ _default_processor_cls = Magi1AttnProcessor
+ _available_processors = [Magi1AttnProcessor]
+
+ def __init__(
+ self,
+ dim: int,
+ heads: int = 24,
+ kv_heads: Optional[int] = None,
+ dim_head: int = 128,
+ eps: float = 1e-6,
+ dropout: float = 0.0,
+ added_kv_proj_dim: Optional[int] = None,
+ cross_attention_dim_head: Optional[int] = None,
+ processor=None,
+ is_cross_attention=None,
+ ):
+ super().__init__()
+
+ self.inner_dim = dim_head * heads
+ self.heads = heads
+ self.kv_heads = kv_heads if kv_heads is not None else heads
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.cross_attention_dim_head = cross_attention_dim_head
+ self.kv_inner_dim = dim_head * self.kv_heads
+
+ self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=False)
+ self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=False)
+ self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=False)
+ # Note: Output projection is handled in Magi1TransformerBlock to match original architecture
+ # where [self_attn, cross_attn] outputs are concatenated, rearranged, then projected together
+ self.norm_q = FP32LayerNorm(dim_head, eps)
+ self.norm_k = FP32LayerNorm(dim_head, eps)
+
+ self.add_k_proj = self.add_v_proj = None
+ if added_kv_proj_dim is not None:
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
+ self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
+
+ self.is_cross_attention = cross_attention_dim_head is not None
+
+ self.set_processor(processor)
+
+ def fuse_projections(self):
+ if getattr(self, "fused_projections", False):
+ return
+
+ if self.cross_attention_dim_head is None:
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_qkv = nn.Linear(in_features, out_features, bias=False)
+ self.to_qkv.load_state_dict({"weight": concatenated_weights}, strict=True, assign=True)
+ else:
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_kv = nn.Linear(in_features, out_features, bias=False)
+ self.to_kv.load_state_dict({"weight": concatenated_weights}, strict=True, assign=True)
+
+ if self.added_kv_proj_dim is not None:
+ concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
+ concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
+ self.to_added_kv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+
+ self.fused_projections = True
+
+ @torch.no_grad()
+ def unfuse_projections(self):
+ if not getattr(self, "fused_projections", False):
+ return
+
+ if hasattr(self, "to_qkv"):
+ delattr(self, "to_qkv")
+ if hasattr(self, "to_kv"):
+ delattr(self, "to_kv")
+ if hasattr(self, "to_added_kv"):
+ delattr(self, "to_added_kv")
+
+ self.fused_projections = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
+
+
+class Magi1TextProjection(nn.Module):
+ """
+ Projects caption embeddings.
+ """
+
+ def __init__(self, in_features, hidden_size, adaln_dim):
+ super().__init__()
+ self.y_proj_xattn = nn.Sequential(nn.Linear(in_features, hidden_size), nn.SiLU())
+ self.y_proj_adaln = nn.Linear(in_features, adaln_dim)
+
+ def forward(self, caption):
+ caption_xattn = self.y_proj_xattn(caption)
+ caption_adaln = self.y_proj_adaln(caption)
+ return caption_xattn, caption_adaln
+
+
+class Magi1TimeTextEmbedding(nn.Module):
+ """
+ Combined time, text embedding module for the MAGI-1 model.
+
+ This module handles the encoding of two types of conditioning inputs:
+ 1. Timestep embeddings for diffusion process control
+ 2. Text embeddings for text-to-video generation
+
+ Args:
+ dim (`int`): Hidden dimension of the transformer model.
+ time_freq_dim (`int`): Dimension for sinusoidal time embeddings.
+ text_embed_dim (`int`): Input dimension of text embeddings.
+ enable_distillation (`bool`, optional): Enable distillation timestep adjustments.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ time_freq_dim: int,
+ text_embed_dim: int,
+ enable_distillation: bool = False,
+ ):
+ super().__init__()
+
+ # NOTE: timestep_rescale_factor=1000 to match original implementation (dit_module.py:71)
+ self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
+ self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=int(dim * 0.25))
+ self.text_embedder = Magi1TextProjection(text_embed_dim, dim, adaln_dim=int(dim * 0.25))
+
+ self.enable_distillation = enable_distillation
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ num_steps: Optional[int] = None,
+ distill_interval: Optional[int] = None,
+ ):
+ timestep = self.timesteps_proj(timestep)
+
+ time_embedder_dtype = get_parameter_dtype(self.time_embedder)
+ if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
+ timestep = timestep.to(time_embedder_dtype)
+ temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
+ y_xattn, y_adaln = self.text_embedder(encoder_hidden_states)
+
+ # Apply distillation logic if enabled
+ if self.enable_distillation and num_steps is not None:
+ distill_dt_scalar = 2
+ if num_steps == 12 and distill_interval is not None:
+ base_chunk_step = 4
+ distill_dt_factor = base_chunk_step / distill_interval * distill_dt_scalar
+ else:
+ distill_dt_factor = num_steps / 4 * distill_dt_scalar
+
+ distill_dt = torch.ones_like(timestep) * distill_dt_factor
+ distill_dt_embed = self.time_embedder(distill_dt)
+ temb = temb + distill_dt_embed
+
+ return temb, y_xattn, y_adaln
+
+
+class Magi1RotaryPosEmbed(nn.Module):
+ """
+ Rotary Position Embedding for MAGI-1 model.
+
+ Args:
+ dim (`int`): The embedding dimension.
+ theta (`float`, *optional*, defaults to 10000.0): Base for the geometric progression.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ theta: float = 10000.0,
+ ):
+ super().__init__()
+
+ num_bands = dim // 8
+ exp = torch.arange(0, num_bands, dtype=torch.float32) / num_bands
+ bands = 1.0 / (theta**exp)
+ self.bands = nn.Parameter(bands)
+
+ def forward(self, hidden_states: torch.Tensor, T_total: int) -> torch.Tensor:
+ # Rebuild bands and embeddings every call, use if target shape changes
+ device = hidden_states.device
+ batch_size, num_channels, num_frames, post_patch_height, post_patch_width = hidden_states.shape
+ feat_shape = [T_total, post_patch_height, post_patch_width]
+
+ # Calculate rescale_factor for multi-resolution & multi aspect-ratio training
+ # the base_size [16*16] is A predefined size based on data:(256x256) vae: (8,8,4) patch size: (1,1,2)
+ # This definition do not have any relationship with the actual input/model/setting.
+ # ref_feat_shape is used to calculate innner rescale factor, so it can be float.
+ rescale_factor = math.sqrt((post_patch_height * post_patch_width) / (16 * 16))
+ ref_feat_shape = [T_total, post_patch_height / rescale_factor, post_patch_width / rescale_factor]
+
+ f = torch.arange(num_frames, device=device, dtype=torch.float32)
+ h = torch.arange(post_patch_height, device=device, dtype=torch.float32)
+ w = torch.arange(post_patch_width, device=device, dtype=torch.float32)
+
+ # Align spatial center (H/2, W/2) to (0,0)
+ h = h - (post_patch_height - 1) / 2
+ w = w - (post_patch_width - 1) / 2
+
+ if ref_feat_shape is not None:
+ # eva's scheme for resizing rope embeddings (ref shape = pretrain)
+ # aligning to the endpoint e.g [0,1,2] -> [0, 0.4, 0.8, 1.2, 1.6, 2]
+ fhw_rescaled = []
+ fhw = [f, h, w]
+ for x, shape, ref_shape in zip(fhw, feat_shape, ref_feat_shape):
+ if shape == 1: # Deal with image input
+ if ref_shape != 1:
+ raise ValueError("ref_feat_shape must be 1 when feat_shape is 1")
+ fhw_rescaled.append(x)
+ else:
+ fhw_rescaled.append(x / (shape - 1) * (ref_shape - 1))
+ f, h, w = fhw_rescaled
+
+ # Create 3D meshgrid & stack into grid tensor: [T, H, W, 3]
+ grid = torch.stack(torch.meshgrid(f, h, w, indexing="ij"), dim=-1)
+ grid = grid.unsqueeze(-1) # [T, H, W, 3, 1]
+
+ # Apply frequency bands
+ freqs = grid * self.bands # [T, H, W, 3, num_bands]
+
+ freqs_cos = freqs.cos()
+ freqs_sin = freqs.sin()
+
+ # This would be much nicer as a .numel() call to torch.Size(), but torchscript sucks
+ num_spatial_dim = 1
+ for x in feat_shape:
+ num_spatial_dim *= x
+
+ freqs_cos = freqs_cos.reshape(num_spatial_dim, -1)
+ freqs_sin = freqs_sin.reshape(num_spatial_dim, -1)
+
+ return freqs_cos, freqs_sin
+
+
+class Magi1TransformerBlock(nn.Module):
+ """
+ A transformer block used in the MAGI-1 model.
+
+ Args:
+ dim (`int`): The number of channels in the input and output.
+ ffn_dim (`int`): The number of channels in the feed-forward layer.
+ num_heads (`int`): The number of attention heads.
+ num_kv_heads (`int`): The number of key-value attention heads.
+ eps (`float`): The epsilon value for layer normalization.
+ gated_linear_unit (`bool`, defaults to `False`):
+ Whether to use gated linear units (SwiGLU) in the feed-forward network.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ ffn_dim: int,
+ num_heads: int,
+ num_kv_heads: int,
+ eps: float = 1e-6,
+ gated_linear_unit: bool = False,
+ ):
+ super().__init__()
+
+ # 1. Self-attention
+ self.norm1 = FP32LayerNorm(dim, eps)
+ self.attn1 = Magi1Attention(
+ dim=dim,
+ heads=num_heads,
+ kv_heads=num_kv_heads,
+ dim_head=dim // num_heads,
+ eps=eps,
+ cross_attention_dim_head=None,
+ processor=Magi1AttnProcessor(),
+ )
+
+ # 2. Cross-attention
+ self.attn2 = Magi1Attention(
+ dim=dim,
+ heads=num_heads,
+ kv_heads=num_kv_heads,
+ dim_head=dim // num_heads,
+ eps=eps,
+ cross_attention_dim_head=dim // num_kv_heads,
+ processor=Magi1AttnProcessor(),
+ )
+
+ # Combined output projection for concatenated [self_attn, cross_attn] outputs
+ # Matches original architecture: concat -> rearrange -> project
+ self.attn_proj = nn.Linear(2 * dim, dim, bias=False)
+
+ self.ada_modulate_layer = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(
+ int(dim * 0.25),
+ int(dim * 2),
+ ),
+ )
+ self.norm2 = FP32LayerNorm(dim, eps)
+ self.norm3 = FP32LayerNorm(dim, eps)
+
+ # 3. Feed-forward
+ # Use SwiGLU activation for gated linear units, GELU otherwise
+ activation_fn = "swiglu" if gated_linear_unit else "gelu"
+ self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn=activation_fn, bias=False)
+
+ self.norm4 = FP32LayerNorm(dim, eps)
+ with torch.no_grad():
+ self.norm2.weight.add_(1.0)
+ self.norm4.weight.add_(1.0)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ rotary_emb: torch.Tensor,
+ temb_map: torch.Tensor,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ self_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states.float()
+
+ mixed_qqkv = self.norm1(hidden_states)
+
+ self_attn_output = self.attn1(mixed_qqkv, None, None, rotary_emb, attention_kwargs=self_attention_kwargs)
+
+ cross_attn_output = self.attn2(
+ mixed_qqkv, encoder_hidden_states, encoder_attention_mask, None, attention_kwargs=encoder_attention_kwargs
+ )
+
+ # 3. Concatenate attention outputs
+ # Shape: [sq, b, num_heads * head_dim + num_heads * head_dim] = [sq, b, 2 * dim]
+ hidden_states = torch.concat([self_attn_output, cross_attn_output], dim=2)
+
+ # 4. Rearrange to interleave query groups from self and cross attention
+ # This matches the original: rearrange(attn_out, "sq b (n hn hd) -> sq b (hn n hd)", n=2, hn=num_query_groups)
+ # The interleaving is done at the query group level (not per-head level)
+ # For 48 heads with 8 query groups: each group has 6 heads = 768 dims
+ # Interleaving pattern: [self_g0, cross_g0, self_g1, cross_g1, ..., self_g7, cross_g7]
+ batch_size, seq_len, _ = hidden_states.shape
+ num_query_groups = self.attn1.kv_heads if self.attn1.kv_heads is not None else self.attn1.heads
+ group_dim = self_attn_output.shape[2] // num_query_groups
+ hidden_states = hidden_states.reshape(batch_size, seq_len, 2, num_query_groups, group_dim)
+ hidden_states = hidden_states.permute(0, 1, 3, 2, 4)
+ hidden_states = hidden_states.reshape(batch_size, seq_len, -1)
+
+ # 5. Apply combined projection
+ hidden_states = self.attn_proj(hidden_states)
+
+ gate_output = self.ada_modulate_layer(temb)
+ # Softcap with 1.0
+ gate_output = torch.tanh(gate_output.float()).to(gate_output.dtype)
+ gate_msa, gate_mlp = gate_output.chunk(2, dim=-1)
+
+ # Residual connection for self-attention
+ original_dtype = hidden_states.dtype
+ hidden_states = range_mod_pytorch(hidden_states.float().transpose(0, 1), temb_map, gate_msa).transpose(0, 1)
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = hidden_states + residual
+ residual = hidden_states
+ hidden_states = hidden_states.to(original_dtype)
+
+ hidden_states = self.norm3(hidden_states)
+ hidden_states = self.ffn(hidden_states)
+
+ # Residual connection for MLP
+ original_dtype = hidden_states.dtype
+ hidden_states = range_mod_pytorch(hidden_states.float().transpose(0, 1), temb_map, gate_mlp).transpose(0, 1)
+ hidden_states = self.norm4(hidden_states)
+ hidden_states = hidden_states + residual
+ hidden_states = hidden_states.to(original_dtype)
+ return hidden_states
+
+
+class Magi1Transformer3DModel(
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
+):
+ r"""
+ A Transformer model for video-like data used in the Magi1 model.
+
+ This model implements a 3D transformer architecture for video generation with support for text conditioning and
+ optional image conditioning. The model uses rotary position embeddings and adaptive layer normalization for
+ temporal and spatial modeling.
+
+ Args:
+ patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
+ num_attention_heads (`int`, defaults to `16`):
+ The number of attention heads in each transformer block.
+ attention_head_dim (`int`, defaults to `64`):
+ The dimension of each attention head.
+ in_channels (`int`, defaults to `16`):
+ The number of input channels (from VAE latent space).
+ out_channels (`int`, defaults to `16`):
+ The number of output channels (to VAE latent space).
+ cross_attention_dim (`int`, defaults to `4096`):
+ The dimension of cross-attention (text encoder hidden size).
+ freq_dim (`int`, defaults to `256`):
+ Dimension for sinusoidal time embeddings.
+ ffn_dim (`int`, defaults to `4096`):
+ Intermediate dimension in feed-forward network.
+ num_layers (`int`, defaults to `34`):
+ The number of transformer layers to use.
+ eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ gated_linear_unit (`bool`, defaults to `False`):
+ Whether to use gated linear units (SwiGLU activation) in the feed-forward network. If True, uses SwiGLU
+ activation; if False, uses GELU activation.
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "rope"]
+ _no_split_modules = ["Magi1TransformerBlock"]
+ _keep_in_fp32_modules = [
+ "condition_embedder",
+ "scale_shift_table",
+ "norm_out",
+ "norm_q",
+ "norm_k",
+ "patch_embedding",
+ "rope",
+ ]
+ _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
+ _repeated_blocks = ["Magi1TransformerBlock"]
+ _cp_plan = {
+ "rope": {
+ 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
+ 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
+ },
+ "blocks.0": {
+ "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ "blocks.*": {
+ "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: Tuple[int] = (1, 2, 2),
+ num_attention_heads: int = 24,
+ attention_head_dim: int = 128,
+ num_kv_heads: int = 8,
+ in_channels: int = 16,
+ out_channels: int = 16,
+ cross_attention_dim: int = 4096,
+ freq_dim: int = 256,
+ ffn_dim: int = 12288,
+ num_layers: int = 34,
+ eps: float = 1e-6,
+ x_rescale_factor: int = 1,
+ half_channel_vae: bool = False,
+ enable_distillation: bool = False,
+ gated_linear_unit: bool = False,
+ ) -> None:
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+ out_channels = out_channels or in_channels
+
+ # 1. Patch & position embedding
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size, bias=False)
+ self.rope = Magi1RotaryPosEmbed(inner_dim // num_attention_heads)
+
+ # 2. Condition embeddings
+ self.condition_embedder = Magi1TimeTextEmbedding(
+ dim=inner_dim,
+ time_freq_dim=freq_dim,
+ text_embed_dim=cross_attention_dim,
+ enable_distillation=enable_distillation,
+ )
+
+ # 3. Transformer blocks
+ self.blocks = nn.ModuleList(
+ [
+ Magi1TransformerBlock(
+ inner_dim,
+ ffn_dim,
+ num_attention_heads,
+ num_kv_heads,
+ eps,
+ gated_linear_unit,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Output norm & projection
+ self.norm_out = FP32LayerNorm(inner_dim, eps)
+ self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size), bias=False)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ denoising_range_num: Optional[int] = None,
+ range_num: Optional[int] = None,
+ slice_point: Optional[int] = 0,
+ kv_range: Optional[Tuple[int, int]] = None,
+ num_steps: Optional[int] = None,
+ distill_interval: Optional[int] = None,
+ extract_prefix_video_feature: Optional[bool] = False,
+ fwd_extra_1st_chunk: Optional[bool] = False,
+ distill_nearly_clean_chunk: Optional[bool] = False,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ """
+ Forward pass of the MAGI-1 transformer.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ Input tensor of shape `(batch_size, num_channels, num_frames, height, width)`.
+ timestep (`torch.LongTensor`):
+ Timesteps for diffusion process. Shape: `(batch_size, denoising_range_num)`.
+ encoder_hidden_states (`torch.Tensor`):
+ Text embeddings from the text encoder for cross-attention conditioning.
+ encoder_attention_mask (`torch.Tensor`, *optional*):
+ Attention mask for text embeddings to handle variable-length sequences.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a dictionary or a tuple.
+ attention_kwargs (`dict`, *optional*):
+ Additional keyword arguments for attention processors (e.g., LoRA scale).
+ denoising_range_num (`int`, *optional*):
+ Number of denoising ranges for autoregressive video generation. Each range represents a chunk of video
+ frames being denoised in parallel.
+ range_num (`int`, *optional*):
+ Total number of ranges in the video generation process.
+ slice_point (`int`, *optional*, defaults to 0):
+ Index indicating how many clean (already generated) frames precede the current denoising chunks. Used
+ for autoregressive context.
+ kv_range (`Tuple[int, int]`, *optional*):
+ Key-value attention ranges for each denoising chunk, defining which frames each chunk can attend to.
+ Required for MAGI-1's autoregressive attention pattern.
+ num_steps (`int`, *optional*):
+ Number of diffusion sampling steps. Used for distillation timestep adjustments.
+ distill_interval (`int`, *optional*):
+ Interval for distillation when using distilled models. Used with `num_steps`.
+ extract_prefix_video_feature (`bool`, *optional*, defaults to `False`):
+ Whether to extract features from prefix (clean) video frames.
+ fwd_extra_1st_chunk (`bool`, *optional*, defaults to `False`):
+ Whether to forward an extra first chunk in the current iteration.
+ distill_nearly_clean_chunk (`bool`, *optional*, defaults to `False`):
+ Whether to apply distillation to nearly clean chunks during generation.
+
+ Returns:
+ `Transformer2DModelOutput` or `tuple`:
+ If `return_dict` is True, returns a `Transformer2DModelOutput` containing the sample. Otherwise,
+ returns a tuple with the sample as the first element.
+
+ Note:
+ MAGI-1 uses an autoregressive video generation approach where video frames are generated in chunks. The
+ `denoising_range_num`, `kv_range`, and related parameters control this autoregressive pattern, allowing
+ each chunk to attend to previously generated (clean) frames while maintaining causal constraints.
+ """
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ # `kv_range` is optional when not using MAGI varlen backend. No requirement here.
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.config.patch_size
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p_h
+ post_patch_width = width // p_w
+ frame_in_range = post_patch_num_frames // denoising_range_num
+ prev_clean_T = frame_in_range * slice_point
+ T_total = post_patch_num_frames + prev_clean_T
+
+ hidden_states = hidden_states * self.config.x_rescale_factor
+
+ if self.config.half_channel_vae:
+ if hidden_states.shape[1] != 16:
+ raise ValueError(
+ "When `config.half_channel_vae` is True, the input `hidden_states` must have 16 channels."
+ )
+ hidden_states = torch.cat([hidden_states, hidden_states], dim=1)
+
+ # Patch & position embedding
+ hidden_states = self.patch_embedding(hidden_states)
+ freqs_cos, freqs_sin = self.rope(hidden_states, T_total)
+ # The shape of freqs_* is (total_seq_length, head_dim). Keep only the last tokens corresponding to current window.
+ keep = post_patch_num_frames * post_patch_height * post_patch_width
+ freqs_cos = freqs_cos[-keep:]
+ freqs_sin = freqs_sin[-keep:]
+ rotary_emb = (freqs_cos, freqs_sin)
+
+ hidden_states = hidden_states.flatten(2).transpose(
+ 1, 2
+ ) # (B, post_patch_num_frames * post_patch_height * post_patch_width, C)
+
+ temb, y_encoder_attention, y_adaln = self.condition_embedder(
+ timestep.flatten(),
+ encoder_hidden_states,
+ num_steps,
+ distill_interval,
+ )
+
+ # Pool AdaLN text conditioning over valid tokens (mask) to get per-batch vector, then broadcast per range
+ if encoder_attention_mask is not None:
+ mask_2d = encoder_attention_mask.squeeze(1).squeeze(1).to(torch.bool) # [B, L]
+ denom = mask_2d.sum(dim=1, keepdim=True).clamp(min=1)
+ y_adaln_pooled = (y_adaln * mask_2d.unsqueeze(-1)).sum(dim=1) / denom
+ else:
+ mask_2d = None
+ y_adaln_pooled = y_adaln.mean(dim=1)
+
+ temb = temb.reshape(batch_size, denoising_range_num, -1) + y_adaln_pooled.unsqueeze(1).expand(
+ -1, denoising_range_num, -1
+ )
+
+ seqlen_per_chunk = (post_patch_num_frames * post_patch_height * post_patch_width) // denoising_range_num
+ temb_map = torch.arange(batch_size * denoising_range_num, device=hidden_states.device)
+ temb_map = torch.repeat_interleave(temb_map, seqlen_per_chunk)
+ temb_map = temb_map.reshape(batch_size, -1).transpose(0, 1).contiguous()
+
+ # Build varlen metadata for MAGI backend
+ clip_token_nums = post_patch_height * post_patch_width * frame_in_range
+
+ self_attention_kwargs = None
+ if kv_range is not None:
+ cu_seqlens_q = torch.tensor(
+ [0] + ([clip_token_nums] * denoising_range_num * batch_size), dtype=torch.int64, device=hidden_states.device
+ ).cumsum(-1).to(torch.int32)
+ # q_ranges pairs from cu_seqlens_q
+ q_ranges = torch.cat([cu_seqlens_q[:-1].unsqueeze(1), cu_seqlens_q[1:].unsqueeze(1)], dim=1)
+ flat_kv = torch.unique(kv_range, sorted=True)
+ max_seqlen_k = int((flat_kv[-1] - flat_kv[0]).item())
+ self_attention_kwargs = {
+ "q_ranges": q_ranges,
+ "k_ranges": kv_range,
+ "max_seqlen_q": clip_token_nums,
+ "max_seqlen_k": max_seqlen_k,
+ }
+
+ encoder_attention_kwargs = None
+ if mask_2d is not None:
+ y_index = mask_2d.sum(dim=-1).to(torch.int32)
+ cu_seqlens_q = torch.tensor(
+ [0] + ([clip_token_nums] * denoising_range_num * batch_size), dtype=torch.int64, device=hidden_states.device
+ ).cumsum(-1).to(torch.int32)
+ cu_seqlens_k = torch.cat([y_index.new_zeros(1, dtype=torch.int32), y_index.to(torch.int32)]).cumsum(-1)
+ q_ranges = torch.cat([cu_seqlens_q[:-1].unsqueeze(1), cu_seqlens_q[1:].unsqueeze(1)], dim=1)
+ k_ranges = torch.cat([cu_seqlens_k[:-1].unsqueeze(1), cu_seqlens_k[1:].unsqueeze(1)], dim=1)
+ max_seqlen_kv = int(y_index.max().item()) if y_index.numel() > 0 else 0
+ encoder_attention_kwargs = {
+ "q_ranges": q_ranges,
+ "k_ranges": k_ranges,
+ "max_seqlen_q": clip_token_nums,
+ "max_seqlen_k": max_seqlen_kv,
+ "cu_seqlens_q": cu_seqlens_q,
+ "cu_seqlens_k": cu_seqlens_k,
+ }
+
+ # 4. Transformer blocks
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for block in self.blocks:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ y_encoder_attention,
+ temb,
+ rotary_emb,
+ temb_map,
+ mask_2d,
+ self_attention_kwargs,
+ encoder_attention_kwargs,
+ )
+ else:
+ for block in self.blocks:
+ hidden_states = block(
+ hidden_states,
+ y_encoder_attention,
+ temb,
+ rotary_emb,
+ temb_map,
+ mask_2d,
+ self_attention_kwargs,
+ encoder_attention_kwargs,
+ )
+
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = self.proj_out(hidden_states)
+
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
+ )
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index c438caed571f..a3b8936f8fb4 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -28,6 +28,7 @@
"deprecated": [],
"latent_diffusion": [],
"ledits_pp": [],
+ "magi1": [],
"marigold": [],
"pag": [],
"stable_diffusion": [],
@@ -293,6 +294,7 @@
"MarigoldNormalsPipeline",
]
)
+ _import_structure["magi1"] = ["Magi1Pipeline", "Magi1ImageToVideoPipeline", "Magi1VideoToVideoPipeline"]
_import_structure["mochi"] = ["MochiPipeline"]
_import_structure["musicldm"] = ["MusicLDMPipeline"]
_import_structure["omnigen"] = ["OmniGenPipeline"]
@@ -689,6 +691,7 @@
from .lucy import LucyEditPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
+ from .magi1 import Magi1ImageToVideoPipeline, Magi1Pipeline, Magi1VideoToVideoPipeline
from .marigold import (
MarigoldDepthPipeline,
MarigoldIntrinsicsPipeline,
diff --git a/src/diffusers/pipelines/magi1/__init__.py b/src/diffusers/pipelines/magi1/__init__.py
new file mode 100644
index 000000000000..5fc6735f6357
--- /dev/null
+++ b/src/diffusers/pipelines/magi1/__init__.py
@@ -0,0 +1,53 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_magi1"] = ["Magi1Pipeline"]
+ _import_structure["pipeline_magi1_i2v"] = ["Magi1ImageToVideoPipeline"]
+ _import_structure["pipeline_magi1_v2v"] = ["Magi1VideoToVideoPipeline"]
+ _import_structure["pipeline_output"] = ["Magi1PipelineOutput"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_magi1 import Magi1Pipeline
+ from .pipeline_magi1_i2v import Magi1ImageToVideoPipeline
+ from .pipeline_magi1_v2v import Magi1VideoToVideoPipeline
+ from .pipeline_output import Magi1PipelineOutput
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1.py b/src/diffusers/pipelines/magi1/pipeline_magi1.py
new file mode 100644
index 000000000000..1948c9c7ca39
--- /dev/null
+++ b/src/diffusers/pipelines/magi1/pipeline_magi1.py
@@ -0,0 +1,1169 @@
+# Copyright 2025 The SandAI Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# MAGI-1 T2V Pipeline with Autoregressive Chunked Generation
+#
+# âś… IMPLEMENTED:
+# - Autoregressive chunked generation (always enabled, matching original MAGI-1)
+# - Window-based scheduling: chunk_width=6, window_size=4
+# - Progressive denoising across overlapping temporal windows
+# - Proper CFG with separate forward passes (diffusers style)
+#
+# ⚠️ CURRENT LIMITATION:
+# - No KV caching: attention is recomputed for previous chunks
+# - This is less efficient than the original but fully functional
+#
+# ⏳ FUTURE OPTIMIZATIONS (when diffusers adds generic KV caching):
+# 1. **KV Cache Management**:
+# - Cache attention keys/values for previously denoised chunks
+# - Reuse cached computations instead of recomputing
+# - Will significantly speed up generation (2-3x faster expected)
+#
+# 2. **Special Token Support** (optional enhancement):
+# - Duration tokens: indicate how many chunks remain to generate
+# - Quality tokens: HQ_TOKEN for high-quality generation
+# - Style tokens: THREE_D_MODEL_TOKEN, TWO_D_ANIME_TOKEN
+# - Motion tokens: STATIC_FIRST_FRAMES_TOKEN, DYNAMIC_FIRST_FRAMES_TOKEN
+#
+# 3. **Streaming Generation**:
+# - Yield clean chunks as they complete (generator pattern)
+# - Enable real-time preview during generation
+#
+# Reference: https://github.com/SandAI/MAGI-1/blob/main/inference/pipeline/video_generate.py
+
+import html
+import re
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import ftfy
+import torch
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import Magi1LoraLoaderMixin
+from ...models import AutoencoderKLMagi1, Magi1Transformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import Magi1PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+def generate_chunk_sequences(chunk_num: int, window_size: int, chunk_offset: int = 0):
+ """
+ Generate chunk scheduling sequences for autoregressive video generation.
+
+ Args:
+ chunk_num: Total number of chunks to generate
+ window_size: Number of chunks to process in each window
+ chunk_offset: Number of clean prefix chunks (for I2V/V2V)
+
+ Returns:
+ ```
+ clip_start: Start index of chunks to process
+ clip_end: End index of chunks to process
+ t_start: Start index in time dimension
+ t_end: End index in time dimension
+ ```
+
+ Examples:
+ ```
+ chunk_num=8, window_size=4, chunk_offset=0
+ Stage 0: Process chunks [0:1], denoise chunk 0
+ Stage 1: Process chunks [0:2], denoise chunk 1
+ Stage 2: Process chunks [0:3], denoise chunk 2
+ Stage 3: Process chunks [0:4], denoise chunk 3
+ Stage 4: Process chunks [1:5], denoise chunk 4
+ ...
+ ```
+ """
+ start_index = chunk_offset
+ end_index = chunk_num + window_size - 1
+
+ clip_start = [max(chunk_offset, i - window_size + 1) for i in range(start_index, end_index)]
+ clip_end = [min(chunk_num, i + 1) for i in range(start_index, end_index)]
+
+ t_start = [max(0, i - chunk_num + 1) for i in range(start_index, end_index)]
+ t_end = [
+ min(window_size, i - chunk_offset + 1) if i - chunk_offset < window_size else window_size
+ for i in range(start_index, end_index)
+ ]
+
+ return clip_start, clip_end, t_start, t_end
+
+
+def load_special_tokens(special_tokens_path: Optional[str] = None) -> Optional[Dict[str, torch.Tensor]]:
+ """
+ Load special conditioning tokens from numpy file.
+
+ Args:
+ special_tokens_path: Path to special_tokens.npz file. If None, returns None (no special tokens).
+
+ Returns:
+ Dictionary mapping token names to embeddings, or None if path not provided or file doesn't exist.
+ """
+ if special_tokens_path is None:
+ return None
+
+ try:
+ import os
+
+ import numpy as np
+
+ if not os.path.exists(special_tokens_path):
+ logger.warning(f"Special tokens file not found at {special_tokens_path}, skipping special token loading.")
+ return None
+
+ special_token_data = np.load(special_tokens_path)
+ caption_token = torch.tensor(special_token_data["caption_token"].astype(np.float16))
+ logo_token = torch.tensor(special_token_data["logo_token"].astype(np.float16))
+ other_tokens = special_token_data["other_tokens"]
+
+ tokens = {
+ "CAPTION_TOKEN": caption_token,
+ "LOGO_TOKEN": logo_token,
+ "TRANS_TOKEN": torch.tensor(other_tokens[:1].astype(np.float16)),
+ "HQ_TOKEN": torch.tensor(other_tokens[1:2].astype(np.float16)),
+ "STATIC_FIRST_FRAMES_TOKEN": torch.tensor(other_tokens[2:3].astype(np.float16)),
+ "DYNAMIC_FIRST_FRAMES_TOKEN": torch.tensor(other_tokens[3:4].astype(np.float16)),
+ "BORDERNESS_TOKEN": torch.tensor(other_tokens[4:5].astype(np.float16)),
+ "THREE_D_MODEL_TOKEN": torch.tensor(other_tokens[15:16].astype(np.float16)),
+ "TWO_D_ANIME_TOKEN": torch.tensor(other_tokens[16:17].astype(np.float16)),
+ }
+
+ # Duration tokens (8 total, representing 1-8 chunks remaining)
+ for i in range(8):
+ tokens[f"DURATION_TOKEN_{i + 1}"] = torch.tensor(other_tokens[i + 7 : i + 8].astype(np.float16))
+
+ logger.info(f"Loaded {len(tokens)} special tokens from {special_tokens_path}")
+ return tokens
+ except Exception as e:
+ logger.warning(f"Failed to load special tokens: {e}")
+ return None
+
+
+def prepend_special_tokens(
+ prompt_embeds: torch.Tensor,
+ special_tokens: Optional[Dict[str, torch.Tensor]],
+ use_hq_token: bool = False,
+ use_3d_style: bool = False,
+ use_2d_anime_style: bool = False,
+ use_static_first_frames: bool = False,
+ use_dynamic_first_frames: bool = False,
+ max_sequence_length: int = 800,
+) -> torch.Tensor:
+ """
+ Prepend special conditioning tokens to text embeddings.
+
+ Args:
+ prompt_embeds: Text embeddings [batch, seq_len, hidden_dim]
+ special_tokens: Dictionary of special token embeddings
+ use_hq_token: Whether to add high-quality token
+ use_3d_style: Whether to add 3D model style token
+ use_2d_anime_style: Whether to add 2D anime style token
+ use_static_first_frames: Whether to add static motion token
+ use_dynamic_first_frames: Whether to add dynamic motion token
+ max_sequence_length: Maximum sequence length after prepending
+
+ Returns:
+ Text embeddings with special tokens prepended
+ """
+ if special_tokens is None:
+ return prompt_embeds
+
+ device = prompt_embeds.device
+ dtype = prompt_embeds.dtype
+ batch_size, seq_len, hidden_dim = prompt_embeds.shape
+
+ # Collect tokens to prepend (in order: motion, quality, style)
+ tokens_to_add = []
+ if use_static_first_frames and "STATIC_FIRST_FRAMES_TOKEN" in special_tokens:
+ tokens_to_add.append(special_tokens["STATIC_FIRST_FRAMES_TOKEN"])
+ if use_dynamic_first_frames and "DYNAMIC_FIRST_FRAMES_TOKEN" in special_tokens:
+ tokens_to_add.append(special_tokens["DYNAMIC_FIRST_FRAMES_TOKEN"])
+ if use_hq_token and "HQ_TOKEN" in special_tokens:
+ tokens_to_add.append(special_tokens["HQ_TOKEN"])
+ if use_3d_style and "THREE_D_MODEL_TOKEN" in special_tokens:
+ tokens_to_add.append(special_tokens["THREE_D_MODEL_TOKEN"])
+ if use_2d_anime_style and "TWO_D_ANIME_TOKEN" in special_tokens:
+ tokens_to_add.append(special_tokens["TWO_D_ANIME_TOKEN"])
+
+ # Prepend tokens
+ for token in tokens_to_add:
+ token = token.to(device=device, dtype=dtype).unsqueeze(0).expand(batch_size, -1, -1)
+ prompt_embeds = torch.cat([token, prompt_embeds], dim=1)
+
+ # Truncate to max length
+ if prompt_embeds.shape[1] > max_sequence_length:
+ prompt_embeds = prompt_embeds[:, :max_sequence_length, :]
+
+ return prompt_embeds
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import Magi1Pipeline, AutoencoderKLMagi1, FlowMatchEulerDiscreteScheduler
+ >>> from diffusers.utils import export_to_video
+
+ >>> model_id = "SandAI/Magi1-T2V-14B-480P-Diffusers"
+ >>> vae = AutoencoderKLMagi1.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+
+ >>> # IMPORTANT: MAGI-1 requires shift=3.0 for the scheduler (SD3-style time resolution transform)
+ >>> scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler", shift=3.0)
+
+ >>> pipe = Magi1Pipeline.from_pretrained(model_id, vae=vae, scheduler=scheduler, torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "A cat and a dog baking a cake together in a kitchen."
+ >>> negative_prompt = "Bright tones, overexposed, static, blurred details, worst quality, low quality"
+
+ >>> output = pipe(
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... height=720,
+ ... width=1280,
+ ... num_frames=81,
+ ... guidance_scale=5.0,
+ ... num_inference_steps=50,
+ ... ).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=16)
+ ```
+"""
+
+
+class Magi1Pipeline(DiffusionPipeline, Magi1LoraLoaderMixin):
+ r"""
+ Pipeline for text-to-video generation using Magi1.
+
+ MAGI-1 is a DiT-based video generation model that supports autoregressive chunked generation for long videos.
+
+ **Note**: This implementation uses autoregressive chunked generation (chunk_width=6, window_size=4) as in the
+ original MAGI-1 paper, with support for special conditioning tokens for quality, style, and motion control.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ transformer ([`Magi1Transformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A flow matching scheduler with Euler discretization, using SD3-style time resolution transform.
+ vae ([`AutoencoderKLMagi1`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ transformer: Magi1Transformer3DModel,
+ vae: AutoencoderKLMagi1,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_scale_factor_temporal = (
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
+ )
+ self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Special tokens for conditioning (optional)
+ self.special_tokens = None
+
+ def load_special_tokens_from_file(self, special_tokens_path: str):
+ """
+ Load special conditioning tokens from a numpy file.
+
+ Args:
+ special_tokens_path: Path to special_tokens.npz file
+ """
+ self.special_tokens = load_special_tokens(special_tokens_path)
+ if self.special_tokens is not None:
+ logger.info("Special tokens loaded successfully. You can now use quality, style, and motion control.")
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 800,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ # Repeat mask the same way as embeddings and keep size [B*num, L]
+ mask = mask.repeat(1, num_videos_per_prompt)
+ mask = mask.view(batch_size * num_videos_per_prompt, -1).to(device)
+
+ return prompt_embeds, mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 800,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_mask = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+ else:
+ prompt_mask = None
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds, negative_mask = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+ else:
+ negative_mask = None
+ return prompt_embeds, negative_prompt_embeds, prompt_mask, negative_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 800,
+ use_hq_token: bool = False,
+ use_3d_style: bool = False,
+ use_2d_anime_style: bool = False,
+ use_static_first_frames: bool = False,
+ use_dynamic_first_frames: bool = False,
+ enable_distillation: bool = False,
+ distill_nearly_clean_chunk_threshold: float = 0.3,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ **Note**: This implementation uses autoregressive chunked generation (chunk_width=6, window_size=4) as in the
+ original MAGI-1 paper. The implementation currently works without KV caching (attention is recomputed for
+ previous chunks), which is less efficient than the original but still functional. KV caching optimization will
+ be added when diffusers implements generic caching support for transformers.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (`guidance_scale` < `1`).
+ height (`int`, defaults to `480`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `832`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `81`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, negative_prompt_embeds will be generated from the `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated video. Choose between `"latent"`, `"pt"`, or `"np"`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`Magi1PipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `800`):
+ The maximum sequence length for the text encoder. Sequences longer than this will be truncated. MAGI-1
+ uses a max length of 800 tokens.
+ use_hq_token (`bool`, *optional*, defaults to `False`):
+ Whether to prepend the high-quality control token to the text embeddings. This token conditions the
+ model to generate higher quality outputs. Requires special tokens to be loaded via
+ `load_special_tokens_from_file`.
+ use_3d_style (`bool`, *optional*, defaults to `False`):
+ Whether to prepend the 3D model style token to the text embeddings. This token conditions the model to
+ generate outputs with 3D modeling aesthetics. Requires special tokens to be loaded.
+ use_2d_anime_style (`bool`, *optional*, defaults to `False`):
+ Whether to prepend the 2D anime style token to the text embeddings. This token conditions the model to
+ generate outputs with 2D anime aesthetics. Requires special tokens to be loaded.
+ use_static_first_frames (`bool`, *optional*, defaults to `False`):
+ Whether to prepend the static first frames token to the text embeddings. This token conditions the
+ model to start the video with minimal motion in the first few frames. Requires special tokens to be
+ loaded.
+ use_dynamic_first_frames (`bool`, *optional*, defaults to `False`):
+ Whether to prepend the dynamic first frames token to the text embeddings. This token conditions the
+ model to start the video with significant motion in the first few frames. Requires special tokens to be
+ loaded.
+ enable_distillation (`bool`, *optional*, defaults to `False`):
+ Whether to enable distillation mode. In distillation mode, the model uses modified timestep embeddings
+ to support distilled (faster) inference. This requires a distilled model checkpoint.
+ distill_nearly_clean_chunk_threshold (`float`, *optional*, defaults to `0.3`):
+ Threshold for identifying nearly-clean chunks in distillation mode. Chunks with timestep > threshold
+ are considered nearly clean and processed differently. Only used when `enable_distillation=True`.
+
+ Examples:
+
+ Returns:
+ [`~Magi1PipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`Magi1PipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated videos.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds, prompt_mask, negative_mask = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=self.text_encoder.dtype,
+ )
+
+ # 3.5. Prepend special tokens if requested
+ if self.special_tokens is not None and any(
+ [use_hq_token, use_3d_style, use_2d_anime_style, use_static_first_frames, use_dynamic_first_frames]
+ ):
+ prompt_embeds = prepend_special_tokens(
+ prompt_embeds=prompt_embeds,
+ special_tokens=self.special_tokens,
+ use_hq_token=use_hq_token,
+ use_3d_style=use_3d_style,
+ use_2d_anime_style=use_2d_anime_style,
+ use_static_first_frames=use_static_first_frames,
+ use_dynamic_first_frames=use_dynamic_first_frames,
+ max_sequence_length=max_sequence_length,
+ )
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = prepend_special_tokens(
+ prompt_embeds=negative_prompt_embeds,
+ special_tokens=self.special_tokens,
+ use_hq_token=use_hq_token,
+ use_3d_style=use_3d_style,
+ use_2d_anime_style=use_2d_anime_style,
+ use_static_first_frames=use_static_first_frames,
+ use_dynamic_first_frames=use_dynamic_first_frames,
+ max_sequence_length=max_sequence_length,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Denoising loop (autoregressive chunked generation)
+ # MAGI-1 always uses autoregressive generation with chunk_width=6 and window_size=4
+ # Note: num_warmup_steps is calculated for compatibility but not used in progress bar logic
+ # because autoregressive generation has a different iteration structure (stages Ă— steps)
+ # For FlowMatchEulerDiscreteScheduler (order=1), this doesn't affect the results
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ # Autoregressive chunked generation parameters
+ chunk_width = 6 # Original MAGI-1 default
+ window_size = 4 # Original MAGI-1 default
+
+ num_latent_frames = latents.shape[2]
+ num_chunks = (num_latent_frames + chunk_width - 1) // chunk_width
+
+ # Calculate chunk scheduling: which chunks to process at each stage
+ clip_start, clip_end, t_start, t_end = generate_chunk_sequences(num_chunks, window_size, chunk_offset=0)
+ num_stages = len(clip_start)
+
+ # Number of denoising steps per stage
+ denoise_step_per_stage = len(timesteps) // window_size
+
+ # Track how many times each chunk has been denoised
+ chunk_denoise_count = {i: 0 for i in range(num_chunks)}
+
+ with self.progress_bar(total=num_stages * denoise_step_per_stage) as progress_bar:
+ for stage_idx in range(num_stages):
+ # Determine which chunks to process in this stage
+ chunk_start_idx = clip_start[stage_idx]
+ chunk_end_idx = clip_end[stage_idx]
+ t_start_idx = t_start[stage_idx]
+ t_end_idx = t_end[stage_idx]
+
+ # Extract chunk range in latent space
+ latent_start = chunk_start_idx * chunk_width
+ latent_end = min(chunk_end_idx * chunk_width, num_latent_frames)
+
+ # Number of chunks in current window
+ num_chunks_in_window = chunk_end_idx - chunk_start_idx
+
+ # Prepare per-chunk conditioning with duration/borderness tokens
+ # Duration tokens indicate how many chunks remain in the video
+ # Borderness tokens condition on chunk boundaries
+ chunk_prompt_embeds_list = []
+ chunk_negative_prompt_embeds_list = []
+ chunk_prompt_masks_list = []
+ chunk_negative_masks_list = []
+
+ if self.special_tokens is not None:
+ # Prepare per-chunk embeddings with duration tokens
+ # Each chunk gets a different duration token based on chunks remaining
+ for i, chunk_idx in enumerate(range(chunk_start_idx, chunk_end_idx)):
+ chunks_remaining = num_chunks - chunk_idx - 1
+ # Duration token ranges from 1-8 chunks
+ duration_idx = min(chunks_remaining, 7) + 1
+
+ # Add duration and borderness tokens for this chunk
+ token_embeds = prompt_embeds.clone()
+ if f"DURATION_TOKEN_{duration_idx}" in self.special_tokens:
+ duration_token = self.special_tokens[f"DURATION_TOKEN_{duration_idx}"]
+ duration_token = duration_token.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
+ duration_token = duration_token.unsqueeze(0).expand(batch_size, -1, -1)
+ token_embeds = torch.cat([duration_token, token_embeds], dim=1)
+
+ if "BORDERNESS_TOKEN" in self.special_tokens:
+ borderness_token = self.special_tokens["BORDERNESS_TOKEN"]
+ borderness_token = borderness_token.to(
+ device=prompt_embeds.device, dtype=prompt_embeds.dtype
+ )
+ borderness_token = borderness_token.unsqueeze(0).expand(batch_size, -1, -1)
+ token_embeds = torch.cat([borderness_token, token_embeds], dim=1)
+
+ # Build per-chunk mask by prepending ones for each added token and truncating
+ token_mask = prompt_mask
+ add_count = 0
+ if f"DURATION_TOKEN_{duration_idx}" in self.special_tokens:
+ add_count += 1
+ if "BORDERNESS_TOKEN" in self.special_tokens:
+ add_count += 1
+ if add_count > 0:
+ prepend = torch.ones(
+ token_mask.shape[0], add_count, dtype=token_mask.dtype, device=token_mask.device
+ )
+ token_mask = torch.cat([prepend, token_mask], dim=1)
+ if token_embeds.shape[1] > max_sequence_length:
+ token_embeds = token_embeds[:, :max_sequence_length, :]
+ token_mask = token_mask[:, :max_sequence_length]
+
+ chunk_prompt_embeds_list.append(token_embeds)
+ chunk_prompt_masks_list.append(token_mask)
+
+ # Same for negative prompts
+ if self.do_classifier_free_guidance:
+ neg_token_embeds = negative_prompt_embeds.clone()
+ if f"DURATION_TOKEN_{duration_idx}" in self.special_tokens:
+ duration_token = self.special_tokens[f"DURATION_TOKEN_{duration_idx}"]
+ duration_token = duration_token.to(
+ device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype
+ )
+ duration_token = duration_token.unsqueeze(0).expand(batch_size, -1, -1)
+ neg_token_embeds = torch.cat([duration_token, neg_token_embeds], dim=1)
+
+ if "BORDERNESS_TOKEN" in self.special_tokens:
+ borderness_token = self.special_tokens["BORDERNESS_TOKEN"]
+ borderness_token = borderness_token.to(
+ device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype
+ )
+ borderness_token = borderness_token.unsqueeze(0).expand(batch_size, -1, -1)
+ neg_token_embeds = torch.cat([borderness_token, neg_token_embeds], dim=1)
+
+ # Build negative per-chunk mask
+ neg_mask = negative_mask
+ add_count = 0
+ if f"DURATION_TOKEN_{duration_idx}" in self.special_tokens:
+ add_count += 1
+ if "BORDERNESS_TOKEN" in self.special_tokens:
+ add_count += 1
+ if add_count > 0:
+ prepend = torch.ones(
+ neg_mask.shape[0], add_count, dtype=neg_mask.dtype, device=neg_mask.device
+ )
+ neg_mask = torch.cat([prepend, neg_mask], dim=1)
+ if neg_token_embeds.shape[1] > max_sequence_length:
+ neg_token_embeds = neg_token_embeds[:, :max_sequence_length, :]
+ neg_mask = neg_mask[:, :max_sequence_length]
+
+ chunk_negative_prompt_embeds_list.append(neg_token_embeds)
+ chunk_negative_masks_list.append(neg_mask)
+
+ # Denoise this chunk range for denoise_step_per_stage steps
+ for denoise_idx in range(denoise_step_per_stage):
+ if self.interrupt:
+ break
+
+ # Calculate timestep index for each chunk in the current window
+ # Chunks at different stages get different timesteps based on their denoise progress
+ # Original MAGI-1: get_timestep() and get_denoise_step_of_each_chunk()
+ timestep_indices = []
+ for offset in range(num_chunks_in_window):
+ # Map offset within window to time index
+ t_idx_within_window = t_start_idx + offset
+ t_idx = t_idx_within_window * denoise_step_per_stage + denoise_idx
+ timestep_indices.append(t_idx)
+
+ # Reverse order: later chunks in window are noisier (higher timestep index)
+ timestep_indices.reverse()
+
+ # Clamp indices to valid range
+ timestep_indices = [min(idx, len(timesteps) - 1) for idx in timestep_indices]
+
+ # Get actual timesteps
+ current_timesteps = timesteps[timestep_indices]
+
+ # Create per-chunk timestep tensor: [batch_size, num_chunks_in_window]
+ # Each chunk gets its own timestep based on how many times it's been denoised
+ timestep_per_chunk = current_timesteps.unsqueeze(0).expand(batch_size, -1)
+
+ # Store first timestep for progress tracking
+ self._current_timestep = current_timesteps[0]
+
+ # Extract chunk
+ latent_chunk = latents[:, :, latent_start:latent_end].to(transformer_dtype)
+
+ # Prepare distillation parameters if enabled
+ num_steps = None
+ distill_interval = None
+ distill_nearly_clean_chunk = None
+
+ if enable_distillation:
+ # Distillation mode uses modified timestep embeddings for faster inference
+ # The interval represents the step size in the distilled schedule
+ num_steps = num_inference_steps
+ distill_interval = len(timesteps) / num_inference_steps
+
+ # Determine if chunks are nearly clean (low noise) based on their timesteps
+ # Check the first chunk's timestep (after reversing, this is the noisiest actively denoising chunk)
+ # Original: checks t[0, int(fwd_extra_1st_chunk)].item()
+ nearly_clean_chunk_t = current_timesteps[0].item() / self.scheduler.config.num_train_timesteps
+ distill_nearly_clean_chunk = nearly_clean_chunk_t < distill_nearly_clean_chunk_threshold
+
+ # Prepare per-chunk embeddings
+ # The transformer expects embeddings in shape [batch_size * num_chunks_in_window, seq_len, hidden_dim]
+ # Each chunk gets its own embedding with appropriate duration/borderness tokens
+ if chunk_prompt_embeds_list:
+ # Stack per-chunk embeddings: [num_chunks_in_window, batch_size, seq_len, hidden_dim]
+ chunk_prompt_embeds = torch.stack(chunk_prompt_embeds_list, dim=0)
+ # Reshape to [batch_size * num_chunks_in_window, seq_len, hidden_dim]
+ chunk_prompt_embeds = chunk_prompt_embeds.transpose(0, 1).flatten(0, 1)
+
+ if chunk_negative_prompt_embeds_list:
+ chunk_negative_prompt_embeds = torch.stack(chunk_negative_prompt_embeds_list, dim=0)
+ chunk_negative_prompt_embeds = chunk_negative_prompt_embeds.transpose(0, 1).flatten(0, 1)
+ else:
+ chunk_negative_prompt_embeds = None
+ else:
+ # Fallback: repeat shared embeddings for each chunk
+ chunk_prompt_embeds = prompt_embeds.unsqueeze(1).repeat(1, num_chunks_in_window, 1, 1)
+ chunk_prompt_embeds = chunk_prompt_embeds.flatten(0, 1)
+ prompt_mask_rep = prompt_mask.unsqueeze(1).repeat(1, num_chunks_in_window, 1)
+ prompt_mask_rep = prompt_mask_rep.flatten(0, 1)
+
+ if negative_prompt_embeds is not None:
+ chunk_negative_prompt_embeds = negative_prompt_embeds.unsqueeze(1).repeat(
+ 1, num_chunks_in_window, 1, 1
+ )
+ chunk_negative_prompt_embeds = chunk_negative_prompt_embeds.flatten(0, 1)
+ negative_mask_rep = negative_mask.unsqueeze(1).repeat(1, num_chunks_in_window, 1)
+ negative_mask_rep = negative_mask_rep.flatten(0, 1)
+ else:
+ chunk_negative_prompt_embeds = None
+ negative_mask_rep = None
+
+ # Create encoder attention mask(s) for per-chunk embeddings from tokenizer masks
+ if chunk_prompt_masks_list:
+ chunk_prompt_masks = torch.stack(chunk_prompt_masks_list, dim=0).transpose(0, 1).flatten(0, 1)
+ else:
+ chunk_prompt_masks = prompt_mask_rep
+ encoder_attention_mask = chunk_prompt_masks.to(dtype=chunk_prompt_embeds.dtype).view(
+ batch_size * num_chunks_in_window, 1, 1, -1
+ )
+ if chunk_negative_masks_list:
+ chunk_negative_masks = torch.stack(chunk_negative_masks_list, dim=0).transpose(0, 1).flatten(0, 1)
+ encoder_attention_mask_neg = chunk_negative_masks.to(dtype=chunk_prompt_embeds.dtype).view(
+ batch_size * num_chunks_in_window, 1, 1, -1
+ )
+ else:
+ encoder_attention_mask_neg = (
+ negative_mask_rep.to(dtype=chunk_prompt_embeds.dtype).view(
+ batch_size * num_chunks_in_window, 1, 1, -1
+ )
+ if negative_mask_rep is not None
+ else encoder_attention_mask
+ )
+
+ # Generate KV range for autoregressive attention
+ # Each chunk can attend to itself and all previous chunks in the sequence
+ # Shape: [batch_size * num_chunks_in_window, 2] where each row is [start_token_idx, end_token_idx]
+ chunk_token_nums = (
+ (latent_chunk.shape[2] // num_chunks_in_window) # frames per chunk
+ * (latent_chunk.shape[3] // self.transformer.config.patch_size[1]) # height tokens
+ * (latent_chunk.shape[4] // self.transformer.config.patch_size[2]) # width tokens
+ )
+ kv_range = []
+ for b in range(batch_size):
+ # Calculate proper batch offset based on total number of chunks in video
+ batch_offset = b * num_chunks
+ for c in range(num_chunks_in_window):
+ # This chunk can attend from the start of its batch up to its own end
+ chunk_global_idx = chunk_start_idx + c
+ k_start = batch_offset * chunk_token_nums
+ k_end = (batch_offset + chunk_global_idx + 1) * chunk_token_nums
+ kv_range.append([k_start, k_end])
+ kv_range = torch.tensor(kv_range, dtype=torch.int32, device=device)
+
+ # Predict noise (conditional)
+ noise_pred = self.transformer(
+ hidden_states=latent_chunk,
+ timestep=timestep_per_chunk,
+ encoder_hidden_states=chunk_prompt_embeds,
+ encoder_attention_mask=encoder_attention_mask,
+ attention_kwargs=attention_kwargs,
+ denoising_range_num=num_chunks_in_window,
+ range_num=chunk_end_idx,
+ slice_point=chunk_start_idx,
+ kv_range=kv_range,
+ num_steps=num_steps,
+ distill_interval=distill_interval,
+ distill_nearly_clean_chunk=distill_nearly_clean_chunk,
+ return_dict=False,
+ )[0]
+
+ # Classifier-free guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond = self.transformer(
+ hidden_states=latent_chunk,
+ timestep=timestep_per_chunk,
+ encoder_hidden_states=chunk_negative_prompt_embeds,
+ encoder_attention_mask=encoder_attention_mask_neg,
+ attention_kwargs=attention_kwargs,
+ denoising_range_num=num_chunks_in_window,
+ range_num=chunk_end_idx,
+ slice_point=chunk_start_idx,
+ kv_range=kv_range,
+ num_steps=num_steps,
+ distill_interval=distill_interval,
+ distill_nearly_clean_chunk=distill_nearly_clean_chunk,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
+
+ # Update latent chunk using scheduler step
+ # FlowMatchEulerDiscreteScheduler implements: x_new = x_old + dt * velocity
+ # For per-chunk timesteps, we use manual integration since scheduler doesn't support
+ # different timesteps for different chunks in a single call
+
+ # Calculate dt for each chunk
+ # Get next timesteps for integration
+ next_timestep_indices = [min(idx + 1, len(timesteps) - 1) for idx in timestep_indices]
+ next_timesteps = timesteps[next_timestep_indices]
+
+ # Convert to sigmas (time in [0, 1] range)
+ current_sigmas = current_timesteps / self.scheduler.config.num_train_timesteps
+ next_sigmas = next_timesteps / self.scheduler.config.num_train_timesteps
+
+ # Reshape for per-chunk application: [batch_size, num_chunks_in_window, 1, 1, 1]
+ dt = (next_sigmas - current_sigmas).view(batch_size, num_chunks_in_window, 1, 1, 1)
+
+ # Reshape latents and velocity to separate chunks: [batch_size, channels, chunks, frames_per_chunk, h, w]
+ B, C, T, H, W = latent_chunk.shape
+ frames_per_chunk = T // num_chunks_in_window
+ latent_chunk = latent_chunk.view(B, C, num_chunks_in_window, frames_per_chunk, H, W)
+ noise_pred = noise_pred.view(B, C, num_chunks_in_window, frames_per_chunk, H, W)
+
+ # Apply Euler integration per chunk: x_new = x_old + dt * velocity
+ latent_chunk = latent_chunk + dt * noise_pred
+
+ # Reshape back to original shape
+ latent_chunk = latent_chunk.view(B, C, T, H, W)
+
+ # Write back to full latents
+ latents[:, :, latent_start:latent_end] = latent_chunk
+
+ # Update chunk denoise counts
+ for chunk_idx in range(chunk_start_idx, chunk_end_idx):
+ chunk_denoise_count[chunk_idx] += 1
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(
+ self,
+ stage_idx * denoise_step_per_stage + denoise_idx,
+ current_timesteps[0],
+ callback_kwargs,
+ )
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return Magi1PipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1_i2v.py b/src/diffusers/pipelines/magi1/pipeline_magi1_i2v.py
new file mode 100644
index 000000000000..e875d15a4948
--- /dev/null
+++ b/src/diffusers/pipelines/magi1/pipeline_magi1_i2v.py
@@ -0,0 +1,1530 @@
+# Copyright 2025 The SandAI Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# MAGI-1 I2V Pipeline with Autoregressive Chunked Generation
+#
+# âś… IMPLEMENTED:
+# - Image-to-Video generation using prefix video conditioning
+# - Autoregressive chunked generation (always enabled, matching original MAGI-1)
+# - Window-based scheduling: chunk_width=6, window_size=4
+# - Progressive denoising across overlapping temporal windows
+# - Proper CFG with separate forward passes (diffusers style)
+# - Input image encoding to VAE latent as clean prefix chunk
+#
+# ⚠️ CURRENT LIMITATION:
+# - No KV caching: attention is recomputed for previous chunks
+# - This is less efficient than the original but fully functional
+#
+# ⏳ FUTURE OPTIMIZATIONS (when diffusers adds generic KV caching):
+# 1. **KV Cache Management**:
+# - Cache attention keys/values for previously denoised chunks
+# - Reuse cached computations instead of recomputing
+# - Will significantly speed up generation (2-3x faster expected)
+#
+# 2. **Special Token Support** (optional enhancement):
+# - Duration tokens: indicate how many chunks remain to generate
+# - Quality tokens: HQ_TOKEN for high-quality generation
+# - Style tokens: THREE_D_MODEL_TOKEN, TWO_D_ANIME_TOKEN
+# - Motion tokens: STATIC_FIRST_FRAMES_TOKEN, DYNAMIC_FIRST_FRAMES_TOKEN
+#
+# 3. **Streaming Generation**:
+# - Yield clean chunks as they complete (generator pattern)
+# - Enable real-time preview during generation
+#
+# Reference: https://github.com/SandAI/MAGI-1/blob/main/inference/pipeline/video_generate.py
+
+import html
+import re
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import PIL.Image
+import torch
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...loaders import Magi1LoraLoaderMixin
+from ...models import AutoencoderKLMagi1, Magi1Transformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import Magi1PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+def generate_chunk_sequences(chunk_num: int, window_size: int, chunk_offset: int = 0):
+ """
+ Generate chunk scheduling sequences for autoregressive video generation.
+
+ Args:
+ chunk_num: Total number of chunks to generate
+ window_size: Number of chunks to process in each window
+ chunk_offset: Number of clean prefix chunks (for I2V/V2V)
+
+ Returns:
+ ```
+ clip_start: Start index of chunks to process
+ clip_end: End index of chunks to process
+ t_start: Start index in time dimension
+ t_end: End index in time dimension
+ ```
+
+ Examples:
+ ```
+ chunk_num=8, window_size=4, chunk_offset=0
+ Stage 0: Process chunks [0:1], denoise chunk 0
+ Stage 1: Process chunks [0:2], denoise chunk 1
+ Stage 2: Process chunks [0:3], denoise chunk 2
+ Stage 3: Process chunks [0:4], denoise chunk 3
+ Stage 4: Process chunks [1:5], denoise chunk 4
+ ...
+ ```
+ """
+ start_index = chunk_offset
+ end_index = chunk_num + window_size - 1
+
+ clip_start = [max(chunk_offset, i - window_size + 1) for i in range(start_index, end_index)]
+ clip_end = [min(chunk_num, i + 1) for i in range(start_index, end_index)]
+
+ t_start = [max(0, i - chunk_num + 1) for i in range(start_index, end_index)]
+ t_end = [
+ min(window_size, i - chunk_offset + 1) if i - chunk_offset < window_size else window_size
+ for i in range(start_index, end_index)
+ ]
+
+ return clip_start, clip_end, t_start, t_end
+
+
+def load_special_tokens(special_tokens_path: Optional[str] = None) -> Optional[Dict[str, torch.Tensor]]:
+ """
+ Load special conditioning tokens from numpy file.
+
+ Args:
+ special_tokens_path: Path to special_tokens.npz file. If None, returns None (no special tokens).
+
+ Returns:
+ Dictionary mapping token names to embeddings, or None if path not provided or file doesn't exist.
+ """
+ if special_tokens_path is None:
+ return None
+
+ try:
+ import os
+
+ import numpy as np
+
+ if not os.path.exists(special_tokens_path):
+ logger.warning(f"Special tokens file not found at {special_tokens_path}, skipping special token loading.")
+ return None
+
+ special_token_data = np.load(special_tokens_path)
+ caption_token = torch.tensor(special_token_data["caption_token"].astype(np.float16))
+ logo_token = torch.tensor(special_token_data["logo_token"].astype(np.float16))
+ other_tokens = special_token_data["other_tokens"]
+
+ tokens = {
+ "CAPTION_TOKEN": caption_token,
+ "LOGO_TOKEN": logo_token,
+ "TRANS_TOKEN": torch.tensor(other_tokens[:1].astype(np.float16)),
+ "HQ_TOKEN": torch.tensor(other_tokens[1:2].astype(np.float16)),
+ "STATIC_FIRST_FRAMES_TOKEN": torch.tensor(other_tokens[2:3].astype(np.float16)),
+ "DYNAMIC_FIRST_FRAMES_TOKEN": torch.tensor(other_tokens[3:4].astype(np.float16)),
+ "BORDERNESS_TOKEN": torch.tensor(other_tokens[4:5].astype(np.float16)),
+ "THREE_D_MODEL_TOKEN": torch.tensor(other_tokens[15:16].astype(np.float16)),
+ "TWO_D_ANIME_TOKEN": torch.tensor(other_tokens[16:17].astype(np.float16)),
+ }
+
+ # Duration tokens (8 total, representing 1-8 chunks remaining)
+ for i in range(8):
+ tokens[f"DURATION_TOKEN_{i + 1}"] = torch.tensor(other_tokens[i + 7 : i + 8].astype(np.float16))
+
+ logger.info(f"Loaded {len(tokens)} special tokens from {special_tokens_path}")
+ return tokens
+ except Exception as e:
+ logger.warning(f"Failed to load special tokens: {e}")
+ return None
+
+
+def prepare_i2v_embeddings(
+ prompt_embeds: torch.Tensor,
+ negative_prompt_embeds: Optional[torch.Tensor],
+ num_chunks: int,
+ clean_chunk_num: int,
+ max_sequence_length: int = 800,
+ prompt_mask: Optional[torch.Tensor] = None,
+ negative_mask: Optional[torch.Tensor] = None,
+) -> Tuple[
+ torch.Tensor,
+ Optional[torch.Tensor],
+ Optional[torch.Tensor],
+ Optional[torch.Tensor],
+]:
+ """
+ Prepare per-chunk text embeddings for I2V generation.
+
+ In I2V, clean prefix chunks (from the input image) use null embeddings, while chunks to be denoised use the actual
+ text embeddings.
+
+ Args:
+ prompt_embeds: Text embeddings [batch_size, seq_len, hidden_dim]
+ negative_prompt_embeds: Negative text embeddings (optional)
+ num_chunks: Total number of chunks
+ clean_chunk_num: Number of clean prefix chunks (typically 1 for I2V single image)
+ max_sequence_length: Maximum sequence length
+
+ Returns:
+ - prompt_embeds_per_chunk: [B, num_chunks, L, D]
+ - negative_prompt_embeds_per_chunk: [B, num_chunks, L, D] or None
+ - prompt_masks_per_chunk: [B, num_chunks, L] or None
+ - negative_masks_per_chunk: [B, num_chunks, L] or None
+ """
+ batch_size = prompt_embeds.shape[0]
+ seq_len = prompt_embeds.shape[1]
+ hidden_dim = prompt_embeds.shape[2]
+ device = prompt_embeds.device
+ dtype = prompt_embeds.dtype
+
+ # Number of chunks that need denoising
+ denoise_chunk_num = num_chunks - clean_chunk_num
+
+ # Create null embeddings (zeros) for clean chunks
+ null_embeds = torch.zeros(batch_size, 1, seq_len, hidden_dim, device=device, dtype=dtype)
+
+ # Expand prompt embeddings for denoise chunks
+ # Shape: [batch_size, denoise_chunk_num, seq_len, hidden_dim]
+ denoise_embeds = prompt_embeds.unsqueeze(1).repeat(1, denoise_chunk_num, 1, 1)
+
+ # Concatenate: [null_embeds for clean chunks] + [text_embeds for denoise chunks]
+ # Shape: [batch_size, num_chunks, seq_len, hidden_dim]
+ if clean_chunk_num > 0:
+ null_embeds_expanded = null_embeds.repeat(1, clean_chunk_num, 1, 1)
+ prompt_embeds_per_chunk = torch.cat([null_embeds_expanded, denoise_embeds], dim=1)
+ else:
+ prompt_embeds_per_chunk = denoise_embeds
+
+ # Build masks per chunk (zeros for clean chunks, prompt_mask for denoise chunks)
+ prompt_masks_per_chunk = None
+ negative_masks_per_chunk = None
+
+ if prompt_mask is not None:
+ denoise_masks = prompt_mask.unsqueeze(1).repeat(1, denoise_chunk_num, 1)
+ if clean_chunk_num > 0:
+ null_masks = torch.zeros(prompt_mask.shape[0], clean_chunk_num, prompt_mask.shape[1], device=device, dtype=prompt_mask.dtype)
+ prompt_masks_per_chunk = torch.cat([null_masks, denoise_masks], dim=1)
+ else:
+ prompt_masks_per_chunk = denoise_masks
+
+ # Same for negative embeddings
+ if negative_prompt_embeds is not None:
+ denoise_neg_embeds = negative_prompt_embeds.unsqueeze(1).repeat(1, denoise_chunk_num, 1, 1)
+ if clean_chunk_num > 0:
+ null_neg_embeds_expanded = null_embeds.repeat(1, clean_chunk_num, 1, 1)
+ negative_prompt_embeds_per_chunk = torch.cat([null_neg_embeds_expanded, denoise_neg_embeds], dim=1)
+ else:
+ negative_prompt_embeds_per_chunk = denoise_neg_embeds
+ else:
+ negative_prompt_embeds_per_chunk = None
+
+ if negative_mask is not None:
+ denoise_neg_masks = negative_mask.unsqueeze(1).repeat(1, denoise_chunk_num, 1)
+ if clean_chunk_num > 0:
+ null_neg_masks = torch.zeros(negative_mask.shape[0], clean_chunk_num, negative_mask.shape[1], device=device, dtype=negative_mask.dtype)
+ negative_masks_per_chunk = torch.cat([null_neg_masks, denoise_neg_masks], dim=1)
+ else:
+ negative_masks_per_chunk = denoise_neg_masks
+
+ return (
+ prompt_embeds_per_chunk,
+ negative_prompt_embeds_per_chunk,
+ prompt_masks_per_chunk,
+ negative_masks_per_chunk,
+ )
+
+
+def prepend_special_tokens(
+ prompt_embeds: torch.Tensor,
+ special_tokens: Optional[Dict[str, torch.Tensor]],
+ use_hq_token: bool = False,
+ use_3d_style: bool = False,
+ use_2d_anime_style: bool = False,
+ use_static_first_frames: bool = False,
+ use_dynamic_first_frames: bool = False,
+ max_sequence_length: int = 800,
+) -> torch.Tensor:
+ """
+ Prepend special conditioning tokens to text embeddings.
+
+ Args:
+ prompt_embeds: Text embeddings [batch, seq_len, hidden_dim]
+ special_tokens: Dictionary of special token embeddings
+ use_hq_token: Whether to add high-quality token
+ use_3d_style: Whether to add 3D model style token
+ use_2d_anime_style: Whether to add 2D anime style token
+ use_static_first_frames: Whether to add static motion token
+ use_dynamic_first_frames: Whether to add dynamic motion token
+ max_sequence_length: Maximum sequence length after prepending
+
+ Returns:
+ Text embeddings with special tokens prepended
+ """
+ if special_tokens is None:
+ return prompt_embeds
+
+ device = prompt_embeds.device
+ dtype = prompt_embeds.dtype
+ batch_size, seq_len, hidden_dim = prompt_embeds.shape
+
+ # Collect tokens to prepend (in order: motion, quality, style)
+ tokens_to_add = []
+ if use_static_first_frames and "STATIC_FIRST_FRAMES_TOKEN" in special_tokens:
+ tokens_to_add.append(special_tokens["STATIC_FIRST_FRAMES_TOKEN"])
+ if use_dynamic_first_frames and "DYNAMIC_FIRST_FRAMES_TOKEN" in special_tokens:
+ tokens_to_add.append(special_tokens["DYNAMIC_FIRST_FRAMES_TOKEN"])
+ if use_hq_token and "HQ_TOKEN" in special_tokens:
+ tokens_to_add.append(special_tokens["HQ_TOKEN"])
+ if use_3d_style and "THREE_D_MODEL_TOKEN" in special_tokens:
+ tokens_to_add.append(special_tokens["THREE_D_MODEL_TOKEN"])
+ if use_2d_anime_style and "TWO_D_ANIME_TOKEN" in special_tokens:
+ tokens_to_add.append(special_tokens["TWO_D_ANIME_TOKEN"])
+
+ # Prepend tokens
+ for token in tokens_to_add:
+ token = token.to(device=device, dtype=dtype).unsqueeze(0).expand(batch_size, -1, -1)
+ prompt_embeds = torch.cat([token, prompt_embeds], dim=1)
+
+ # Truncate to max length
+ if prompt_embeds.shape[1] > max_sequence_length:
+ prompt_embeds = prompt_embeds[:, :max_sequence_length, :]
+
+ return prompt_embeds
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import Magi1ImageToVideoPipeline, AutoencoderKLMagi1, FlowMatchEulerDiscreteScheduler
+ >>> from diffusers.utils import export_to_video, load_image
+
+ >>> model_id = "SandAI/Magi1-I2V-14B-480P-Diffusers"
+ >>> vae = AutoencoderKLMagi1.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+
+ >>> # IMPORTANT: MAGI-1 requires shift=3.0 for the scheduler (SD3-style time resolution transform)
+ >>> scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler", shift=3.0)
+
+ >>> pipe = Magi1ImageToVideoPipeline.from_pretrained(model_id, vae=vae, scheduler=scheduler, torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
+ ... )
+ >>> prompt = (
+ ... "An astronaut walking on the moon's surface, with the Earth visible in the background. "
+ ... "The astronaut moves slowly in a low-gravity environment."
+ ... )
+ >>> negative_prompt = "Bright tones, overexposed, static, blurred details, worst quality, low quality"
+
+ >>> output = pipe(
+ ... image=image,
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... height=720,
+ ... width=1280,
+ ... num_frames=81,
+ ... guidance_scale=5.0,
+ ... num_inference_steps=50,
+ ... ).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=16)
+ ```
+"""
+
+
+class Magi1ImageToVideoPipeline(DiffusionPipeline, Magi1LoraLoaderMixin):
+ r"""
+ Pipeline for image-to-video generation using Magi1.
+
+ MAGI-1 is a DiT-based video generation model that supports autoregressive chunked generation for long videos. This
+ I2V pipeline takes an input image and generates a video animation starting from that image.
+
+ **Note**: This implementation uses autoregressive chunked generation (chunk_width=6, window_size=4) as in the
+ original MAGI-1 paper. The input image is encoded to a latent representation and used as a clean prefix chunk to
+ condition the generation. Text prompts provide additional semantic guidance for the animation.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ transformer ([`Magi1Transformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A flow matching scheduler with Euler discretization, using SD3-style time resolution transform.
+ vae ([`AutoencoderKLMagi1`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ transformer: Magi1Transformer3DModel,
+ vae: AutoencoderKLMagi1,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_scale_factor_temporal = (
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
+ )
+ self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Special tokens for conditioning (optional)
+ self.special_tokens = None
+
+ def load_special_tokens_from_file(self, special_tokens_path: str):
+ """
+ Load special conditioning tokens from a numpy file.
+
+ Args:
+ special_tokens_path: Path to special_tokens.npz file
+ """
+ self.special_tokens = load_special_tokens(special_tokens_path)
+ if self.special_tokens is not None:
+ logger.info("Special tokens loaded successfully. You can now use quality, style, and motion control.")
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 800,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ # Repeat mask similarly and keep [B*num, L]
+ mask = mask.repeat(1, num_videos_per_prompt)
+ mask = mask.view(batch_size * num_videos_per_prompt, -1).to(device)
+
+ return prompt_embeds, mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 800,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_mask = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+ else:
+ prompt_mask = None
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds, negative_mask = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+ else:
+ negative_mask = None
+
+ return prompt_embeds, negative_prompt_embeds, prompt_mask, negative_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ image,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if image is None:
+ raise ValueError(
+ "Provide `image` for image-to-video generation. Cannot leave `image` undefined for I2V pipeline."
+ )
+ if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
+ raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
+
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ def prepare_latents(
+ self,
+ image: Optional[PipelineImageInput],
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """
+ Prepare latents for I2V generation, including encoding the input image as prefix_video.
+
+ Args:
+ image: Input image for I2V generation
+ batch_size: Batch size
+ num_channels_latents: Number of latent channels
+ height: Video height
+ width: Video width
+ num_frames: Total number of frames to generate
+ dtype: Data type
+ device: Device
+ generator: Random generator
+ latents: Pre-generated latents (optional)
+
+ Returns:
+ Tuple of (latents, prefix_video) where:
+ - latents: Random noise tensor for generation [batch, channels, num_latent_frames, H, W]
+ - prefix_video: Encoded image as clean latent [batch, channels, 1, H, W] (or None if no image)
+ """
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ # Prepare random latents for generation
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ # Encode input image to latent as prefix_video
+ prefix_video = None
+ if image is not None:
+ # Preprocess image to target size
+ image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
+
+ # Add temporal dimension: [batch, channels, height, width] -> [batch, channels, 1, height, width]
+ if image.ndim == 4:
+ image = image.unsqueeze(2)
+
+ # Encode to latent space using VAE
+ # VAE expects [batch, channels, frames, height, width]
+ if isinstance(generator, list):
+ prefix_video = [
+ retrieve_latents(self.vae.encode(image), sample_mode="sample", generator=g) for g in generator
+ ]
+ prefix_video = torch.cat(prefix_video)
+ else:
+ prefix_video = retrieve_latents(self.vae.encode(image), sample_mode="sample", generator=generator)
+ prefix_video = prefix_video.repeat(batch_size, 1, 1, 1, 1)
+
+ # Normalize latent using VAE's latent statistics
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(prefix_video.device, dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ prefix_video.device, dtype
+ )
+ prefix_video = prefix_video.to(dtype)
+ prefix_video = (prefix_video - latents_mean) * latents_std
+
+ return latents, prefix_video
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 800,
+ use_hq_token: bool = False,
+ use_3d_style: bool = False,
+ use_2d_anime_style: bool = False,
+ use_static_first_frames: bool = False,
+ use_dynamic_first_frames: bool = False,
+ enable_distillation: bool = False,
+ distill_nearly_clean_chunk_threshold: float = 0.3,
+ ):
+ r"""
+ The call function to the pipeline for image-to-video generation.
+
+ **Note**: This implementation uses autoregressive chunked generation (chunk_width=6, window_size=4) as in the
+ original MAGI-1 paper. The input image is encoded to a VAE latent and used as a clean prefix chunk to condition
+ the video generation. The implementation currently works without KV caching (attention is recomputed for
+ previous chunks), which is less efficient than the original but still functional. KV caching optimization will
+ be added when diffusers implements generic caching support for transformers.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to condition the video generation on. Must be an image, a list of images, or a
+ `torch.Tensor`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the video generation. If not defined, pass `prompt_embeds` instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to avoid during video generation. If not defined, pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (`guidance_scale` < `1`).
+ height (`int`, defaults to `480`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `832`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `81`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, negative_prompt_embeds will be generated from the `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated video. Choose between `"latent"`, `"pt"`, or `"np"`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`Magi1PipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `800`):
+ The maximum sequence length for the text encoder. Sequences longer than this will be truncated. MAGI-1
+ uses a max length of 800 tokens.
+ use_hq_token (`bool`, *optional*, defaults to `False`):
+ Whether to prepend the high-quality control token to the text embeddings. This token conditions the
+ model to generate higher quality outputs. Requires special tokens to be loaded via
+ `load_special_tokens_from_file`.
+ use_3d_style (`bool`, *optional*, defaults to `False`):
+ Whether to prepend the 3D model style token to the text embeddings. This token conditions the model to
+ generate outputs with 3D modeling aesthetics. Requires special tokens to be loaded.
+ use_2d_anime_style (`bool`, *optional*, defaults to `False`):
+ Whether to prepend the 2D anime style token to the text embeddings. This token conditions the model to
+ generate outputs with 2D anime aesthetics. Requires special tokens to be loaded.
+ use_static_first_frames (`bool`, *optional*, defaults to `False`):
+ Whether to prepend the static first frames token to the text embeddings. This token conditions the
+ model to start the video with minimal motion in the first few frames. Requires special tokens to be
+ loaded.
+ use_dynamic_first_frames (`bool`, *optional*, defaults to `False`):
+ Whether to prepend the dynamic first frames token to the text embeddings. This token conditions the
+ model to start the video with significant motion in the first few frames. Requires special tokens to be
+ loaded.
+ enable_distillation (`bool`, *optional*, defaults to `False`):
+ Whether to enable distillation mode. In distillation mode, the model uses modified timestep embeddings
+ to support distilled (faster) inference. This requires a distilled model checkpoint.
+ distill_nearly_clean_chunk_threshold (`float`, *optional*, defaults to `0.3`):
+ Threshold for identifying nearly-clean chunks in distillation mode. Chunks with timestep > threshold
+ are considered nearly clean and processed differently. Only used when `enable_distillation=True`.
+
+ Examples:
+
+ Returns:
+ [`~Magi1PipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`Magi1PipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated videos.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ image,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds, prompt_mask, negative_mask = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=self.text_encoder.dtype,
+ )
+
+ # 3.5. Prepend special tokens if requested
+ if self.special_tokens is not None and any(
+ [use_hq_token, use_3d_style, use_2d_anime_style, use_static_first_frames, use_dynamic_first_frames]
+ ):
+ prompt_embeds = prepend_special_tokens(
+ prompt_embeds=prompt_embeds,
+ special_tokens=self.special_tokens,
+ use_hq_token=use_hq_token,
+ use_3d_style=use_3d_style,
+ use_2d_anime_style=use_2d_anime_style,
+ use_static_first_frames=use_static_first_frames,
+ use_dynamic_first_frames=use_dynamic_first_frames,
+ max_sequence_length=max_sequence_length,
+ )
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = prepend_special_tokens(
+ prompt_embeds=negative_prompt_embeds,
+ special_tokens=self.special_tokens,
+ use_hq_token=use_hq_token,
+ use_3d_style=use_3d_style,
+ use_2d_anime_style=use_2d_anime_style,
+ use_static_first_frames=use_static_first_frames,
+ use_dynamic_first_frames=use_dynamic_first_frames,
+ max_sequence_length=max_sequence_length,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents, prefix_video = self.prepare_latents(
+ image,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Denoising loop (autoregressive chunked generation with I2V prefix conditioning)
+ # MAGI-1 I2V uses autoregressive generation with chunk_width=6 and window_size=4
+ # The input image is encoded as a clean prefix chunk and used to condition the generation
+ # Note: num_warmup_steps is calculated for compatibility but not used in progress bar logic
+ # because autoregressive generation has a different iteration structure (stages Ă— steps)
+ # For FlowMatchEulerDiscreteScheduler (order=1), this doesn't affect the results
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ # Autoregressive chunked generation parameters
+ chunk_width = 6 # Original MAGI-1 default
+ window_size = 4 # Original MAGI-1 default
+
+ num_latent_frames = latents.shape[2]
+ num_chunks = (num_latent_frames + chunk_width - 1) // chunk_width
+
+ # Calculate chunk_offset from prefix_video (for I2V, this is the clean image chunk)
+ chunk_offset = 0
+ if prefix_video is not None:
+ # prefix_video has shape [batch, channels, 1, height, width] for I2V
+ # Calculate how many chunks are covered by the prefix
+ prefix_latent_frames = prefix_video.shape[2]
+ chunk_offset = prefix_latent_frames // chunk_width
+
+ # Pad prefix_video into latents at the beginning
+ # The prefix frames are already clean and don't need denoising
+ if prefix_latent_frames > 0:
+ prefix_video = prefix_video.to(latents.dtype)
+ latents[:, :, :prefix_latent_frames] = prefix_video
+
+ # Calculate chunk scheduling: which chunks to process at each stage
+ # chunk_offset skips the clean prefix chunks
+ clip_start, clip_end, t_start, t_end = generate_chunk_sequences(num_chunks, window_size, chunk_offset)
+ num_stages = len(clip_start)
+
+ # Prepare per-chunk text embeddings for I2V
+ # Clean chunks (from input image) use null embeddings, denoise chunks use text embeddings
+ (
+ prompt_embeds_per_chunk,
+ negative_prompt_embeds_per_chunk,
+ prompt_masks_per_chunk,
+ negative_masks_per_chunk,
+ ) = prepare_i2v_embeddings(
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ num_chunks=num_chunks,
+ clean_chunk_num=chunk_offset,
+ max_sequence_length=max_sequence_length,
+ prompt_mask=prompt_mask,
+ negative_mask=negative_mask,
+ )
+
+ # Number of denoising steps per stage
+ denoise_step_per_stage = len(timesteps) // window_size
+
+ # Track how many times each chunk has been denoised
+ chunk_denoise_count = {i: 0 for i in range(num_chunks)}
+
+ with self.progress_bar(total=num_stages * denoise_step_per_stage) as progress_bar:
+ for stage_idx in range(num_stages):
+ # Determine which chunks to process in this stage
+ chunk_start_idx = clip_start[stage_idx]
+ chunk_end_idx = clip_end[stage_idx]
+ t_start_idx = t_start[stage_idx]
+ t_end_idx = t_end[stage_idx]
+
+ # Extract chunk range in latent space
+ latent_start = chunk_start_idx * chunk_width
+ latent_end = min(chunk_end_idx * chunk_width, num_latent_frames)
+
+ # Number of chunks in current window
+ num_chunks_in_window = chunk_end_idx - chunk_start_idx
+
+ # Prepare per-chunk conditioning with duration/borderness tokens
+ # Duration tokens indicate how many chunks remain in the video
+ # Borderness tokens condition on chunk boundaries
+ chunk_prompt_embeds_list = []
+ chunk_negative_prompt_embeds_list = []
+ chunk_prompt_masks_list = []
+ chunk_negative_masks_list = []
+
+ if self.special_tokens is not None:
+ # Prepare per-chunk embeddings with duration tokens
+ # Each chunk gets a different duration token based on chunks remaining
+ for i, chunk_idx in enumerate(range(chunk_start_idx, chunk_end_idx)):
+ chunks_remaining = num_chunks - chunk_idx - 1
+ # Duration token ranges from 1-8 chunks
+ duration_idx = min(chunks_remaining, 7) + 1
+
+ # Get base embeddings for this chunk (clean chunks have null embeds, denoise chunks have text embeds)
+ token_embeds = prompt_embeds_per_chunk[:, chunk_idx].clone()
+
+ # Add duration and borderness tokens for this chunk
+ if f"DURATION_TOKEN_{duration_idx}" in self.special_tokens:
+ duration_token = self.special_tokens[f"DURATION_TOKEN_{duration_idx}"]
+ duration_token = duration_token.to(device=token_embeds.device, dtype=token_embeds.dtype)
+ duration_token = duration_token.unsqueeze(0).expand(batch_size, -1, -1)
+ token_embeds = torch.cat([duration_token, token_embeds], dim=1)
+
+ if "BORDERNESS_TOKEN" in self.special_tokens:
+ borderness_token = self.special_tokens["BORDERNESS_TOKEN"]
+ borderness_token = borderness_token.to(
+ device=token_embeds.device, dtype=token_embeds.dtype
+ )
+ borderness_token = borderness_token.unsqueeze(0).expand(batch_size, -1, -1)
+ token_embeds = torch.cat([borderness_token, token_embeds], dim=1)
+
+ # Build mask for this chunk from base per-chunk mask
+ token_mask = (
+ prompt_masks_per_chunk[:, chunk_idx]
+ if prompt_masks_per_chunk is not None
+ else torch.ones(batch_size, token_embeds.shape[1], device=token_embeds.device, dtype=torch.int64)
+ )
+ add_count = 0
+ if f"DURATION_TOKEN_{duration_idx}" in self.special_tokens:
+ add_count += 1
+ if "BORDERNESS_TOKEN" in self.special_tokens:
+ add_count += 1
+ if add_count > 0:
+ prepend = torch.ones(batch_size, add_count, dtype=token_mask.dtype, device=token_mask.device)
+ token_mask = torch.cat([prepend, token_mask], dim=1)
+ # Truncate to max length
+ if token_embeds.shape[1] > max_sequence_length:
+ token_embeds = token_embeds[:, :max_sequence_length, :]
+ token_mask = token_mask[:, :max_sequence_length]
+
+ chunk_prompt_embeds_list.append(token_embeds)
+ chunk_prompt_masks_list.append(token_mask)
+
+ # Same for negative prompts
+ if self.do_classifier_free_guidance and negative_prompt_embeds_per_chunk is not None:
+ neg_token_embeds = negative_prompt_embeds_per_chunk[:, chunk_idx].clone()
+
+ if f"DURATION_TOKEN_{duration_idx}" in self.special_tokens:
+ duration_token = self.special_tokens[f"DURATION_TOKEN_{duration_idx}"]
+ duration_token = duration_token.to(
+ device=neg_token_embeds.device, dtype=neg_token_embeds.dtype
+ )
+ duration_token = duration_token.unsqueeze(0).expand(batch_size, -1, -1)
+ neg_token_embeds = torch.cat([duration_token, neg_token_embeds], dim=1)
+
+ if "BORDERNESS_TOKEN" in self.special_tokens:
+ borderness_token = self.special_tokens["BORDERNESS_TOKEN"]
+ borderness_token = borderness_token.to(
+ device=neg_token_embeds.device, dtype=neg_token_embeds.dtype
+ )
+ borderness_token = borderness_token.unsqueeze(0).expand(batch_size, -1, -1)
+ neg_token_embeds = torch.cat([borderness_token, neg_token_embeds], dim=1)
+
+ # Build negative mask for this chunk
+ neg_mask = (
+ negative_masks_per_chunk[:, chunk_idx]
+ if negative_masks_per_chunk is not None
+ else torch.ones(batch_size, neg_token_embeds.shape[1], device=neg_token_embeds.device, dtype=torch.int64)
+ )
+ add_count = 0
+ if f"DURATION_TOKEN_{duration_idx}" in self.special_tokens:
+ add_count += 1
+ if "BORDERNESS_TOKEN" in self.special_tokens:
+ add_count += 1
+ if add_count > 0:
+ prepend = torch.ones(batch_size, add_count, dtype=neg_mask.dtype, device=neg_mask.device)
+ neg_mask = torch.cat([prepend, neg_mask], dim=1)
+ if neg_token_embeds.shape[1] > max_sequence_length:
+ neg_token_embeds = neg_token_embeds[:, :max_sequence_length, :]
+ neg_mask = neg_mask[:, :max_sequence_length]
+
+ chunk_negative_prompt_embeds_list.append(neg_token_embeds)
+ chunk_negative_masks_list.append(neg_mask)
+
+ # Denoise this chunk range for denoise_step_per_stage steps
+ for denoise_idx in range(denoise_step_per_stage):
+ if self.interrupt:
+ break
+
+ # Calculate timestep index for each chunk in the current window
+ # Chunks at different stages get different timesteps based on their denoise progress
+ timestep_indices = []
+ for offset in range(num_chunks_in_window):
+ # Map offset within window to time index
+ t_idx_within_window = t_start_idx + offset
+ if t_idx_within_window < t_end_idx:
+ # This chunk is actively being denoised in this window
+ t_idx = t_idx_within_window * denoise_step_per_stage + denoise_idx
+ else:
+ # This chunk is beyond the active window, use max timestep (it's already cleaner)
+ t_idx = min((window_size - 1) * denoise_step_per_stage + denoise_idx, len(timesteps) - 1)
+ timestep_indices.append(t_idx)
+
+ # Reverse order: chunks further from start are noisier
+ timestep_indices.reverse()
+
+ # Get actual timesteps (reversed order: high noise to low noise)
+ current_timesteps = timesteps[timestep_indices]
+
+ # Create per-chunk timestep tensor: [batch_size, num_chunks_in_window]
+ # Each chunk gets its own timestep based on how many times it's been denoised
+ timestep_per_chunk = current_timesteps.unsqueeze(0).expand(batch_size, -1)
+
+ # Store first timestep for progress tracking
+ self._current_timestep = current_timesteps[0]
+
+ # Extract chunk
+ latent_chunk = latents[:, :, latent_start:latent_end].to(transformer_dtype)
+
+ # Prepare distillation parameters if enabled
+ num_steps = None
+ distill_interval = None
+ distill_nearly_clean_chunk = None
+
+ if enable_distillation:
+ # distill_interval represents the time interval between denoising steps
+ distill_interval = len(timesteps) / num_inference_steps
+
+ # Determine if chunks are nearly clean (low noise) based on their timesteps
+ # Check the first active chunk's timestep (after reversing, this is the noisiest chunk being actively denoised)
+ # Normalize timestep to [0, 1] range where 0=clean, 1=noise
+ nearly_clean_chunk_t = current_timesteps[0].item() / self.scheduler.config.num_train_timesteps
+ distill_nearly_clean_chunk = nearly_clean_chunk_t < distill_nearly_clean_chunk_threshold
+
+ num_steps = num_inference_steps
+
+ # Prepare per-chunk embeddings
+ # The transformer expects embeddings in shape [batch_size * num_chunks_in_window, seq_len, hidden_dim]
+ # Each chunk gets its own embedding with appropriate duration/borderness tokens
+ if chunk_prompt_embeds_list:
+ # Stack per-chunk embeddings: [num_chunks_in_window, batch_size, seq_len, hidden_dim]
+ chunk_prompt_embeds = torch.stack(chunk_prompt_embeds_list, dim=0)
+ # Reshape to [batch_size * num_chunks_in_window, seq_len, hidden_dim]
+ chunk_prompt_embeds = chunk_prompt_embeds.transpose(0, 1).flatten(0, 1)
+
+ if chunk_negative_prompt_embeds_list:
+ chunk_negative_prompt_embeds = torch.stack(chunk_negative_prompt_embeds_list, dim=0)
+ chunk_negative_prompt_embeds = chunk_negative_prompt_embeds.transpose(0, 1).flatten(0, 1)
+ else:
+ chunk_negative_prompt_embeds = None
+ else:
+ # Fallback: use per-chunk embeddings without special tokens
+ # Extract embeddings for the current chunk range
+ chunk_prompt_embeds = prompt_embeds_per_chunk[:, chunk_start_idx:chunk_end_idx]
+ chunk_prompt_embeds = chunk_prompt_embeds.flatten(0, 1)
+ chunk_prompt_masks = (
+ prompt_masks_per_chunk[:, chunk_start_idx:chunk_end_idx].flatten(0, 1)
+ if prompt_masks_per_chunk is not None
+ else None
+ )
+
+ if negative_prompt_embeds_per_chunk is not None:
+ chunk_negative_prompt_embeds = negative_prompt_embeds_per_chunk[
+ :, chunk_start_idx:chunk_end_idx
+ ]
+ chunk_negative_prompt_embeds = chunk_negative_prompt_embeds.flatten(0, 1)
+ chunk_negative_masks = (
+ negative_masks_per_chunk[:, chunk_start_idx:chunk_end_idx].flatten(0, 1)
+ if negative_masks_per_chunk is not None
+ else None
+ )
+ else:
+ chunk_negative_prompt_embeds = None
+ chunk_negative_masks = None
+
+ # Create encoder attention mask(s) from tokenizer masks
+ if chunk_prompt_embeds_list:
+ prompt_masks_stacked = torch.stack(chunk_prompt_masks_list, dim=0).transpose(0, 1).flatten(0, 1)
+ else:
+ prompt_masks_stacked = chunk_prompt_masks
+ if prompt_masks_stacked is None:
+ prompt_masks_stacked = torch.ones(
+ batch_size * num_chunks_in_window,
+ chunk_prompt_embeds.shape[1],
+ dtype=torch.int64,
+ device=chunk_prompt_embeds.device,
+ )
+ encoder_attention_mask = prompt_masks_stacked.to(dtype=chunk_prompt_embeds.dtype).view(
+ batch_size * num_chunks_in_window, 1, 1, -1
+ )
+
+ if self.do_classifier_free_guidance:
+ if chunk_negative_prompt_embeds_list:
+ negative_masks_stacked = (
+ torch.stack(chunk_negative_masks_list, dim=0).transpose(0, 1).flatten(0, 1)
+ )
+ else:
+ negative_masks_stacked = chunk_negative_masks
+ if negative_masks_stacked is None and chunk_negative_prompt_embeds is not None:
+ negative_masks_stacked = torch.ones(
+ batch_size * num_chunks_in_window,
+ chunk_negative_prompt_embeds.shape[1],
+ dtype=torch.int64,
+ device=chunk_negative_prompt_embeds.device,
+ )
+ encoder_attention_mask_neg = (
+ negative_masks_stacked.to(dtype=chunk_prompt_embeds.dtype).view(
+ batch_size * num_chunks_in_window, 1, 1, -1
+ )
+ if negative_masks_stacked is not None
+ else encoder_attention_mask
+ )
+
+ # Pad prefix video into latent_chunk if applicable (I2V)
+ # This ensures clean prefix frames are maintained during denoising
+ if prefix_video is not None:
+ prefix_length = prefix_video.shape[2]
+ prefix_video_start = chunk_start_idx * chunk_width
+
+ if prefix_length > prefix_video_start:
+ # Calculate how many frames to pad
+ padding_length = min(prefix_length - prefix_video_start, latent_chunk.shape[2])
+ prefix_video_end = prefix_video_start + padding_length
+
+ # Pad clean prefix frames into latent_chunk
+ latent_chunk = latent_chunk.clone()
+ latent_chunk[:, :, :padding_length] = prefix_video[
+ :, :, prefix_video_start:prefix_video_end
+ ]
+
+ # Set timesteps for clean prefix chunks to maximum (indicates "already clean")
+ # This matches original MAGI-1's try_pad_prefix_video logic
+ num_clean_chunks_in_window = padding_length // chunk_width
+ if num_clean_chunks_in_window > 0:
+ # Get max timestep from scheduler
+ max_timestep = timesteps[0]
+ timestep_per_chunk[:, :num_clean_chunks_in_window] = max_timestep
+
+ # Generate KV range for autoregressive attention
+ # Each chunk can attend to itself and all previous chunks in the sequence
+ # Shape: [batch_size * num_chunks_in_window, 2] where each row is [start_token_idx, end_token_idx]
+ chunk_token_nums = (
+ (latent_chunk.shape[2] // num_chunks_in_window) # frames per chunk
+ * (latent_chunk.shape[3] // self.transformer.config.patch_size[1]) # height tokens
+ * (latent_chunk.shape[4] // self.transformer.config.patch_size[2]) # width tokens
+ )
+ kv_range = []
+ for b in range(batch_size):
+ # batch_offset should be based on total chunks in the video, not chunk_end_idx
+ batch_offset = b * num_chunks
+ for c in range(num_chunks_in_window):
+ # This chunk can attend from the start of the video up to its own end
+ chunk_global_idx = chunk_start_idx + c
+ k_start = batch_offset * chunk_token_nums
+ k_end = (batch_offset + chunk_global_idx + 1) * chunk_token_nums
+ kv_range.append([k_start, k_end])
+ kv_range = torch.tensor(kv_range, dtype=torch.int32, device=device)
+
+ # Predict noise (conditional)
+ # Note: MAGI-1 uses velocity field (flow matching), but following diffusers convention
+ # we use noise_pred naming for consistency across all pipelines
+ noise_pred = self.transformer(
+ hidden_states=latent_chunk,
+ timestep=timestep_per_chunk,
+ encoder_hidden_states=chunk_prompt_embeds,
+ encoder_attention_mask=encoder_attention_mask,
+ attention_kwargs=attention_kwargs,
+ denoising_range_num=num_chunks_in_window,
+ range_num=chunk_end_idx,
+ slice_point=chunk_start_idx,
+ kv_range=kv_range,
+ num_steps=num_steps,
+ distill_interval=distill_interval,
+ distill_nearly_clean_chunk=distill_nearly_clean_chunk,
+ return_dict=False,
+ )[0]
+
+ # Classifier-free guidance: separate forward pass for unconditional
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond = self.transformer(
+ hidden_states=latent_chunk,
+ timestep=timestep_per_chunk,
+ encoder_hidden_states=chunk_negative_prompt_embeds,
+ encoder_attention_mask=encoder_attention_mask_neg,
+ attention_kwargs=attention_kwargs,
+ denoising_range_num=num_chunks_in_window,
+ range_num=chunk_end_idx,
+ slice_point=chunk_start_idx,
+ kv_range=kv_range,
+ num_steps=num_steps,
+ distill_interval=distill_interval,
+ distill_nearly_clean_chunk=distill_nearly_clean_chunk,
+ return_dict=False,
+ )[0]
+ # Apply classifier-free guidance
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
+
+ # CRITICAL: Apply per-chunk Euler integration with different delta_t for each chunk
+ # This matches the original MAGI-1's integrate() function
+ # Each chunk in the window is at a different noise level and needs its own time step
+
+ # Calculate per-chunk timesteps for integration
+ # Get current timesteps (t_before)
+ current_timesteps = timesteps[timestep_indices]
+
+ # Get next timesteps (t_after) - one step forward for each chunk
+ next_timestep_indices = [min(idx + 1, len(timesteps) - 1) for idx in timestep_indices]
+ next_timesteps = timesteps[next_timestep_indices]
+
+ # Convert timesteps to sigmas (matching FlowMatchEulerDiscreteScheduler)
+ current_sigmas = current_timesteps / self.scheduler.config.num_train_timesteps
+ next_sigmas = next_timesteps / self.scheduler.config.num_train_timesteps
+
+ # Calculate delta_t for each chunk: [num_chunks_in_window]
+ delta_t = next_sigmas - current_sigmas
+
+ # Reshape latent_chunk and noise_pred to separate chunks
+ # From: [batch, channels, frames, height, width]
+ # To: [batch, channels, num_chunks, chunk_width, height, width]
+ batch_size_actual, num_channels, total_frames, height_latent, width_latent = latent_chunk.shape
+
+ # Ensure total_frames is divisible by chunk_width for reshaping
+ # (it should be, but let's handle edge cases)
+ num_complete_chunks = total_frames // chunk_width
+ remainder_frames = total_frames % chunk_width
+
+ if remainder_frames == 0:
+ # Perfect division: reshape and apply per-chunk delta_t
+ latent_chunk = latent_chunk.reshape(
+ batch_size_actual,
+ num_channels,
+ num_complete_chunks,
+ chunk_width,
+ height_latent,
+ width_latent,
+ )
+ noise_pred = noise_pred.reshape(
+ batch_size_actual,
+ num_channels,
+ num_complete_chunks,
+ chunk_width,
+ height_latent,
+ width_latent,
+ )
+
+ # Apply Euler integration: x_chunk = x_chunk + velocity * delta_t
+ # delta_t shape: [num_chunks] -> broadcast to [1, 1, num_chunks, 1, 1, 1]
+ delta_t_broadcast = delta_t.reshape(1, 1, -1, 1, 1, 1).to(
+ latent_chunk.device, latent_chunk.dtype
+ )
+ latent_chunk = latent_chunk + noise_pred * delta_t_broadcast
+
+ # Reshape back to original dimensions
+ latent_chunk = latent_chunk.reshape(
+ batch_size_actual, num_channels, total_frames, height_latent, width_latent
+ )
+ else:
+ # Handle remainder frames separately (edge case for last incomplete chunk)
+ complete_frames = num_complete_chunks * chunk_width
+
+ # Process complete chunks
+ latent_chunk_complete = latent_chunk[:, :, :complete_frames]
+ noise_pred_complete = noise_pred[:, :, :complete_frames]
+
+ latent_chunk_complete = latent_chunk_complete.reshape(
+ batch_size_actual,
+ num_channels,
+ num_complete_chunks,
+ chunk_width,
+ height_latent,
+ width_latent,
+ )
+ noise_pred_complete = noise_pred_complete.reshape(
+ batch_size_actual,
+ num_channels,
+ num_complete_chunks,
+ chunk_width,
+ height_latent,
+ width_latent,
+ )
+
+ # Apply per-chunk delta_t to complete chunks
+ delta_t_broadcast = (
+ delta_t[:num_complete_chunks]
+ .reshape(1, 1, -1, 1, 1, 1)
+ .to(latent_chunk.device, latent_chunk.dtype)
+ )
+ latent_chunk_complete = latent_chunk_complete + noise_pred_complete * delta_t_broadcast
+ latent_chunk_complete = latent_chunk_complete.reshape(
+ batch_size_actual, num_channels, complete_frames, height_latent, width_latent
+ )
+
+ # Process remainder frames with last delta_t
+ if remainder_frames > 0:
+ latent_chunk_remainder = latent_chunk[:, :, complete_frames:]
+ noise_pred_remainder = noise_pred[:, :, complete_frames:]
+ delta_t_remainder = delta_t[-1].to(latent_chunk.device, latent_chunk.dtype)
+ latent_chunk_remainder = latent_chunk_remainder + noise_pred_remainder * delta_t_remainder
+
+ # Concatenate complete and remainder
+ latent_chunk = torch.cat([latent_chunk_complete, latent_chunk_remainder], dim=2)
+ else:
+ latent_chunk = latent_chunk_complete
+
+ # Write back to full latents
+ latents[:, :, latent_start:latent_end] = latent_chunk
+
+ # Update chunk denoise counts
+ for chunk_idx in range(chunk_start_idx, chunk_end_idx):
+ chunk_denoise_count[chunk_idx] += 1
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ # Use first timestep for callback (most representative)
+ callback_timestep = current_timesteps[0]
+ callback_outputs = callback_on_step_end(
+ self, stage_idx * denoise_step_per_stage + denoise_idx, callback_timestep, callback_kwargs
+ )
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return Magi1PipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1_v2v.py b/src/diffusers/pipelines/magi1/pipeline_magi1_v2v.py
new file mode 100644
index 000000000000..99d647e74ea3
--- /dev/null
+++ b/src/diffusers/pipelines/magi1/pipeline_magi1_v2v.py
@@ -0,0 +1,1527 @@
+# Copyright 2025 The SandAI Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# MAGI-1 V2V Pipeline with Autoregressive Chunked Generation
+#
+# âś… IMPLEMENTED:
+# - Video-to-Video generation using prefix video conditioning
+# - Autoregressive chunked generation (always enabled, matching original MAGI-1)
+# - Window-based scheduling: chunk_width=6, window_size=4
+# - Progressive denoising across overlapping temporal windows
+# - Proper CFG with separate forward passes (diffusers style)
+# - Input video frames encoding to VAE latent as clean prefix chunks
+#
+# ⚠️ CURRENT LIMITATION:
+# - No KV caching: attention is recomputed for previous chunks
+# - This is less efficient than the original but fully functional
+#
+# ⏳ FUTURE OPTIMIZATIONS (when diffusers adds generic KV caching):
+# 1. **KV Cache Management**:
+# - Cache attention keys/values for previously denoised chunks
+# - Reuse cached computations instead of recomputing
+# - Will significantly speed up generation (2-3x faster expected)
+#
+# 2. **Special Token Support** (optional enhancement):
+# - Duration tokens: indicate how many chunks remain to generate
+# - Quality tokens: HQ_TOKEN for high-quality generation
+# - Style tokens: THREE_D_MODEL_TOKEN, TWO_D_ANIME_TOKEN
+# - Motion tokens: STATIC_FIRST_FRAMES_TOKEN, DYNAMIC_FIRST_FRAMES_TOKEN
+#
+# 3. **Streaming Generation**:
+# - Yield clean chunks as they complete (generator pattern)
+# - Enable real-time preview during generation
+#
+# Reference: https://github.com/SandAI/MAGI-1/blob/main/inference/pipeline/video_generate.py
+
+import html
+import re
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import PIL.Image
+import torch
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import Magi1LoraLoaderMixin
+from ...models import AutoencoderKLMagi1, Magi1Transformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import Magi1PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+def generate_chunk_sequences(chunk_num: int, window_size: int, chunk_offset: int = 0):
+ """
+ Generate chunk scheduling sequences for autoregressive video generation.
+
+ Args:
+ chunk_num: Total number of chunks to generate
+ window_size: Number of chunks to process in each window
+ chunk_offset: Number of clean prefix chunks (for I2V/V2V)
+
+ Returns:
+ ```
+ clip_start: Start index of chunks to process
+ clip_end: End index of chunks to process
+ t_start: Start index in time dimension
+ t_end: End index in time dimension
+ ```
+
+ Examples:
+ ```
+ chunk_num=8, window_size=4, chunk_offset=0
+ Stage 0: Process chunks [0:1], denoise chunk 0
+ Stage 1: Process chunks [0:2], denoise chunk 1
+ Stage 2: Process chunks [0:3], denoise chunk 2
+ Stage 3: Process chunks [0:4], denoise chunk 3
+ Stage 4: Process chunks [1:5], denoise chunk 4
+ ...
+ ```
+ """
+ start_index = chunk_offset
+ end_index = chunk_num + window_size - 1
+
+ clip_start = [max(chunk_offset, i - window_size + 1) for i in range(start_index, end_index)]
+ clip_end = [min(chunk_num, i + 1) for i in range(start_index, end_index)]
+
+ t_start = [max(0, i - chunk_num + 1) for i in range(start_index, end_index)]
+ t_end = [
+ min(window_size, i - chunk_offset + 1) if i - chunk_offset < window_size else window_size
+ for i in range(start_index, end_index)
+ ]
+
+ return clip_start, clip_end, t_start, t_end
+
+
+def load_special_tokens(special_tokens_path: Optional[str] = None) -> Optional[Dict[str, torch.Tensor]]:
+ """
+ Load special conditioning tokens from numpy file.
+
+ Args:
+ special_tokens_path: Path to special_tokens.npz file. If None, returns None (no special tokens).
+
+ Returns:
+ Dictionary mapping token names to embeddings, or None if path not provided or file doesn't exist.
+ """
+ if special_tokens_path is None:
+ return None
+
+ try:
+ import os
+
+ import numpy as np
+
+ if not os.path.exists(special_tokens_path):
+ logger.warning(f"Special tokens file not found at {special_tokens_path}, skipping special token loading.")
+ return None
+
+ special_token_data = np.load(special_tokens_path)
+ caption_token = torch.tensor(special_token_data["caption_token"].astype(np.float16))
+ logo_token = torch.tensor(special_token_data["logo_token"].astype(np.float16))
+ other_tokens = special_token_data["other_tokens"]
+
+ tokens = {
+ "CAPTION_TOKEN": caption_token,
+ "LOGO_TOKEN": logo_token,
+ "TRANS_TOKEN": torch.tensor(other_tokens[:1].astype(np.float16)),
+ "HQ_TOKEN": torch.tensor(other_tokens[1:2].astype(np.float16)),
+ "STATIC_FIRST_FRAMES_TOKEN": torch.tensor(other_tokens[2:3].astype(np.float16)),
+ "DYNAMIC_FIRST_FRAMES_TOKEN": torch.tensor(other_tokens[3:4].astype(np.float16)),
+ "BORDERNESS_TOKEN": torch.tensor(other_tokens[4:5].astype(np.float16)),
+ "THREE_D_MODEL_TOKEN": torch.tensor(other_tokens[15:16].astype(np.float16)),
+ "TWO_D_ANIME_TOKEN": torch.tensor(other_tokens[16:17].astype(np.float16)),
+ }
+
+ # Duration tokens (8 total, representing 1-8 chunks remaining)
+ for i in range(8):
+ tokens[f"DURATION_TOKEN_{i + 1}"] = torch.tensor(other_tokens[i + 7 : i + 8].astype(np.float16))
+
+ logger.info(f"Loaded {len(tokens)} special tokens from {special_tokens_path}")
+ return tokens
+ except Exception as e:
+ logger.warning(f"Failed to load special tokens: {e}")
+ return None
+
+
+def prepend_special_tokens(
+ prompt_embeds: torch.Tensor,
+ special_tokens: Optional[Dict[str, torch.Tensor]],
+ use_hq_token: bool = False,
+ use_3d_style: bool = False,
+ use_2d_anime_style: bool = False,
+ use_static_first_frames: bool = False,
+ use_dynamic_first_frames: bool = False,
+ max_sequence_length: int = 800,
+) -> torch.Tensor:
+ """
+ Prepend special conditioning tokens to text embeddings.
+
+ Args:
+ prompt_embeds: Text embeddings [batch, seq_len, hidden_dim]
+ special_tokens: Dictionary of special token embeddings
+ use_hq_token: Whether to add high-quality token
+ use_3d_style: Whether to add 3D model style token
+ use_2d_anime_style: Whether to add 2D anime style token
+ use_static_first_frames: Whether to add static motion token
+ use_dynamic_first_frames: Whether to add dynamic motion token
+ max_sequence_length: Maximum sequence length after prepending
+
+ Returns:
+ Text embeddings with special tokens prepended
+ """
+ if special_tokens is None:
+ return prompt_embeds
+
+ device = prompt_embeds.device
+ dtype = prompt_embeds.dtype
+ batch_size, seq_len, hidden_dim = prompt_embeds.shape
+
+ # Collect tokens to prepend (in order: motion, quality, style)
+ tokens_to_add = []
+ if use_static_first_frames and "STATIC_FIRST_FRAMES_TOKEN" in special_tokens:
+ tokens_to_add.append(special_tokens["STATIC_FIRST_FRAMES_TOKEN"])
+ if use_dynamic_first_frames and "DYNAMIC_FIRST_FRAMES_TOKEN" in special_tokens:
+ tokens_to_add.append(special_tokens["DYNAMIC_FIRST_FRAMES_TOKEN"])
+ if use_hq_token and "HQ_TOKEN" in special_tokens:
+ tokens_to_add.append(special_tokens["HQ_TOKEN"])
+ if use_3d_style and "THREE_D_MODEL_TOKEN" in special_tokens:
+ tokens_to_add.append(special_tokens["THREE_D_MODEL_TOKEN"])
+ if use_2d_anime_style and "TWO_D_ANIME_TOKEN" in special_tokens:
+ tokens_to_add.append(special_tokens["TWO_D_ANIME_TOKEN"])
+
+ # Prepend tokens
+ for token in tokens_to_add:
+ token = token.to(device=device, dtype=dtype).unsqueeze(0).expand(batch_size, -1, -1)
+ prompt_embeds = torch.cat([token, prompt_embeds], dim=1)
+
+ # Truncate to max length
+ if prompt_embeds.shape[1] > max_sequence_length:
+ prompt_embeds = prompt_embeds[:, :max_sequence_length, :]
+
+ return prompt_embeds
+
+
+def prepare_v2v_embeddings(
+ prompt_embeds: torch.Tensor,
+ negative_prompt_embeds: Optional[torch.Tensor],
+ num_chunks: int,
+ clean_chunk_num: int,
+ max_sequence_length: int = 800,
+ prompt_mask: Optional[torch.Tensor] = None,
+ negative_mask: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
+ """
+ Prepare per-chunk text embeddings for V2V generation.
+
+ In V2V, clean prefix chunks (from the input video) use null embeddings, while chunks to be denoised use the actual
+ text embeddings.
+
+ Args:
+ prompt_embeds: Text embeddings [batch_size, seq_len, hidden_dim]
+ negative_prompt_embeds: Negative text embeddings (optional)
+ num_chunks: Total number of chunks
+ clean_chunk_num: Number of clean prefix chunks
+ max_sequence_length: Maximum sequence length
+
+ Returns:
+ - prompt_embeds_per_chunk: [B, num_chunks, L, D]
+ - negative_prompt_embeds_per_chunk: [B, num_chunks, L, D] or None
+ - prompt_masks_per_chunk: [B, num_chunks, L] or None
+ - negative_masks_per_chunk: [B, num_chunks, L] or None
+ """
+ batch_size = prompt_embeds.shape[0]
+ seq_len = prompt_embeds.shape[1]
+ hidden_dim = prompt_embeds.shape[2]
+ device = prompt_embeds.device
+ dtype = prompt_embeds.dtype
+
+ # Number of chunks that need denoising
+ denoise_chunk_num = num_chunks - clean_chunk_num
+
+ # Create null embeddings (zeros) for clean chunks
+ null_embeds = torch.zeros(batch_size, 1, seq_len, hidden_dim, device=device, dtype=dtype)
+
+ # Expand prompt embeddings for denoise chunks
+ # Shape: [batch_size, denoise_chunk_num, seq_len, hidden_dim]
+ denoise_embeds = prompt_embeds.unsqueeze(1).repeat(1, denoise_chunk_num, 1, 1)
+
+ # Concatenate: [null_embeds for clean chunks] + [text_embeds for denoise chunks]
+ # Shape: [batch_size, num_chunks, seq_len, hidden_dim]
+ if clean_chunk_num > 0:
+ null_embeds_expanded = null_embeds.repeat(1, clean_chunk_num, 1, 1)
+ prompt_embeds_per_chunk = torch.cat([null_embeds_expanded, denoise_embeds], dim=1)
+ else:
+ prompt_embeds_per_chunk = denoise_embeds
+
+ # Build prompt masks per chunk
+ prompt_masks_per_chunk = None
+ negative_masks_per_chunk = None
+ if prompt_mask is not None:
+ denoise_masks = prompt_mask.unsqueeze(1).repeat(1, denoise_chunk_num, 1)
+ if clean_chunk_num > 0:
+ null_masks = torch.zeros(prompt_mask.shape[0], clean_chunk_num, prompt_mask.shape[1], device=device, dtype=prompt_mask.dtype)
+ prompt_masks_per_chunk = torch.cat([null_masks, denoise_masks], dim=1)
+ else:
+ prompt_masks_per_chunk = denoise_masks
+
+ # Same for negative embeddings
+ if negative_prompt_embeds is not None:
+ denoise_neg_embeds = negative_prompt_embeds.unsqueeze(1).repeat(1, denoise_chunk_num, 1, 1)
+ if clean_chunk_num > 0:
+ null_neg_embeds_expanded = null_embeds.repeat(1, clean_chunk_num, 1, 1)
+ negative_prompt_embeds_per_chunk = torch.cat([null_neg_embeds_expanded, denoise_neg_embeds], dim=1)
+ else:
+ negative_prompt_embeds_per_chunk = denoise_neg_embeds
+ else:
+ negative_prompt_embeds_per_chunk = None
+ if negative_mask is not None:
+ denoise_neg_masks = negative_mask.unsqueeze(1).repeat(1, denoise_chunk_num, 1)
+ if clean_chunk_num > 0:
+ null_neg_masks = torch.zeros(negative_mask.shape[0], clean_chunk_num, negative_mask.shape[1], device=device, dtype=negative_mask.dtype)
+ negative_masks_per_chunk = torch.cat([null_neg_masks, denoise_neg_masks], dim=1)
+ else:
+ negative_masks_per_chunk = denoise_neg_masks
+
+ return (
+ prompt_embeds_per_chunk,
+ negative_prompt_embeds_per_chunk,
+ prompt_masks_per_chunk,
+ negative_masks_per_chunk,
+ )
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import Magi1VideoToVideoPipeline, AutoencoderKLMagi1, FlowMatchEulerDiscreteScheduler
+ >>> from diffusers.utils import export_to_video, load_video
+
+ >>> model_id = "SandAI/Magi1-V2V-14B-480P-Diffusers"
+ >>> vae = AutoencoderKLMagi1.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+
+ >>> # IMPORTANT: MAGI-1 requires shift=3.0 for the scheduler (SD3-style time resolution transform)
+ >>> scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler", shift=3.0)
+
+ >>> pipe = Magi1VideoToVideoPipeline.from_pretrained(model_id, vae=vae, scheduler=scheduler, torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> # Load prefix video (e.g., 24 frames)
+ >>> video = load_video("path/to/input_video.mp4", num_frames=24)
+ >>> prompt = (
+ ... "Continue this video with smooth camera motion and consistent style. "
+ ... "The scene evolves naturally with coherent motion."
+ ... )
+ >>> negative_prompt = "Bright tones, overexposed, static, blurred details, worst quality, low quality"
+
+ >>> output = pipe(
+ ... video=video,
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... height=480,
+ ... width=832,
+ ... num_frames=81, # Total frames including prefix
+ ... guidance_scale=5.0,
+ ... num_inference_steps=50,
+ ... ).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=16)
+ ```
+"""
+
+
+class Magi1VideoToVideoPipeline(DiffusionPipeline, Magi1LoraLoaderMixin):
+ r"""
+ Pipeline for video-to-video generation using Magi1.
+
+ MAGI-1 is a DiT-based video generation model that supports autoregressive chunked generation for long videos. This
+ V2V pipeline takes an input video and generates a continuation or extension of that video.
+
+ **Note**: This implementation uses autoregressive chunked generation (chunk_width=6, window_size=4) as in the
+ original MAGI-1 paper. The input video frames are encoded to latent representations and used as clean prefix chunks
+ to condition the generation. Text prompts provide additional semantic guidance for the video continuation.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ transformer ([`Magi1Transformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A flow matching scheduler with Euler discretization, using SD3-style time resolution transform.
+ vae ([`AutoencoderKLMagi1`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ transformer: Magi1Transformer3DModel,
+ vae: AutoencoderKLMagi1,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_scale_factor_temporal = (
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
+ )
+ self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Special tokens for conditioning (optional)
+ self.special_tokens = None
+
+ def load_special_tokens_from_file(self, special_tokens_path: str):
+ """
+ Load special conditioning tokens from a numpy file.
+
+ Args:
+ special_tokens_path: Path to special_tokens.npz file
+ """
+ self.special_tokens = load_special_tokens(special_tokens_path)
+ if self.special_tokens is not None:
+ logger.info("Special tokens loaded successfully. You can now use quality, style, and motion control.")
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 800,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+ mask = mask.repeat(1, num_videos_per_prompt)
+ mask = mask.view(batch_size * num_videos_per_prompt, -1).to(device)
+ return prompt_embeds, mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 800,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_mask = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+ else:
+ prompt_mask = None
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds, negative_mask = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+ else:
+ negative_mask = None
+
+ return prompt_embeds, negative_prompt_embeds, prompt_mask, negative_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ video,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if video is None:
+ raise ValueError(
+ "Provide `video` for video-to-video generation. Cannot leave `video` undefined for V2V pipeline."
+ )
+ if video is not None and not isinstance(video, list):
+ raise ValueError(f"`video` has to be of type `list` (list of PIL Images or tensors) but is {type(video)}")
+
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ def prepare_latents(
+ self,
+ video: Optional[List[PIL.Image.Image]],
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """
+ Prepare latents for V2V generation, including encoding the input video frames as prefix_video.
+
+ Args:
+ video: Input video frames for V2V generation (list of PIL Images)
+ batch_size: Batch size
+ num_channels_latents: Number of latent channels
+ height: Video height
+ width: Video width
+ num_frames: Total number of frames to generate (including prefix)
+ dtype: Data type
+ device: Device
+ generator: Random generator
+ latents: Pre-generated latents (optional)
+
+ Returns:
+ Tuple of (latents, prefix_video) where:
+ - latents: Random noise tensor for generation [batch, channels, num_latent_frames, H, W]
+ - prefix_video: Encoded video frames as clean latent [batch, channels, prefix_frames, H, W]
+ """
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ # Prepare random latents for generation
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ # Encode input video frames to latent as prefix_video
+ prefix_video = None
+ if video is not None and len(video) > 0:
+ # Preprocess video frames to target size
+ # video_processor.preprocess_video expects list of PIL Images
+ video_tensor = self.video_processor.preprocess_video(video, height=height, width=width).to(
+ device, dtype=torch.float32
+ )
+
+ # video_tensor shape: [batch, channels, num_frames, height, width]
+ # For single batch, expand if needed
+ if video_tensor.ndim == 4:
+ # [channels, num_frames, height, width] -> [1, channels, num_frames, height, width]
+ video_tensor = video_tensor.unsqueeze(0)
+
+ # Encode to latent space using VAE
+ # VAE expects [batch, channels, frames, height, width]
+ if isinstance(generator, list):
+ prefix_video = [
+ retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="sample", generator=g)
+ for g, vid in zip(generator, video_tensor)
+ ]
+ prefix_video = torch.cat(prefix_video)
+ else:
+ prefix_video = retrieve_latents(
+ self.vae.encode(video_tensor), sample_mode="sample", generator=generator
+ )
+ if prefix_video.shape[0] < batch_size:
+ prefix_video = prefix_video.repeat(batch_size, 1, 1, 1, 1)
+
+ # Normalize latent using VAE's latent statistics
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(prefix_video.device, dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ prefix_video.device, dtype
+ )
+ prefix_video = prefix_video.to(dtype)
+ prefix_video = (prefix_video - latents_mean) * latents_std
+
+ return latents, prefix_video
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ video: List[PIL.Image.Image],
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 800,
+ use_hq_token: bool = False,
+ use_3d_style: bool = False,
+ use_2d_anime_style: bool = False,
+ use_static_first_frames: bool = False,
+ use_dynamic_first_frames: bool = False,
+ enable_distillation: bool = False,
+ distill_nearly_clean_chunk_threshold: float = 0.3,
+ ):
+ r"""
+ The call function to the pipeline for video-to-video generation.
+
+ **Note**: This implementation uses autoregressive chunked generation (chunk_width=6, window_size=4) as in the
+ original MAGI-1 paper. The input video frames are encoded to VAE latents and used as clean prefix chunks to
+ condition the video generation. The implementation currently works without KV caching (attention is recomputed
+ for previous chunks), which is less efficient than the original but still functional. KV caching optimization
+ will be added when diffusers implements generic caching support for transformers.
+
+ Args:
+ video (`List[PIL.Image.Image]`):
+ The input video frames to condition the video generation on. Must be a list of PIL Images representing
+ the prefix video (e.g., first 24 frames of a video).
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the video generation. If not defined, pass `prompt_embeds` instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to avoid during video generation. If not defined, pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (`guidance_scale` < `1`).
+ height (`int`, defaults to `480`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `832`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `81`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, negative_prompt_embeds will be generated from the `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated video. Choose between `"latent"`, `"pt"`, or `"np"`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`Magi1PipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `800`):
+ The maximum sequence length for the text encoder. Sequences longer than this will be truncated. MAGI-1
+ uses a max length of 800 tokens.
+ use_hq_token (`bool`, *optional*, defaults to `False`):
+ Whether to prepend the high-quality control token to the text embeddings. This token conditions the
+ model to generate higher quality outputs. Requires special tokens to be loaded via
+ `load_special_tokens_from_file`.
+ use_3d_style (`bool`, *optional*, defaults to `False`):
+ Whether to prepend the 3D model style token to the text embeddings. This token conditions the model to
+ generate outputs with 3D modeling aesthetics. Requires special tokens to be loaded.
+ use_2d_anime_style (`bool`, *optional*, defaults to `False`):
+ Whether to prepend the 2D anime style token to the text embeddings. This token conditions the model to
+ generate outputs with 2D anime aesthetics. Requires special tokens to be loaded.
+ use_static_first_frames (`bool`, *optional*, defaults to `False`):
+ Whether to prepend the static first frames token to the text embeddings. This token conditions the
+ model to start the video with minimal motion in the first few frames. Requires special tokens to be
+ loaded.
+ use_dynamic_first_frames (`bool`, *optional*, defaults to `False`):
+ Whether to prepend the dynamic first frames token to the text embeddings. This token conditions the
+ model to start the video with significant motion in the first few frames. Requires special tokens to be
+ loaded.
+ enable_distillation (`bool`, *optional*, defaults to `False`):
+ Whether to enable distillation mode. In distillation mode, the model uses modified timestep embeddings
+ to support distilled (faster) inference. This requires a distilled model checkpoint.
+ distill_nearly_clean_chunk_threshold (`float`, *optional*, defaults to `0.3`):
+ Threshold for identifying nearly-clean chunks in distillation mode. Chunks with timestep > threshold
+ are considered nearly clean and processed differently. Only used when `enable_distillation=True`.
+
+ Examples:
+
+ Returns:
+ [`~Magi1PipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`Magi1PipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated videos.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ video,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds, prompt_mask, negative_mask = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=self.text_encoder.dtype,
+ )
+
+ # 3.5. Prepend special tokens if requested
+ if self.special_tokens is not None and any(
+ [use_hq_token, use_3d_style, use_2d_anime_style, use_static_first_frames, use_dynamic_first_frames]
+ ):
+ prompt_embeds = prepend_special_tokens(
+ prompt_embeds=prompt_embeds,
+ special_tokens=self.special_tokens,
+ use_hq_token=use_hq_token,
+ use_3d_style=use_3d_style,
+ use_2d_anime_style=use_2d_anime_style,
+ use_static_first_frames=use_static_first_frames,
+ use_dynamic_first_frames=use_dynamic_first_frames,
+ max_sequence_length=max_sequence_length,
+ )
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = prepend_special_tokens(
+ prompt_embeds=negative_prompt_embeds,
+ special_tokens=self.special_tokens,
+ use_hq_token=use_hq_token,
+ use_3d_style=use_3d_style,
+ use_2d_anime_style=use_2d_anime_style,
+ use_static_first_frames=use_static_first_frames,
+ use_dynamic_first_frames=use_dynamic_first_frames,
+ max_sequence_length=max_sequence_length,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents, prefix_video = self.prepare_latents(
+ video,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Denoising loop (autoregressive chunked generation with V2V prefix conditioning)
+ # MAGI-1 V2V uses autoregressive generation with chunk_width=6 and window_size=4
+ # The input video frames are encoded as clean prefix chunks and used to condition the generation
+ # Note: num_warmup_steps is calculated for compatibility but not used in progress bar logic
+ # because autoregressive generation has a different iteration structure (stages Ă— steps)
+ # For FlowMatchEulerDiscreteScheduler (order=1), this doesn't affect the results
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ # Autoregressive chunked generation parameters
+ chunk_width = 6 # Original MAGI-1 default
+ window_size = 4 # Original MAGI-1 default
+
+ num_latent_frames = latents.shape[2]
+ num_chunks = (num_latent_frames + chunk_width - 1) // chunk_width
+
+ # Calculate chunk_offset from prefix_video (for V2V, these are the clean video frame chunks)
+ chunk_offset = 0
+ if prefix_video is not None:
+ # prefix_video has shape [batch, channels, num_prefix_frames, height, width] for V2V
+ # Calculate how many chunks are covered by the prefix
+ prefix_latent_frames = prefix_video.shape[2]
+ chunk_offset = prefix_latent_frames // chunk_width
+
+ # Pad prefix_video into latents at the beginning
+ # The prefix frames are already clean and don't need denoising
+ if prefix_latent_frames > 0:
+ prefix_video = prefix_video.to(latents.dtype)
+ latents[:, :, :prefix_latent_frames] = prefix_video
+
+ # Calculate chunk scheduling: which chunks to process at each stage
+ # chunk_offset skips the clean prefix chunks
+ clip_start, clip_end, t_start, t_end = generate_chunk_sequences(num_chunks, window_size, chunk_offset)
+ num_stages = len(clip_start)
+
+ # Prepare per-chunk text embeddings for V2V
+ # Clean chunks (from input video) use null embeddings, denoise chunks use text embeddings
+ (
+ prompt_embeds_per_chunk,
+ negative_prompt_embeds_per_chunk,
+ prompt_masks_per_chunk,
+ negative_masks_per_chunk,
+ ) = prepare_v2v_embeddings(
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ num_chunks=num_chunks,
+ clean_chunk_num=chunk_offset,
+ max_sequence_length=max_sequence_length,
+ prompt_mask=prompt_mask,
+ negative_mask=negative_mask,
+ )
+
+ # Number of denoising steps per stage
+ denoise_step_per_stage = len(timesteps) // window_size
+
+ # Track how many times each chunk has been denoised
+ chunk_denoise_count = {i: 0 for i in range(num_chunks)}
+
+ with self.progress_bar(total=num_stages * denoise_step_per_stage) as progress_bar:
+ for stage_idx in range(num_stages):
+ # Determine which chunks to process in this stage
+ chunk_start_idx = clip_start[stage_idx]
+ chunk_end_idx = clip_end[stage_idx]
+ t_start_idx = t_start[stage_idx]
+ t_end_idx = t_end[stage_idx]
+
+ # Extract chunk range in latent space
+ latent_start = chunk_start_idx * chunk_width
+ latent_end = min(chunk_end_idx * chunk_width, num_latent_frames)
+
+ # Number of chunks in current window
+ num_chunks_in_window = chunk_end_idx - chunk_start_idx
+
+ # Prepare per-chunk conditioning with duration/borderness tokens
+ # Duration tokens indicate how many chunks remain in the video
+ # Borderness tokens condition on chunk boundaries
+ chunk_prompt_embeds_list = []
+ chunk_negative_prompt_embeds_list = []
+ chunk_prompt_masks_list = []
+ chunk_negative_masks_list = []
+
+ if self.special_tokens is not None:
+ # Prepare per-chunk embeddings with duration tokens
+ # Each chunk gets a different duration token based on chunks remaining
+ for i, chunk_idx in enumerate(range(chunk_start_idx, chunk_end_idx)):
+ chunks_remaining = num_chunks - chunk_idx - 1
+ # Duration token ranges from 1-8 chunks
+ duration_idx = min(chunks_remaining, 7) + 1
+
+ # Get base embeddings for this chunk (clean chunks have null embeds, denoise chunks have text embeds)
+ token_embeds = prompt_embeds_per_chunk[:, chunk_idx].clone()
+
+ # Add duration and borderness tokens for this chunk
+ if f"DURATION_TOKEN_{duration_idx}" in self.special_tokens:
+ duration_token = self.special_tokens[f"DURATION_TOKEN_{duration_idx}"]
+ duration_token = duration_token.to(device=token_embeds.device, dtype=token_embeds.dtype)
+ duration_token = duration_token.unsqueeze(0).expand(batch_size, -1, -1)
+ token_embeds = torch.cat([duration_token, token_embeds], dim=1)
+
+ if "BORDERNESS_TOKEN" in self.special_tokens:
+ borderness_token = self.special_tokens["BORDERNESS_TOKEN"]
+ borderness_token = borderness_token.to(
+ device=token_embeds.device, dtype=token_embeds.dtype
+ )
+ borderness_token = borderness_token.unsqueeze(0).expand(batch_size, -1, -1)
+ token_embeds = torch.cat([borderness_token, token_embeds], dim=1)
+
+ # Build mask for this chunk from base per-chunk mask
+ token_mask = (
+ prompt_masks_per_chunk[:, chunk_idx]
+ if prompt_masks_per_chunk is not None
+ else torch.ones(batch_size, token_embeds.shape[1], device=token_embeds.device, dtype=torch.int64)
+ )
+ add_count = 0
+ if f"DURATION_TOKEN_{duration_idx}" in self.special_tokens:
+ add_count += 1
+ if "BORDERNESS_TOKEN" in self.special_tokens:
+ add_count += 1
+ if add_count > 0:
+ prepend = torch.ones(batch_size, add_count, dtype=token_mask.dtype, device=token_mask.device)
+ token_mask = torch.cat([prepend, token_mask], dim=1)
+ # Truncate to max length
+ if token_embeds.shape[1] > max_sequence_length:
+ token_embeds = token_embeds[:, :max_sequence_length, :]
+ token_mask = token_mask[:, :max_sequence_length]
+
+ chunk_prompt_embeds_list.append(token_embeds)
+ chunk_prompt_masks_list.append(token_mask)
+
+ # Same for negative prompts
+ if self.do_classifier_free_guidance and negative_prompt_embeds_per_chunk is not None:
+ neg_token_embeds = negative_prompt_embeds_per_chunk[:, chunk_idx].clone()
+
+ if f"DURATION_TOKEN_{duration_idx}" in self.special_tokens:
+ duration_token = self.special_tokens[f"DURATION_TOKEN_{duration_idx}"]
+ duration_token = duration_token.to(
+ device=neg_token_embeds.device, dtype=neg_token_embeds.dtype
+ )
+ duration_token = duration_token.unsqueeze(0).expand(batch_size, -1, -1)
+ neg_token_embeds = torch.cat([duration_token, neg_token_embeds], dim=1)
+
+ if "BORDERNESS_TOKEN" in self.special_tokens:
+ borderness_token = self.special_tokens["BORDERNESS_TOKEN"]
+ borderness_token = borderness_token.to(
+ device=neg_token_embeds.device, dtype=neg_token_embeds.dtype
+ )
+ borderness_token = borderness_token.unsqueeze(0).expand(batch_size, -1, -1)
+ neg_token_embeds = torch.cat([borderness_token, neg_token_embeds], dim=1)
+
+ # Build negative mask for this chunk
+ neg_mask = (
+ negative_masks_per_chunk[:, chunk_idx]
+ if negative_masks_per_chunk is not None
+ else torch.ones(batch_size, neg_token_embeds.shape[1], device=neg_token_embeds.device, dtype=torch.int64)
+ )
+ add_count = 0
+ if f"DURATION_TOKEN_{duration_idx}" in self.special_tokens:
+ add_count += 1
+ if "BORDERNESS_TOKEN" in self.special_tokens:
+ add_count += 1
+ if add_count > 0:
+ prepend = torch.ones(batch_size, add_count, dtype=neg_mask.dtype, device=neg_mask.device)
+ neg_mask = torch.cat([prepend, neg_mask], dim=1)
+ if neg_token_embeds.shape[1] > max_sequence_length:
+ neg_token_embeds = neg_token_embeds[:, :max_sequence_length, :]
+ neg_mask = neg_mask[:, :max_sequence_length]
+
+ chunk_negative_prompt_embeds_list.append(neg_token_embeds)
+ chunk_negative_masks_list.append(neg_mask)
+
+ # Denoise this chunk range for denoise_step_per_stage steps
+ for denoise_idx in range(denoise_step_per_stage):
+ if self.interrupt:
+ break
+
+ # Calculate timestep index for each chunk in the current window
+ # Chunks at different stages get different timesteps based on their denoise progress
+ timestep_indices = []
+ for offset in range(num_chunks_in_window):
+ # Map offset within window to time index
+ t_idx_within_window = t_start_idx + offset
+ if t_idx_within_window < t_end_idx:
+ # This chunk is actively being denoised in this window
+ t_idx = t_idx_within_window * denoise_step_per_stage + denoise_idx
+ else:
+ # This chunk is beyond the active window, use max timestep (it's already cleaner)
+ t_idx = min((window_size - 1) * denoise_step_per_stage + denoise_idx, len(timesteps) - 1)
+ timestep_indices.append(t_idx)
+
+ # Reverse order: chunks further from start are noisier
+ timestep_indices.reverse()
+
+ # Get actual timesteps (reversed order: high noise to low noise)
+ current_timesteps = timesteps[timestep_indices]
+
+ # Create per-chunk timestep tensor: [batch_size, num_chunks_in_window]
+ # Each chunk gets its own timestep based on how many times it's been denoised
+ timestep_per_chunk = current_timesteps.unsqueeze(0).expand(batch_size, -1)
+
+ # Store first timestep for progress tracking
+ self._current_timestep = current_timesteps[0]
+
+ # Extract chunk
+ latent_chunk = latents[:, :, latent_start:latent_end].to(transformer_dtype)
+
+ # Prepare distillation parameters if enabled
+ num_steps = None
+ distill_interval = None
+ distill_nearly_clean_chunk = None
+
+ if enable_distillation:
+ # distill_interval represents the time interval between denoising steps
+ distill_interval = len(timesteps) / num_inference_steps
+
+ # Determine if chunks are nearly clean (low noise) based on their timesteps
+ # Check the first active chunk's timestep (after reversing, this is the noisiest chunk being actively denoised)
+ # Normalize timestep to [0, 1] range where 0=clean, 1=noise
+ nearly_clean_chunk_t = current_timesteps[0].item() / self.scheduler.config.num_train_timesteps
+ distill_nearly_clean_chunk = nearly_clean_chunk_t < distill_nearly_clean_chunk_threshold
+
+ num_steps = num_inference_steps
+
+ # Prepare per-chunk embeddings
+ # The transformer expects embeddings in shape [batch_size * num_chunks_in_window, seq_len, hidden_dim]
+ # Each chunk gets its own embedding with appropriate duration/borderness tokens
+ if chunk_prompt_embeds_list:
+ # Stack per-chunk embeddings: [num_chunks_in_window, batch_size, seq_len, hidden_dim]
+ chunk_prompt_embeds = torch.stack(chunk_prompt_embeds_list, dim=0)
+ # Reshape to [batch_size * num_chunks_in_window, seq_len, hidden_dim]
+ chunk_prompt_embeds = chunk_prompt_embeds.transpose(0, 1).flatten(0, 1)
+
+ if chunk_negative_prompt_embeds_list:
+ chunk_negative_prompt_embeds = torch.stack(chunk_negative_prompt_embeds_list, dim=0)
+ chunk_negative_prompt_embeds = chunk_negative_prompt_embeds.transpose(0, 1).flatten(0, 1)
+ else:
+ chunk_negative_prompt_embeds = None
+ else:
+ # Fallback: use per-chunk embeddings without special tokens
+ # Extract embeddings for the current chunk range
+ chunk_prompt_embeds = prompt_embeds_per_chunk[:, chunk_start_idx:chunk_end_idx]
+ chunk_prompt_embeds = chunk_prompt_embeds.flatten(0, 1)
+ chunk_prompt_masks = (
+ prompt_masks_per_chunk[:, chunk_start_idx:chunk_end_idx].flatten(0, 1)
+ if prompt_masks_per_chunk is not None
+ else None
+ )
+
+ if negative_prompt_embeds_per_chunk is not None:
+ chunk_negative_prompt_embeds = negative_prompt_embeds_per_chunk[
+ :, chunk_start_idx:chunk_end_idx
+ ]
+ chunk_negative_prompt_embeds = chunk_negative_prompt_embeds.flatten(0, 1)
+ chunk_negative_masks = (
+ negative_masks_per_chunk[:, chunk_start_idx:chunk_end_idx].flatten(0, 1)
+ if negative_masks_per_chunk is not None
+ else None
+ )
+ else:
+ chunk_negative_prompt_embeds = None
+ chunk_negative_masks = None
+
+ # Create encoder attention mask(s) from tokenizer masks
+ if chunk_prompt_embeds_list:
+ prompt_masks_stacked = torch.stack(chunk_prompt_masks_list, dim=0).transpose(0, 1).flatten(0, 1)
+ else:
+ prompt_masks_stacked = chunk_prompt_masks
+ if prompt_masks_stacked is None:
+ prompt_masks_stacked = torch.ones(
+ batch_size * num_chunks_in_window,
+ chunk_prompt_embeds.shape[1],
+ dtype=torch.int64,
+ device=chunk_prompt_embeds.device,
+ )
+ encoder_attention_mask = prompt_masks_stacked.to(dtype=chunk_prompt_embeds.dtype).view(
+ batch_size * num_chunks_in_window, 1, 1, -1
+ )
+
+ if self.do_classifier_free_guidance:
+ if chunk_negative_prompt_embeds_list:
+ negative_masks_stacked = (
+ torch.stack(chunk_negative_masks_list, dim=0).transpose(0, 1).flatten(0, 1)
+ )
+ else:
+ negative_masks_stacked = chunk_negative_masks
+ if negative_masks_stacked is None and chunk_negative_prompt_embeds is not None:
+ negative_masks_stacked = torch.ones(
+ batch_size * num_chunks_in_window,
+ chunk_negative_prompt_embeds.shape[1],
+ dtype=torch.int64,
+ device=chunk_negative_prompt_embeds.device,
+ )
+ encoder_attention_mask_neg = (
+ negative_masks_stacked.to(dtype=chunk_prompt_embeds.dtype).view(
+ batch_size * num_chunks_in_window, 1, 1, -1
+ )
+ if negative_masks_stacked is not None
+ else encoder_attention_mask
+ )
+
+ # Pad prefix video into latent_chunk if applicable (I2V)
+ # This ensures clean prefix frames are maintained during denoising
+ if prefix_video is not None:
+ prefix_length = prefix_video.shape[2]
+ prefix_video_start = chunk_start_idx * chunk_width
+
+ if prefix_length > prefix_video_start:
+ # Calculate how many frames to pad
+ padding_length = min(prefix_length - prefix_video_start, latent_chunk.shape[2])
+ prefix_video_end = prefix_video_start + padding_length
+
+ # Pad clean prefix frames into latent_chunk
+ latent_chunk = latent_chunk.clone()
+ latent_chunk[:, :, :padding_length] = prefix_video[
+ :, :, prefix_video_start:prefix_video_end
+ ]
+
+ # Set timesteps for clean prefix chunks to maximum (indicates "already clean")
+ # This matches original MAGI-1's try_pad_prefix_video logic
+ num_clean_chunks_in_window = padding_length // chunk_width
+ if num_clean_chunks_in_window > 0:
+ # Get max timestep from scheduler
+ max_timestep = timesteps[0]
+ timestep_per_chunk[:, :num_clean_chunks_in_window] = max_timestep
+
+ # Generate KV range for autoregressive attention
+ # Each chunk can attend to itself and all previous chunks in the sequence
+ # Shape: [batch_size * num_chunks_in_window, 2] where each row is [start_token_idx, end_token_idx]
+ chunk_token_nums = (
+ (latent_chunk.shape[2] // num_chunks_in_window) # frames per chunk
+ * (latent_chunk.shape[3] // self.transformer.config.patch_size[1]) # height tokens
+ * (latent_chunk.shape[4] // self.transformer.config.patch_size[2]) # width tokens
+ )
+ kv_range = []
+ for b in range(batch_size):
+ # batch_offset should be based on total chunks in the video, not chunk_end_idx
+ batch_offset = b * num_chunks
+ for c in range(num_chunks_in_window):
+ # This chunk can attend from the start of the video up to its own end
+ chunk_global_idx = chunk_start_idx + c
+ k_start = batch_offset * chunk_token_nums
+ k_end = (batch_offset + chunk_global_idx + 1) * chunk_token_nums
+ kv_range.append([k_start, k_end])
+ kv_range = torch.tensor(kv_range, dtype=torch.int32, device=device)
+
+ # Predict noise (conditional)
+ # Note: MAGI-1 uses velocity field (flow matching), but following diffusers convention
+ # we use noise_pred naming for consistency across all pipelines
+ noise_pred = self.transformer(
+ hidden_states=latent_chunk,
+ timestep=timestep_per_chunk,
+ encoder_hidden_states=chunk_prompt_embeds,
+ encoder_attention_mask=encoder_attention_mask,
+ attention_kwargs=attention_kwargs,
+ denoising_range_num=num_chunks_in_window,
+ range_num=chunk_end_idx,
+ slice_point=chunk_start_idx,
+ kv_range=kv_range,
+ num_steps=num_steps,
+ distill_interval=distill_interval,
+ distill_nearly_clean_chunk=distill_nearly_clean_chunk,
+ return_dict=False,
+ )[0]
+
+ # Classifier-free guidance: separate forward pass for unconditional
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond = self.transformer(
+ hidden_states=latent_chunk,
+ timestep=timestep_per_chunk,
+ encoder_hidden_states=chunk_negative_prompt_embeds,
+ encoder_attention_mask=encoder_attention_mask_neg,
+ attention_kwargs=attention_kwargs,
+ denoising_range_num=num_chunks_in_window,
+ range_num=chunk_end_idx,
+ slice_point=chunk_start_idx,
+ kv_range=kv_range,
+ num_steps=num_steps,
+ distill_interval=distill_interval,
+ distill_nearly_clean_chunk=distill_nearly_clean_chunk,
+ return_dict=False,
+ )[0]
+ # Apply classifier-free guidance
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
+
+ # CRITICAL: Apply per-chunk Euler integration with different delta_t for each chunk
+ # This matches the original MAGI-1's integrate() function
+ # Each chunk in the window is at a different noise level and needs its own time step
+
+ # Calculate per-chunk timesteps for integration
+ # Get current timesteps (t_before)
+ current_timesteps = timesteps[timestep_indices]
+
+ # Get next timesteps (t_after) - one step forward for each chunk
+ next_timestep_indices = [min(idx + 1, len(timesteps) - 1) for idx in timestep_indices]
+ next_timesteps = timesteps[next_timestep_indices]
+
+ # Convert timesteps to sigmas (matching FlowMatchEulerDiscreteScheduler)
+ current_sigmas = current_timesteps / self.scheduler.config.num_train_timesteps
+ next_sigmas = next_timesteps / self.scheduler.config.num_train_timesteps
+
+ # Calculate delta_t for each chunk: [num_chunks_in_window]
+ delta_t = next_sigmas - current_sigmas
+
+ # Reshape latent_chunk and noise_pred to separate chunks
+ # From: [batch, channels, frames, height, width]
+ # To: [batch, channels, num_chunks, chunk_width, height, width]
+ batch_size_actual, num_channels, total_frames, height_latent, width_latent = latent_chunk.shape
+
+ # Ensure total_frames is divisible by chunk_width for reshaping
+ # (it should be, but let's handle edge cases)
+ num_complete_chunks = total_frames // chunk_width
+ remainder_frames = total_frames % chunk_width
+
+ if remainder_frames == 0:
+ # Perfect division: reshape and apply per-chunk delta_t
+ latent_chunk = latent_chunk.reshape(
+ batch_size_actual,
+ num_channels,
+ num_complete_chunks,
+ chunk_width,
+ height_latent,
+ width_latent,
+ )
+ noise_pred = noise_pred.reshape(
+ batch_size_actual,
+ num_channels,
+ num_complete_chunks,
+ chunk_width,
+ height_latent,
+ width_latent,
+ )
+
+ # Apply Euler integration: x_chunk = x_chunk + velocity * delta_t
+ # delta_t shape: [num_chunks] -> broadcast to [1, 1, num_chunks, 1, 1, 1]
+ delta_t_broadcast = delta_t.reshape(1, 1, -1, 1, 1, 1).to(
+ latent_chunk.device, latent_chunk.dtype
+ )
+ latent_chunk = latent_chunk + noise_pred * delta_t_broadcast
+
+ # Reshape back to original dimensions
+ latent_chunk = latent_chunk.reshape(
+ batch_size_actual, num_channels, total_frames, height_latent, width_latent
+ )
+ else:
+ # Handle remainder frames separately (edge case for last incomplete chunk)
+ complete_frames = num_complete_chunks * chunk_width
+
+ # Process complete chunks
+ latent_chunk_complete = latent_chunk[:, :, :complete_frames]
+ noise_pred_complete = noise_pred[:, :, :complete_frames]
+
+ latent_chunk_complete = latent_chunk_complete.reshape(
+ batch_size_actual,
+ num_channels,
+ num_complete_chunks,
+ chunk_width,
+ height_latent,
+ width_latent,
+ )
+ noise_pred_complete = noise_pred_complete.reshape(
+ batch_size_actual,
+ num_channels,
+ num_complete_chunks,
+ chunk_width,
+ height_latent,
+ width_latent,
+ )
+
+ # Apply per-chunk delta_t to complete chunks
+ delta_t_broadcast = (
+ delta_t[:num_complete_chunks]
+ .reshape(1, 1, -1, 1, 1, 1)
+ .to(latent_chunk.device, latent_chunk.dtype)
+ )
+ latent_chunk_complete = latent_chunk_complete + noise_pred_complete * delta_t_broadcast
+ latent_chunk_complete = latent_chunk_complete.reshape(
+ batch_size_actual, num_channels, complete_frames, height_latent, width_latent
+ )
+
+ # Process remainder frames with last delta_t
+ if remainder_frames > 0:
+ latent_chunk_remainder = latent_chunk[:, :, complete_frames:]
+ noise_pred_remainder = noise_pred[:, :, complete_frames:]
+ delta_t_remainder = delta_t[-1].to(latent_chunk.device, latent_chunk.dtype)
+ latent_chunk_remainder = latent_chunk_remainder + noise_pred_remainder * delta_t_remainder
+
+ # Concatenate complete and remainder
+ latent_chunk = torch.cat([latent_chunk_complete, latent_chunk_remainder], dim=2)
+ else:
+ latent_chunk = latent_chunk_complete
+
+ # Write back to full latents
+ latents[:, :, latent_start:latent_end] = latent_chunk
+
+ # Update chunk denoise counts
+ for chunk_idx in range(chunk_start_idx, chunk_end_idx):
+ chunk_denoise_count[chunk_idx] += 1
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ # Use first timestep for callback (most representative)
+ callback_timestep = current_timesteps[0]
+ callback_outputs = callback_on_step_end(
+ self, stage_idx * denoise_step_per_stage + denoise_idx, callback_timestep, callback_kwargs
+ )
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return Magi1PipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/magi1/pipeline_output.py b/src/diffusers/pipelines/magi1/pipeline_output.py
new file mode 100644
index 000000000000..200156cffac9
--- /dev/null
+++ b/src/diffusers/pipelines/magi1/pipeline_output.py
@@ -0,0 +1,36 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import torch
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class Magi1PipelineOutput(BaseOutput):
+ """
+ Output class for MAGI-1 pipeline.
+
+ Args:
+ frames (`torch.Tensor` or `np.ndarray`):
+ List of denoised frames from the diffusion process, as a NumPy array of shape `(batch_size, num_frames,
+ height, width, num_channels)` or a PyTorch tensor of shape `(batch_size, num_channels, num_frames, height,
+ width)`.
+ """
+
+ frames: Union[torch.Tensor, np.ndarray, List[List[np.ndarray]]]
diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py
index 63932221b207..e5406f040d99 100644
--- a/src/diffusers/utils/__init__.py
+++ b/src/diffusers/utils/__init__.py
@@ -86,6 +86,8 @@
is_kernels_available,
is_kornia_available,
is_librosa_available,
+ is_magi_attn_available,
+ is_magi_attn_version,
is_matplotlib_available,
is_nltk_available,
is_note_seq_available,
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 5d62709c28fd..07dddb0c1b2d 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -408,6 +408,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class AutoencoderKLMagi1(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class AutoencoderKLMagvit(metaclass=DummyObject):
_backends = ["torch"]
@@ -993,6 +1008,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class Magi1Transformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class MochiTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 3244ef12ef87..fd7371f1c9f8 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -1742,6 +1742,51 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class Magi1ImageToVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class Magi1Pipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class Magi1VideoToVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class MarigoldDepthPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py
index 97065267b004..7763380af8ab 100644
--- a/src/diffusers/utils/import_utils.py
+++ b/src/diffusers/utils/import_utils.py
@@ -226,6 +226,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
+_magi_attn_available, _magi_attn_version = _is_package_available("magi_attention")
_kornia_available, _kornia_version = _is_package_available("kornia")
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
@@ -406,6 +407,10 @@ def is_flash_attn_3_available():
return _flash_attn_3_available
+def is_magi_attn_available():
+ return _magi_attn_available
+
+
def is_kornia_available():
return _kornia_available
@@ -911,6 +916,21 @@ def is_flash_attn_version(operation: str, version: str):
return compare_versions(parse(_flash_attn_version), operation, version)
+def is_magi_attn_version(operation: str, version: str):
+ """
+ Compares the current magi-attention version to a given reference with an operation.
+
+ Args:
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A version string
+ """
+ if not _magi_attn_available:
+ return False
+ return compare_versions(parse(_magi_attn_version), operation, version)
+
+
def get_objects_from_module(module):
"""
Returns a dict of object names and values in a module, while skipping private/internal objects
diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_magi1.py b/tests/models/autoencoders/test_models_autoencoder_kl_magi1.py
new file mode 100644
index 000000000000..ab8b45f31837
--- /dev/null
+++ b/tests/models/autoencoders/test_models_autoencoder_kl_magi1.py
@@ -0,0 +1,155 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import AutoencoderKLMagi1
+from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
+
+from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+
+
+enable_full_determinism()
+
+
+class AutoencoderKLMagi1Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+ model_class = AutoencoderKLMagi1
+ main_input_name = "sample"
+ base_precision = 1e-2
+
+ def get_autoencoder_kl_magi1_config(self):
+ return {
+ "base_dim": 3,
+ "z_dim": 16,
+ "dim_mult": [1, 1, 1, 1],
+ "num_res_blocks": 1,
+ "temperal_downsample": [False, True, True],
+ }
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_frames = 9
+ num_channels = 3
+ sizes = (16, 16)
+ image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
+ return {"sample": image}
+
+ @property
+ def dummy_input_tiling(self):
+ batch_size = 2
+ num_frames = 9
+ num_channels = 3
+ sizes = (128, 128)
+ image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
+ return {"sample": image}
+
+ @property
+ def input_shape(self):
+ return (3, 9, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (3, 9, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = self.get_autoencoder_kl_magi1_config()
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def prepare_init_args_and_inputs_for_tiling(self):
+ init_dict = self.get_autoencoder_kl_magi1_config()
+ inputs_dict = self.dummy_input_tiling
+ return init_dict, inputs_dict
+
+ def test_enable_disable_tiling(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_tiling()
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict).to(torch_device)
+
+ inputs_dict.update({"return_dict": False})
+
+ torch.manual_seed(0)
+ output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ torch.manual_seed(0)
+ model.enable_tiling(96, 96, 64, 64)
+ output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertLess(
+ (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
+ 0.5,
+ "VAE tiling should not affect the inference results",
+ )
+
+ torch.manual_seed(0)
+ model.disable_tiling()
+ output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertEqual(
+ output_without_tiling.detach().cpu().numpy().all(),
+ output_without_tiling_2.detach().cpu().numpy().all(),
+ "Without tiling outputs should match with the outputs when tiling is manually disabled.",
+ )
+
+ def test_enable_disable_slicing(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict).to(torch_device)
+
+ inputs_dict.update({"return_dict": False})
+
+ torch.manual_seed(0)
+ output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ torch.manual_seed(0)
+ model.enable_slicing()
+ output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertLess(
+ (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
+ 0.05,
+ "VAE slicing should not affect the inference results",
+ )
+
+ torch.manual_seed(0)
+ model.disable_slicing()
+ output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertEqual(
+ output_without_slicing.detach().cpu().numpy().all(),
+ output_without_slicing_2.detach().cpu().numpy().all(),
+ "Without slicing outputs should match with the outputs when slicing is manually disabled.",
+ )
+
+ @unittest.skip("Gradient checkpointing has not been implemented yet")
+ def test_gradient_checkpointing_is_applied(self):
+ pass
+
+ @unittest.skip("Test not supported")
+ def test_forward_with_norm_groups(self):
+ pass
+
+ @unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
+ def test_layerwise_casting_inference(self):
+ pass
+
+ @unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
+ def test_layerwise_casting_training(self):
+ pass
diff --git a/tests/models/transformers/test_models_transformer_magi1.py b/tests/models/transformers/test_models_transformer_magi1.py
new file mode 100644
index 000000000000..ed8d775e6058
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_magi1.py
@@ -0,0 +1,91 @@
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import Magi1Transformer3DModel
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+
+from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
+
+
+enable_full_determinism()
+
+
+class Magi1Transformer3DTests(ModelTesterMixin, unittest.TestCase):
+ model_class = Magi1Transformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 1
+ num_channels = 4
+ num_frames = 2
+ height = 16
+ width = 16
+ text_encoder_embedding_dim = 16
+ sequence_length = 12
+
+ hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "timestep": timestep,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 1, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (4, 1, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "patch_size": (1, 2, 2),
+ "num_attention_heads": 2,
+ "attention_head_dim": 12,
+ "in_channels": 4,
+ "out_channels": 4,
+ "cross_attention_dim": 16,
+ "freq_dim": 256,
+ "ffn_dim": 32,
+ "num_layers": 2,
+ "cross_attn_norm": True,
+ "qk_norm": "rms_norm_across_heads",
+ "rope_max_seq_len": 32,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"Magi1Transformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+
+class Magi1TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+ model_class = Magi1Transformer3DModel
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return Magi1Transformer3DTests().prepare_init_args_and_inputs_for_common()
diff --git a/tests/pipelines/magi1/test_magi1.py b/tests/pipelines/magi1/test_magi1.py
new file mode 100644
index 000000000000..3695bcbe5be7
--- /dev/null
+++ b/tests/pipelines/magi1/test_magi1.py
@@ -0,0 +1,158 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKLMagi1, FlowMatchEulerDiscreteScheduler, Magi1Pipeline, Magi1Transformer3DModel
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+class Magi1PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = Magi1Pipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLMagi1(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ # TODO: impl FlowDPMSolverMultistepScheduler
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = Magi1Transformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 16,
+ "width": 16,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+ expected_video = torch.randn(9, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+
+@slow
+@require_torch_accelerator
+class Magi1PipelineIntegrationTests(unittest.TestCase):
+ prompt = "A painting of a squirrel eating a burger."
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ @unittest.skip("TODO: test needs to be implemented")
+ def test_Magi1(self):
+ pass
diff --git a/tests/pipelines/magi1/test_magi1_image_to_video.py b/tests/pipelines/magi1/test_magi1_image_to_video.py
new file mode 100644
index 000000000000..50ca59f5bbf5
--- /dev/null
+++ b/tests/pipelines/magi1/test_magi1_image_to_video.py
@@ -0,0 +1,146 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import PIL
+import torch
+from transformers import AutoTokenizer, CLIPVisionModel, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKLMagi1,
+ FlowMatchEulerDiscreteScheduler,
+ Magi1ImageToVideoPipeline,
+ Magi1Transformer3DModel,
+)
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+)
+
+from ..pipeline_params import (
+ TEXT_TO_IMAGE_BATCH_PARAMS,
+ TEXT_TO_IMAGE_IMAGE_PARAMS,
+ TEXT_TO_IMAGE_PARAMS,
+)
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+class Magi1ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = Magi1ImageToVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLMagi1(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ # TODO: impl FlowDPMSolverMultistepScheduler
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+ image_encoder = CLIPVisionModel.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ torch.manual_seed(0)
+ transformer = Magi1Transformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "image_encoder": image_encoder,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 16,
+ "width": 16,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ "image": PIL.Image.new("RGB", (16, 16)),
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+ expected_video = torch.randn(9, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
diff --git a/tests/pipelines/magi1/test_magi1_video_to_video.py b/tests/pipelines/magi1/test_magi1_video_to_video.py
new file mode 100644
index 000000000000..97efc4da0539
--- /dev/null
+++ b/tests/pipelines/magi1/test_magi1_video_to_video.py
@@ -0,0 +1,153 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import PIL
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKLMagi1,
+ FlowMatchEulerDiscreteScheduler,
+ Magi1Transformer3DModel,
+ Magi1VideoToVideoPipeline,
+)
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+)
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+class Magi1VideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = Magi1VideoToVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLMagi1(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = Magi1Transformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ # Create a list of PIL images to simulate video input
+ video_frames = [PIL.Image.new("RGB", (16, 16)) for _ in range(9)]
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "negative",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 16,
+ "width": 16,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ "video": video_frames,
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+ expected_video = torch.randn(9, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ @unittest.skip(
+ "Magi1VideoToVideoPipeline has to run in mixed precision. Casting the entire pipeline will result in errors"
+ )
+ def test_model_cpu_offload_forward_pass(self):
+ pass
+
+ @unittest.skip(
+ "Magi1VideoToVideoPipeline has to run in mixed precision. Save/Load the entire pipeline in FP16 will result in errors"
+ )
+ def test_save_load_float16(self):
+ pass
diff --git a/tests/single_file/test_model_magi_autoencoder_single_file.py b/tests/single_file/test_model_magi_autoencoder_single_file.py
new file mode 100644
index 000000000000..8721d884fa25
--- /dev/null
+++ b/tests/single_file/test_model_magi_autoencoder_single_file.py
@@ -0,0 +1,61 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import AutoencoderKLMagi1
+from diffusers.utils.testing_utils import (
+ require_torch_gpu,
+ slow,
+ torch_device,
+)
+
+
+class AutoencoderKLMagiSingleFileTests(unittest.TestCase):
+ model_class = AutoencoderKLMagi1
+ ckpt_path = "https://huggingface.co/sand-ai/MAGI-1/blob/main/vae/diffusion_pytorch_model.safetensors"
+ repo_id = "sand-ai/MAGI-1"
+
+ @slow
+ @require_torch_gpu
+ def test_single_file_components(self):
+ model = self.model_class.from_single_file(self.ckpt_path)
+ model.to(torch_device)
+
+ batch_size = 1
+ num_frames = 2
+ num_channels = 3
+ sizes = (16, 16)
+ image = torch.randn((batch_size, num_channels, num_frames) + sizes).to(torch_device)
+
+ with torch.no_grad():
+ model(image, return_dict=False)
+
+ @slow
+ @require_torch_gpu
+ def test_single_file_components_from_hub(self):
+ model = self.model_class.from_pretrained(self.repo_id, subfolder="vae")
+ model.to(torch_device)
+
+ batch_size = 1
+ num_frames = 2
+ num_channels = 3
+ sizes = (16, 16)
+ image = torch.randn((batch_size, num_channels, num_frames) + sizes).to(torch_device)
+
+ with torch.no_grad():
+ model(image, return_dict=False)
diff --git a/tests/single_file/test_model_magi_transformer3d_single_file.py b/tests/single_file/test_model_magi_transformer3d_single_file.py
new file mode 100644
index 000000000000..fb6b0ae04622
--- /dev/null
+++ b/tests/single_file/test_model_magi_transformer3d_single_file.py
@@ -0,0 +1,83 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import Magi1Transformer3DModel
+from diffusers.utils.testing_utils import (
+ require_torch_gpu,
+ slow,
+ torch_device,
+)
+
+
+class Magi1Transformer3DModelText2VideoSingleFileTest(unittest.TestCase):
+ model_class = Magi1Transformer3DModel
+ ckpt_path = "https://huggingface.co/sand-ai/MAGI-1/blob/main/transformer/diffusion_pytorch_model.safetensors"
+ repo_id = "sand-ai/MAGI-1"
+
+ @slow
+ @require_torch_gpu
+ def test_single_file_components(self):
+ model = self.model_class.from_single_file(self.ckpt_path)
+ model.to(torch_device)
+
+ batch_size = 1
+ num_channels = 4
+ num_frames = 2
+ height = 16
+ width = 16
+ text_encoder_embedding_dim = 16
+ sequence_length = 12
+
+ hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
+
+ with torch.no_grad():
+ model(
+ hidden_states=hidden_states,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ return_dict=False,
+ )
+
+ @slow
+ @require_torch_gpu
+ def test_single_file_components_from_hub(self):
+ model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
+ model.to(torch_device)
+
+ batch_size = 1
+ num_channels = 4
+ num_frames = 2
+ height = 16
+ width = 16
+ text_encoder_embedding_dim = 16
+ sequence_length = 12
+
+ hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
+
+ with torch.no_grad():
+ model(
+ hidden_states=hidden_states,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ return_dict=False,
+ )