Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
ec5449f
Support both huggingface_hub `v0.x` and `v1.x` (#12389)
Wauplin Sep 25, 2025
016316a
mirage pipeline first commit
Sep 26, 2025
4ac274b
use attention processors
Sep 26, 2025
904debc
use diffusers rmsnorm
Sep 26, 2025
122115a
use diffusers timestep embedding method
Sep 26, 2025
4588bbe
[CI] disable installing transformers from main in ci for now. (#12397)
sayakpaul Sep 26, 2025
e3fe0e8
remove MirageParams
Sep 26, 2025
85ae87b
checkpoint conversion script
Sep 26, 2025
9a697d0
ruff formating
Sep 26, 2025
9c09445
[docs] slight edits to the attention backends docs. (#12394)
sayakpaul Sep 26, 2025
041501a
[docs] remove docstrings from repeated methods in `lora_pipeline.py` …
sayakpaul Sep 26, 2025
19085ac
Don't skip Qwen model tests for group offloading with disk (#12382)
sayakpaul Sep 29, 2025
0a15111
Fix #12116: preserve boolean dtype for attention masks in ChromaPipe…
akshay-babbar Sep 29, 2025
64a5187
[quantization] feat: support aobaseconfig classes in `TorchAOConfig` …
sayakpaul Sep 29, 2025
ccedeca
[docs] Distributed inference (#12285)
stevhliu Sep 29, 2025
c07fcf7
[docs] Model formats (#12256)
stevhliu Sep 29, 2025
76d4e41
[modular]some small fix (#12307)
yiyixuxu Sep 29, 2025
20fd00b
[Tests] Add single file tester mixin for Models and remove unittest d…
DN6 Sep 30, 2025
0e12ba7
fix 3 xpu failures uts w/ latest pytorch (#12408)
yao-matrix Sep 30, 2025
b596545
Install latest prerelease from huggingface_hub when installing transf…
Wauplin Sep 30, 2025
d7a1a03
[docs] CP (#12331)
stevhliu Sep 30, 2025
cc5b31f
[docs] Migrate syntax (#12390)
stevhliu Sep 30, 2025
34fa9dd
remove dependencies to old checkpoints
Sep 30, 2025
5cc965a
remove old checkpoints dependency
Sep 30, 2025
d79cd8f
move default height and width in checkpoint config
Sep 30, 2025
f2759fd
add docstrings
Sep 30, 2025
394f725
if conditions and raised as ValueError instead of asserts
Sep 30, 2025
54fb063
small fix
Sep 30, 2025
c49fafb
nit remove try block at import
Sep 30, 2025
7e7df35
mirage pipeline doc
Sep 30, 2025
814d710
[tests] cache non lora pipeline outputs. (#12298)
sayakpaul Oct 1, 2025
9ae5b62
[ci] xfail failing tests in CI. (#12418)
sayakpaul Oct 2, 2025
b429796
[core] conditionally import torch distributed stuff. (#12420)
sayakpaul Oct 2, 2025
7242b5f
FIX Test to ignore warning for enable_lora_hotswap (#12421)
BenjaminBossan Oct 2, 2025
941ac9c
[training-scripts] Make more examples UV-compatible (follow up on #12…
linoytsaban Oct 3, 2025
2b7deff
fix scale_shift_factor being on cpu for wan and ltx (#12347)
vladmandic Oct 5, 2025
c3675d4
[core] support QwenImage Edit Plus in modular (#12416)
sayakpaul Oct 5, 2025
ce90f9b
[FIX] Text to image training peft version (#12434)
SahilCarterr Oct 6, 2025
7f3e9b8
make flux ready for mellon (#12419)
sayakpaul Oct 6, 2025
cf4b97b
[perf] Cache version checks (#12399)
cbensimon Oct 6, 2025
0974b4c
[i18n-KO] Fix typo and update translation in ethical_guidelines.md (#…
braintrue Oct 6, 2025
2d69bac
handle offload_state_dict when initing transformers models (#12438)
sayakpaul Oct 7, 2025
de03851
update doc
Oct 7, 2025
a69aa4b
rename model to photon
Oct 7, 2025
1066de8
[Qwen LoRA training] fix bug when offloading (#12440)
linoytsaban Oct 7, 2025
2dc3167
Align Flux modular more and more with Qwen modular (#12445)
sayakpaul Oct 8, 2025
35e538d
fix dockerfile definitions. (#12424)
sayakpaul Oct 8, 2025
345864e
fix more torch.distributed imports (#12425)
sayakpaul Oct 8, 2025
9e099a7
mirage pipeline first commit
Sep 26, 2025
6e10ed4
use attention processors
Sep 26, 2025
866c6de
use diffusers rmsnorm
Sep 26, 2025
4e8b647
use diffusers timestep embedding method
Sep 26, 2025
472ad97
remove MirageParams
Sep 26, 2025
97a231e
checkpoint conversion script
Sep 26, 2025
35d721f
ruff formating
Sep 26, 2025
775a115
remove dependencies to old checkpoints
Sep 30, 2025
1c6c25c
remove old checkpoints dependency
Sep 30, 2025
b0d965c
move default height and width in checkpoint config
Sep 30, 2025
235fe49
add docstrings
Sep 30, 2025
a6ff579
if conditions and raised as ValueError instead of asserts
Sep 30, 2025
3a91503
small fix
Sep 30, 2025
e200cf6
nit remove try block at import
Sep 30, 2025
2ea8976
mirage pipeline doc
Sep 30, 2025
26429a3
update doc
Oct 7, 2025
0abe136
rename model to photon
Oct 7, 2025
fe0e3d5
add text tower and vae in checkpoint
Oct 8, 2025
855b068
update doc
Oct 8, 2025
d2c6bdd
Merge branch 'mirage' of https://github.com/Photoroom/diffusers into …
Oct 8, 2025
89beae8
update photon doc
Oct 8, 2025
2df0e2f
ruff fixes
Oct 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
297 changes: 297 additions & 0 deletions scripts/convert_mirage_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
#!/usr/bin/env python3
"""
Script to convert Mirage checkpoint from original codebase to diffusers format.
"""

import argparse
import json
import os
import shutil
import sys

import torch
from safetensors.torch import save_file


sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))

from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel
from diffusers.pipelines.mirage import MiragePipeline


def load_reference_config(vae_type: str) -> dict:
"""Load transformer config from existing pipeline checkpoint."""

if vae_type == "flux":
config_path = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_fluxvae_gemmaT5_updated/transformer/config.json"
elif vae_type == "dc-ae":
config_path = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_dcae_gemmaT5_updated/transformer/config.json"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we change these hardcoded paths?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely!

else:
raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'")

if not os.path.exists(config_path):
raise FileNotFoundError(f"Reference config not found: {config_path}")

with open(config_path, "r") as f:
config = json.load(f)

print(f"✓ Loaded {vae_type} config: in_channels={config['in_channels']}")
return config


def create_parameter_mapping() -> dict:
"""Create mapping from old parameter names to new diffusers names."""

# Key mappings for structural changes
mapping = {}

# RMSNorm: scale -> weight
for i in range(16): # 16 layers
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a MIRAGE_NUM_LAYERS: int = 16 constant at the top?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I change it to come from a config instead.

mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.qk_norm.query_norm.weight"
mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.qk_norm.key_norm.weight"
mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.k_norm.weight"

# Attention: attn_out -> attention.to_out.0
mapping[f"blocks.{i}.attn_out.weight"] = f"blocks.{i}.attention.to_out.0.weight"

return mapping


def convert_checkpoint_parameters(old_state_dict: dict) -> dict:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably good to use more specific types, like Dict[str, str]. Import Dict from typing, because the diffusers library supports Python 3.8 and built-in types with generics, e.g. dict[str, str] are only supported since Python 3.9.

"""Convert old checkpoint parameters to new diffusers format."""

print("Converting checkpoint parameters...")

mapping = create_parameter_mapping()
converted_state_dict = {}

# First, print available keys to understand structure
print("Available keys in checkpoint:")
for key in sorted(old_state_dict.keys())[:10]: # Show first 10 keys
print(f" {key}")
if len(old_state_dict) > 10:
print(f" ... and {len(old_state_dict) - 10} more")

for key, value in old_state_dict.items():
new_key = key

# Apply specific mappings if needed
if key in mapping:
new_key = mapping[key]
print(f" Mapped: {key} -> {new_key}")

# Handle img_qkv_proj -> split to to_q, to_k, to_v
if "img_qkv_proj.weight" in key:
print(f" Found QKV projection: {key}")
# Split QKV weight into separate Q, K, V projections
qkv_weight = value
q_weight, k_weight, v_weight = qkv_weight.chunk(3, dim=0)

# Extract layer number from key (e.g., blocks.0.img_qkv_proj.weight -> 0)
parts = key.split(".")
layer_idx = None
for i, part in enumerate(parts):
if part == "blocks" and i + 1 < len(parts) and parts[i + 1].isdigit():
layer_idx = parts[i + 1]
break

if layer_idx is not None:
converted_state_dict[f"blocks.{layer_idx}.attention.to_q.weight"] = q_weight
converted_state_dict[f"blocks.{layer_idx}.attention.to_k.weight"] = k_weight
converted_state_dict[f"blocks.{layer_idx}.attention.to_v.weight"] = v_weight
print(f" Split QKV for layer {layer_idx}")

# Also keep the original img_qkv_proj for backward compatibility
converted_state_dict[new_key] = value
else:
converted_state_dict[new_key] = value

print(f"✓ Converted {len(converted_state_dict)} parameters")
return converted_state_dict


def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> MirageTransformer2DModel:
"""Create and load MirageTransformer2DModel from old checkpoint."""

print(f"Loading checkpoint from: {checkpoint_path}")

# Load old checkpoint
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

old_checkpoint = torch.load(checkpoint_path, map_location="cpu")

# Handle different checkpoint formats
if isinstance(old_checkpoint, dict):
if "model" in old_checkpoint:
state_dict = old_checkpoint["model"]
elif "state_dict" in old_checkpoint:
state_dict = old_checkpoint["state_dict"]
else:
state_dict = old_checkpoint
else:
state_dict = old_checkpoint

print(f"✓ Loaded checkpoint with {len(state_dict)} parameters")

# Convert parameter names if needed
converted_state_dict = convert_checkpoint_parameters(state_dict)

# Create transformer with config
print("Creating MirageTransformer2DModel...")
transformer = MirageTransformer2DModel(**config)

# Load state dict
print("Loading converted parameters...")
missing_keys, unexpected_keys = transformer.load_state_dict(converted_state_dict, strict=False)

if missing_keys:
print(f"⚠ Missing keys: {missing_keys}")
if unexpected_keys:
print(f"⚠ Unexpected keys: {unexpected_keys}")

if not missing_keys and not unexpected_keys:
print("✓ All parameters loaded successfully!")

return transformer


def copy_pipeline_components(vae_type: str, output_path: str):
"""Copy VAE, scheduler, text encoder, and tokenizer from reference pipeline."""

if vae_type == "flux":
ref_pipeline = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_fluxvae_gemmaT5_updated"
else: # dc-ae
ref_pipeline = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_dcae_gemmaT5_updated"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we change these hardcoded paths?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed all dependency to previous ref pipeline, this was a mistake.


components = ["vae", "scheduler", "text_encoder", "tokenizer"]

for component in components:
src_path = os.path.join(ref_pipeline, component)
dst_path = os.path.join(output_path, component)

if os.path.exists(src_path):
if os.path.isdir(src_path):
shutil.copytree(src_path, dst_path, dirs_exist_ok=True)
else:
shutil.copy2(src_path, dst_path)
print(f"✓ Copied {component}")
else:
print(f"⚠ Component not found: {src_path}")


def create_model_index(vae_type: str, output_path: str):
"""Create model_index.json for the pipeline."""

if vae_type == "flux":
vae_class = "AutoencoderKL"
else: # dc-ae
vae_class = "AutoencoderDC"

model_index = {
"_class_name": "MiragePipeline",
"_diffusers_version": "0.31.0.dev0",
"_name_or_path": os.path.basename(output_path),
"scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"],
"text_encoder": ["transformers", "T5GemmaEncoder"],
"tokenizer": ["transformers", "GemmaTokenizerFast"],
"transformer": ["diffusers", "MirageTransformer2DModel"],
"vae": ["diffusers", vae_class],
}

model_index_path = os.path.join(output_path, "model_index.json")
with open(model_index_path, "w") as f:
json.dump(model_index, f, indent=2)

print("✓ Created model_index.json")


def main(args):
# Validate inputs
if not os.path.exists(args.checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}")

# Load reference config based on VAE type
config = load_reference_config(args.vae_type)

# Create output directory
os.makedirs(args.output_path, exist_ok=True)
print(f"✓ Output directory: {args.output_path}")

# Create transformer from checkpoint
transformer = create_transformer_from_checkpoint(args.checkpoint_path, config)

# Save transformer
transformer_path = os.path.join(args.output_path, "transformer")
os.makedirs(transformer_path, exist_ok=True)

# Save config
with open(os.path.join(transformer_path, "config.json"), "w") as f:
json.dump(config, f, indent=2)

# Save model weights as safetensors
state_dict = transformer.state_dict()
save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors"))
print(f"✓ Saved transformer to {transformer_path}")

# Copy other pipeline components
copy_pipeline_components(args.vae_type, args.output_path)

# Create model index
create_model_index(args.vae_type, args.output_path)

# Verify the pipeline can be loaded
try:
pipeline = MiragePipeline.from_pretrained(args.output_path)
print("Pipeline loaded successfully!")
print(f"Transformer: {type(pipeline.transformer).__name__}")
print(f"VAE: {type(pipeline.vae).__name__}")
print(f"Text Encoder: {type(pipeline.text_encoder).__name__}")
print(f"Scheduler: {type(pipeline.scheduler).__name__}")

# Display model info
num_params = sum(p.numel() for p in pipeline.transformer.parameters())
print(f"✓ Transformer parameters: {num_params:,}")

except Exception as e:
print(f"Pipeline verification failed: {e}")
return False

print("Conversion completed successfully!")
print(f"Converted pipeline saved to: {args.output_path}")
print(f"VAE type: {args.vae_type}")

return True


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Mirage checkpoint to diffusers format")

parser.add_argument(
"--checkpoint_path", type=str, required=True, help="Path to the original Mirage checkpoint (.pth file)"
)

parser.add_argument(
"--output_path", type=str, required=True, help="Output directory for the converted diffusers pipeline"
)

parser.add_argument(
"--vae_type",
type=str,
choices=["flux", "dc-ae"],
required=True,
help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)",
)

args = parser.parse_args()

try:
success = main(args)
if not success:
sys.exit(1)
except Exception as e:
print(f"Conversion failed: {e}")
import traceback

traceback.print_exc()
sys.exit(1)
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@
"LTXVideoTransformer3DModel",
"Lumina2Transformer2DModel",
"LuminaNextDiT2DModel",
"MirageTransformer2DModel",
"MochiTransformer3DModel",
"ModelMixin",
"MotionAdapter",
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mirage"] = ["MirageTransformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
Expand Down
58 changes: 58 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5609,6 +5609,63 @@ def __new__(cls, *args, **kwargs):
return processor


class MirageAttnProcessor2_0:
r"""
Processor for implementing Mirage-style attention with multi-source tokens and RoPE.
Properly integrates with diffusers Attention module while handling Mirage-specific logic.
"""

def __init__(self):
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
raise ImportError("MirageAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.")

def __call__(
self,
attn: "Attention",
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Apply Mirage attention using standard diffusers interface.

Expected tensor formats from MirageBlock.attn_forward():
- hidden_states: Image queries with RoPE applied [B, H, L_img, D]
- encoder_hidden_states: Packed key+value tensors [B, H, L_all, 2*D]
(concatenated keys and values from text + image + spatial conditioning)
- attention_mask: Custom attention mask [B, H, L_img, L_all] or None
"""

if encoder_hidden_states is None:
raise ValueError(
"MirageAttnProcessor2_0 requires 'encoder_hidden_states' containing packed key+value tensors. "
"This should be provided by MirageBlock.attn_forward()."
)

# Unpack the combined key+value tensor
# encoder_hidden_states is [B, H, L_all, 2*D] containing [keys, values]
key, value = encoder_hidden_states.chunk(2, dim=-1) # Each [B, H, L_all, D]

# Apply scaled dot-product attention with Mirage's processed tensors
# hidden_states is image queries [B, H, L_img, D]
attn_output = torch.nn.functional.scaled_dot_product_attention(
hidden_states.contiguous(), key.contiguous(), value.contiguous(), attn_mask=attention_mask
)

# Reshape from [B, H, L_img, D] to [B, L_img, H*D]
batch_size, num_heads, seq_len, head_dim = attn_output.shape
attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, num_heads * head_dim)

# Apply output projection using the diffusers Attention module
attn_output = attn.to_out[0](attn_output)
if len(attn.to_out) > 1:
attn_output = attn.to_out[1](attn_output) # dropout if present

return attn_output


ADDED_KV_ATTENTION_PROCESSORS = (
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
Expand Down Expand Up @@ -5657,6 +5714,7 @@ def __new__(cls, *args, **kwargs):
PAGHunyuanAttnProcessor2_0,
PAGCFGHunyuanAttnProcessor2_0,
LuminaAttnProcessor2_0,
MirageAttnProcessor2_0,
FusedAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
CustomDiffusionAttnProcessor2_0,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
from .transformer_ltx import LTXVideoTransformer3DModel
from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mirage import MirageTransformer2DModel
from .transformer_mochi import MochiTransformer3DModel
from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_qwenimage import QwenImageTransformer2DModel
Expand Down
Loading