From 016316a57a8a0df9cacca66af2fa290862b8d5c4 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Fri, 26 Sep 2025 10:20:19 +0200 Subject: [PATCH 1/7] mirage pipeline first commit --- src/diffusers/__init__.py | 1 + src/diffusers/models/__init__.py | 1 + src/diffusers/models/transformers/__init__.py | 1 + .../models/transformers/transformer_mirage.py | 489 ++++++++++++++ src/diffusers/pipelines/__init__.py | 1 + src/diffusers/pipelines/mirage/__init__.py | 4 + .../pipelines/mirage/pipeline_mirage.py | 629 ++++++++++++++++++ .../pipelines/mirage/pipeline_output.py | 35 + .../test_models_transformer_mirage.py | 252 +++++++ 9 files changed, 1413 insertions(+) create mode 100644 src/diffusers/models/transformers/transformer_mirage.py create mode 100644 src/diffusers/pipelines/mirage/__init__.py create mode 100644 src/diffusers/pipelines/mirage/pipeline_mirage.py create mode 100644 src/diffusers/pipelines/mirage/pipeline_output.py create mode 100644 tests/models/transformers/test_models_transformer_mirage.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8867250deda8..6fc6ac5f3ebd 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -224,6 +224,7 @@ "LTXVideoTransformer3DModel", "Lumina2Transformer2DModel", "LuminaNextDiT2DModel", + "MirageTransformer2DModel", "MochiTransformer3DModel", "ModelMixin", "MotionAdapter", diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 457f70448af3..279e69216b1b 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -93,6 +93,7 @@ _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] + _import_structure["transformers.transformer_mirage"] = ["MirageTransformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] _import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"] diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index b60f0636e6dc..ebe0d0c9b8e1 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -29,6 +29,7 @@ from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_lumina2 import Lumina2Transformer2DModel + from .transformer_mirage import MirageTransformer2DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_omnigen import OmniGenTransformer2DModel from .transformer_qwenimage import QwenImageTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_mirage.py b/src/diffusers/models/transformers/transformer_mirage.py new file mode 100644 index 000000000000..39c569cbb26b --- /dev/null +++ b/src/diffusers/models/transformers/transformer_mirage.py @@ -0,0 +1,489 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union, Tuple +import torch +import math +from torch import Tensor, nn +from torch.nn.functional import fold, unfold +from einops import rearrange +from einops.layers.torch import Rearrange + +from ...configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin +from ..modeling_outputs import Transformer2DModelOutput +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers + + +logger = logging.get_logger(__name__) + + +# Mirage Layer Components +def get_image_ids(bs: int, h: int, w: int, patch_size: int, device: torch.device) -> Tensor: + img_ids = torch.zeros(h // patch_size, w // patch_size, 2, device=device) + img_ids[..., 0] = torch.arange(h // patch_size, device=device)[:, None] + img_ids[..., 1] = torch.arange(w // patch_size, device=device)[None, :] + return img_ids.reshape((h // patch_size) * (w // patch_size), 2).unsqueeze(0).repeat(bs, 1, 1) + + +def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq) + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + self.rope_rearrange = Rearrange("b n d (i j) -> b n d i j", i=2, j=2) + + def rope(self, pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = pos.unsqueeze(-1) * omega.unsqueeze(0) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = self.rope_rearrange(out) + return out.float() + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [self.rope(ids[:, :, i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim: int, max_period: int = 10000, time_factor: float = 1000.0) -> Tensor: + t = time_factor * t + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms * self.scale).to(dtype=x_dtype) + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.lin = nn.Linear(dim, 6 * dim, bias=True) + nn.init.constant_(self.lin.weight, 0) + nn.init.constant_(self.lin.bias, 0) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1) + return ModulationOut(*out[:3]), ModulationOut(*out[3:]) + + +class MirageBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + + self._fsdp_wrap = True + self._activation_checkpointing = True + + self.hidden_dim = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scale = qk_scale or self.head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.hidden_size = hidden_size + + # img qkv + self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_qkv_proj = nn.Linear(hidden_size, hidden_size * 3, bias=False) + self.attn_out = nn.Linear(hidden_size, hidden_size, bias=False) + self.qk_norm = QKNorm(self.head_dim) + + # txt kv + self.txt_kv_proj = nn.Linear(hidden_size, hidden_size * 2, bias=False) + self.k_norm = RMSNorm(self.head_dim) + + + # mlp + self.post_attention_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.gate_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False) + self.up_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False) + self.down_proj = nn.Linear(self.mlp_hidden_dim, hidden_size, bias=False) + self.mlp_act = nn.GELU(approximate="tanh") + + self.modulation = Modulation(hidden_size) + self.spatial_cond_kv_proj: None | nn.Linear = None + + def attn_forward( + self, + img: Tensor, + txt: Tensor, + pe: Tensor, + modulation: ModulationOut, + spatial_conditioning: None | Tensor = None, + attention_mask: None | Tensor = None, + ) -> Tensor: + # image tokens proj and norm + img_mod = (1 + modulation.scale) * self.img_pre_norm(img) + modulation.shift + + img_qkv = self.img_qkv_proj(img_mod) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + img_q, img_k = self.qk_norm(img_q, img_k, img_v) + + # txt tokens proj and norm + txt_kv = self.txt_kv_proj(txt) + txt_k, txt_v = rearrange(txt_kv, "B L (K H D) -> K B H L D", K=2, H=self.num_heads) + txt_k = self.k_norm(txt_k) + + # compute attention + img_q, img_k = apply_rope(img_q, pe), apply_rope(img_k, pe) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + # optional spatial conditioning tokens + cond_len = 0 + if self.spatial_cond_kv_proj is not None: + assert spatial_conditioning is not None + cond_kv = self.spatial_cond_kv_proj(spatial_conditioning) + cond_k, cond_v = rearrange(cond_kv, "B L (K H D) -> K B H L D", K=2, H=self.num_heads) + cond_k = apply_rope(cond_k, pe) + cond_len = cond_k.shape[2] + k = torch.cat((cond_k, k), dim=2) + v = torch.cat((cond_v, v), dim=2) + + # build additive attention bias + attn_bias: Tensor | None = None + attn_mask: Tensor | None = None + + # build multiplicative 0/1 mask for provided attention_mask over [cond?, text, image] keys + if attention_mask is not None: + bs, _, l_img, _ = img_q.shape + l_txt = txt_k.shape[2] + l_all = k.shape[2] + + assert attention_mask.dim() == 2, f"Unsupported attention_mask shape: {attention_mask.shape}" + assert ( + attention_mask.shape[-1] == l_txt + ), f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}" + + device = img_q.device + + ones_img = torch.ones((bs, l_img), dtype=torch.bool, device=device) + cond_mask = torch.ones((bs, cond_len), dtype=torch.bool, device=device) + + mask_parts = [ + cond_mask, + attention_mask.to(torch.bool), + ones_img, + ] + joint_mask = torch.cat(mask_parts, dim=-1) # (B, L_all) + + # repeat across heads and query positions + attn_mask = joint_mask[:, None, None, :].expand(-1, self.num_heads, l_img, -1) # (B,H,L_img,L_all) + + attn = torch.nn.functional.scaled_dot_product_attention( + img_q.contiguous(), k.contiguous(), v.contiguous(), attn_mask=attn_mask + ) + attn = rearrange(attn, "B H L D -> B L (H D)") + attn = self.attn_out(attn) + + return attn + + def ffn_forward(self, x: Tensor, modulation: ModulationOut) -> Tensor: + x = (1 + modulation.scale) * self.post_attention_layernorm(x) + modulation.shift + return self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x)) + + def forward( + self, + img: Tensor, + txt: Tensor, + vec: Tensor, + pe: Tensor, + spatial_conditioning: Tensor | None = None, + attention_mask: Tensor | None = None, + **_: dict[str, Any], + ) -> Tensor: + mod_attn, mod_mlp = self.modulation(vec) + + img = img + mod_attn.gate * self.attn_forward( + img, + txt, + pe, + mod_attn, + spatial_conditioning=spatial_conditioning, + attention_mask=attention_mask, + ) + img = img + mod_mlp.gate * self.ffn_forward(img, mod_mlp) + return img + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + nn.init.constant_(self.adaLN_modulation[1].weight, 0) + nn.init.constant_(self.adaLN_modulation[1].bias, 0) + nn.init.constant_(self.linear.weight, 0) + nn.init.constant_(self.linear.bias, 0) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x + + +@dataclass +class MirageParams: + in_channels: int + patch_size: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + axes_dim: list[int] + theta: int + time_factor: float = 1000.0 + time_max_period: int = 10_000 + conditioning_block_ids: list[int] | None = None + + +def img2seq(img: Tensor, patch_size: int) -> Tensor: + """Flatten an image into a sequence of patches""" + return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2) + + +def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor: + """Revert img2seq""" + if isinstance(shape, tuple): + shape = shape[-2:] + elif isinstance(shape, torch.Tensor): + shape = (int(shape[0]), int(shape[1])) + else: + raise NotImplementedError(f"shape type {type(shape)} not supported") + return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size) + + +class MirageTransformer2DModel(ModelMixin, ConfigMixin): + """Mirage Transformer model with IP-Adapter support.""" + + config_name = "config.json" + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 16, + patch_size: int = 2, + context_in_dim: int = 2304, + hidden_size: int = 1792, + mlp_ratio: float = 3.5, + num_heads: int = 28, + depth: int = 16, + axes_dim: list = None, + theta: int = 10000, + time_factor: float = 1000.0, + time_max_period: int = 10000, + conditioning_block_ids: list = None, + **kwargs + ): + super().__init__() + + if axes_dim is None: + axes_dim = [32, 32] + + # Create MirageParams from the provided arguments + params = MirageParams( + in_channels=in_channels, + patch_size=patch_size, + context_in_dim=context_in_dim, + hidden_size=hidden_size, + mlp_ratio=mlp_ratio, + num_heads=num_heads, + depth=depth, + axes_dim=axes_dim, + theta=theta, + time_factor=time_factor, + time_max_period=time_max_period, + conditioning_block_ids=conditioning_block_ids, + ) + + self.params = params + self.in_channels = params.in_channels + self.patch_size = params.patch_size + self.out_channels = self.in_channels * self.patch_size**2 + + self.time_factor = params.time_factor + self.time_max_period = params.time_max_period + + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + + pe_dim = params.hidden_size // params.num_heads + + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + conditioning_block_ids: list[int] = params.conditioning_block_ids or list(range(params.depth)) + + self.blocks = nn.ModuleList( + [ + MirageBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + ) + for i in range(params.depth) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + def process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]: + """Timestep independent stuff""" + txt = self.txt_in(txt) + img = img2seq(image_latent, self.patch_size) + bs, _, h, w = image_latent.shape + img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=image_latent.device) + pe = self.pe_embedder(img_ids) + return img, txt, pe + + def compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Tensor: + return self.time_in( + timestep_embedding( + t=timestep, dim=256, max_period=self.time_max_period, time_factor=self.time_factor + ).to(dtype) + ) + + def forward_transformers( + self, + image_latent: Tensor, + cross_attn_conditioning: Tensor, + timestep: Optional[Tensor] = None, + time_embedding: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + **block_kwargs: Any, + ) -> Tensor: + img = self.img_in(image_latent) + + if time_embedding is not None: + vec = time_embedding + else: + if timestep is None: + raise ValueError("Please provide either a timestep or a timestep_embedding") + vec = self.compute_timestep_embedding(timestep, dtype=img.dtype) + + for block in self.blocks: + img = block( + img=img, txt=cross_attn_conditioning, vec=vec, attention_mask=attention_mask, **block_kwargs + ) + + img = self.final_layer(img, vec) + return img + + def forward( + self, + image_latent: Tensor, + timestep: Tensor, + cross_attn_conditioning: Tensor, + micro_conditioning: Tensor, + cross_attn_mask: None | Tensor = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + img_seq, txt, pe = self.process_inputs(image_latent, cross_attn_conditioning) + img_seq = self.forward_transformers(img_seq, txt, timestep, pe=pe, attention_mask=cross_attn_mask) + output = seq2img(img_seq, self.patch_size, image_latent.shape) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 190c7871d270..7b7ebb633c3b 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -144,6 +144,7 @@ "FluxKontextPipeline", "FluxKontextInpaintPipeline", ] + _import_structure["mirage"] = ["MiragePipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ "AudioLDM2Pipeline", diff --git a/src/diffusers/pipelines/mirage/__init__.py b/src/diffusers/pipelines/mirage/__init__.py new file mode 100644 index 000000000000..4fd8ad191b3f --- /dev/null +++ b/src/diffusers/pipelines/mirage/__init__.py @@ -0,0 +1,4 @@ +from .pipeline_mirage import MiragePipeline +from .pipeline_output import MiragePipelineOutput + +__all__ = ["MiragePipeline", "MiragePipelineOutput"] \ No newline at end of file diff --git a/src/diffusers/pipelines/mirage/pipeline_mirage.py b/src/diffusers/pipelines/mirage/pipeline_mirage.py new file mode 100644 index 000000000000..126eab07977c --- /dev/null +++ b/src/diffusers/pipelines/mirage/pipeline_mirage.py @@ -0,0 +1,629 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +from typing import Any, Callable, Dict, List, Optional, Union + +import html +import re +import urllib.parse as ul + +import ftfy +import torch +from transformers import ( + AutoTokenizer, + GemmaTokenizerFast, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, AutoencoderDC +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import MiragePipelineOutput + +try: + from ...models.transformers.transformer_mirage import MirageTransformer2DModel +except ImportError: + MirageTransformer2DModel = None + +logger = logging.get_logger(__name__) + + +class TextPreprocessor: + """Text preprocessing utility for MiragePipeline.""" + + def __init__(self): + """Initialize text preprocessor.""" + self.bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + r"\\" + r"\/" + r"\*" + r"]{1,}" + ) + + def clean_text(self, text: str) -> str: + """Clean text using comprehensive text processing logic.""" + # See Deepfloyd https://github.com/deep-floyd/IF/blob/develop/deepfloyd_if/modules/t5.py + text = str(text) + text = ul.unquote_plus(text) + text = text.strip().lower() + text = re.sub("", "person", text) + + # Remove all urls: + text = re.sub( + r"\b((?:https?|www):(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@))", + "", + text, + ) # regex for urls + + # @ + text = re.sub(r"@[\w\d]+\b", "", text) + + # 31C0—31EF CJK Strokes through 4E00—9FFF CJK Unified Ideographs + text = re.sub(r"[\u31c0-\u31ef]+", "", text) + text = re.sub(r"[\u31f0-\u31ff]+", "", text) + text = re.sub(r"[\u3200-\u32ff]+", "", text) + text = re.sub(r"[\u3300-\u33ff]+", "", text) + text = re.sub(r"[\u3400-\u4dbf]+", "", text) + text = re.sub(r"[\u4dc0-\u4dff]+", "", text) + text = re.sub(r"[\u4e00-\u9fff]+", "", text) + + # все виды тире / all types of dash --> "-" + text = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", + "-", + text, + ) + + # кавычки к одному стандарту + text = re.sub(r"[`´«»""¨]", '"', text) + text = re.sub(r"['']", "'", text) + + # " and & + text = re.sub(r""?", "", text) + text = re.sub(r"&", "", text) + + # ip addresses: + text = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", text) + + # article ids: + text = re.sub(r"\d:\d\d\s+$", "", text) + + # \n + text = re.sub(r"\\n", " ", text) + + # "#123", "#12345..", "123456.." + text = re.sub(r"#\d{1,3}\b", "", text) + text = re.sub(r"#\d{5,}\b", "", text) + text = re.sub(r"\b\d{6,}\b", "", text) + + # filenames: + text = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", text) + + # Clean punctuation + text = re.sub(r"[\"\']{2,}", r'"', text) # """AUSVERKAUFT""" + text = re.sub(r"[\.]{2,}", r" ", text) + + text = re.sub(self.bad_punct_regex, r" ", text) # ***AUSVERKAUFT***, #AUSVERKAUFT + text = re.sub(r"\s+\.\s+", r" ", text) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, text)) > 3: + text = re.sub(regex2, " ", text) + + # Basic cleaning + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + text = text.strip() + + # Clean alphanumeric patterns + text = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", text) # jc6640 + text = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", text) # jc6640vc + text = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", text) # 6640vc231 + + # Common spam patterns + text = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", text) + text = re.sub(r"(free\s)?download(\sfree)?", "", text) + text = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", text) + text = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", text) + text = re.sub(r"\bpage\s+\d+\b", "", text) + + text = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", text) # j2d1a2a... + text = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", text) + + # Final cleanup + text = re.sub(r"\b\s+\:\s+", r": ", text) + text = re.sub(r"(\D[,\./])\b", r"\1 ", text) + text = re.sub(r"\s+", " ", text) + + text.strip() + + text = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", text) + text = re.sub(r"^[\'\_,\-\:;]", r"", text) + text = re.sub(r"[\'\_,\-\:\-\+]$", r"", text) + text = re.sub(r"^\.\S+$", "", text) + + return text.strip() + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import MiragePipeline + >>> from diffusers.models import AutoencoderKL, AutoencoderDC + >>> from transformers import T5GemmaModel, GemmaTokenizerFast + + >>> # Load pipeline directly with from_pretrained + >>> pipe = MiragePipeline.from_pretrained("path/to/mirage_checkpoint") + + >>> # Or initialize pipeline components manually + >>> transformer = MirageTransformer2DModel.from_pretrained("path/to/transformer") + >>> scheduler = FlowMatchEulerDiscreteScheduler() + >>> # Load T5Gemma encoder + >>> t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2") + >>> text_encoder = t5gemma_model.encoder + >>> tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2") + >>> vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae") + + >>> pipe = MiragePipeline( + ... transformer=transformer, + ... scheduler=scheduler, + ... text_encoder=text_encoder, + ... tokenizer=tokenizer, + ... vae=vae + ... ) + >>> pipe.to("cuda") + >>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach" + >>> image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0] + >>> image.save("mirage_output.png") + ``` +""" + + +class MiragePipeline( + DiffusionPipeline, + LoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): + r""" + Pipeline for text-to-image generation using Mirage Transformer. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + transformer ([`MirageTransformer2DModel`]): + The Mirage transformer model to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + text_encoder ([`T5EncoderModel`]): + Standard text encoder model for encoding prompts. + tokenizer ([`T5TokenizerFast` or `GemmaTokenizerFast`]): + Tokenizer for the text encoder. + vae ([`AutoencoderKL`] or [`AutoencoderDC`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + Supports both AutoencoderKL (8x compression) and AutoencoderDC (32x compression). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents"] + _optional_components = [] + + # Component configurations for automatic loading + config_name = "model_index.json" + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + """ + Override from_pretrained to ensure T5GemmaEncoder is available for loading. + + This ensures that T5GemmaEncoder from transformers is accessible in the module namespace + during component loading, which is required for MiragePipeline checkpoints that use + T5GemmaEncoder as the text encoder. + """ + # Ensure T5GemmaEncoder is available for loading + import transformers + if not hasattr(transformers, 'T5GemmaEncoder'): + try: + from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder + transformers.T5GemmaEncoder = T5GemmaEncoder + except ImportError: + # T5GemmaEncoder not available in this transformers version + pass + + # Proceed with standard loading + return super().from_pretrained(pretrained_model_name_or_path, **kwargs) + + + def __init__( + self, + transformer: MirageTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder: Union[T5EncoderModel, Any], + tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer], + vae: Union[AutoencoderKL, AutoencoderDC], + ): + super().__init__() + + if MirageTransformer2DModel is None: + raise ImportError( + "MirageTransformer2DModel is not available. Please ensure the transformer_mirage module is properly installed." + ) + + # Store standard components + self.text_encoder = text_encoder + self.tokenizer = tokenizer + + # Initialize text preprocessor + self.text_preprocessor = TextPreprocessor() + + self.register_modules( + transformer=transformer, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + ) + + # Enhance VAE with universal properties for both AutoencoderKL and AutoencoderDC + self._enhance_vae_properties() + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + + def _enhance_vae_properties(self): + """Add universal properties to VAE for consistent interface across AutoencoderKL and AutoencoderDC.""" + if not hasattr(self, "vae") or self.vae is None: + return + + # Set spatial_compression_ratio property + if hasattr(self.vae, "spatial_compression_ratio"): + # AutoencoderDC already has this property + pass + elif hasattr(self.vae, "config") and hasattr(self.vae.config, "block_out_channels"): + # AutoencoderKL: calculate from block_out_channels + self.vae.spatial_compression_ratio = 2 ** (len(self.vae.config.block_out_channels) - 1) + else: + # Fallback + self.vae.spatial_compression_ratio = 8 + + # Set scaling_factor property with safe defaults + if hasattr(self.vae, "config"): + self.vae.scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215) + else: + self.vae.scaling_factor = 0.18215 + + # Set shift_factor property with safe defaults (0.0 for AutoencoderDC) + if hasattr(self.vae, "config"): + shift_factor = getattr(self.vae.config, "shift_factor", None) + if shift_factor is None: # AutoencoderDC case + self.vae.shift_factor = 0.0 + else: + self.vae.shift_factor = shift_factor + else: + self.vae.shift_factor = 0.0 + + # Set latent_channels property (like VaeTower does) + if hasattr(self.vae, "config") and hasattr(self.vae.config, "latent_channels"): + # AutoencoderDC has latent_channels in config + self.vae.latent_channels = int(self.vae.config.latent_channels) + elif hasattr(self.vae, "config") and hasattr(self.vae.config, "in_channels"): + # AutoencoderKL has in_channels in config + self.vae.latent_channels = int(self.vae.config.in_channels) + else: + # Fallback based on VAE type - DC-AE typically has 32, AutoencoderKL has 4/16 + if hasattr(self.vae, "spatial_compression_ratio") and self.vae.spatial_compression_ratio == 32: + self.vae.latent_channels = 32 # DC-AE default + else: + self.vae.latent_channels = 4 # AutoencoderKL default + + @property + def vae_scale_factor(self): + """Compatibility property that returns spatial compression ratio.""" + return getattr(self.vae, "spatial_compression_ratio", 8) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ): + """Prepare initial latents for the diffusion process.""" + if latents is None: + latent_height, latent_width = height // self.vae.spatial_compression_ratio, width // self.vae.spatial_compression_ratio + shape = (batch_size, num_channels_latents, latent_height, latent_width) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # FlowMatchEulerDiscreteScheduler doesn't use init_noise_sigma scaling + return latents + + def encode_prompt(self, prompt: Union[str, List[str]], device: torch.device): + """Encode text prompt using standard text encoder and tokenizer.""" + if isinstance(prompt, str): + prompt = [prompt] + + return self._encode_prompt_standard(prompt, device) + + def _encode_prompt_standard(self, prompt: List[str], device: torch.device): + """Encode prompt using standard text encoder and tokenizer with batch processing.""" + # Clean text using modular preprocessor + cleaned_prompts = [self.text_preprocessor.clean_text(text) for text in prompt] + cleaned_uncond_prompts = [self.text_preprocessor.clean_text("") for _ in prompt] + + # Batch conditional and unconditional prompts together for efficiency + all_prompts = cleaned_prompts + cleaned_uncond_prompts + + # Tokenize all prompts in one batch + tokens = self.tokenizer( + all_prompts, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + + input_ids = tokens["input_ids"].to(device) + attention_mask = tokens["attention_mask"].bool().to(device) + + # Encode all prompts in one batch + with torch.no_grad(): + # Disable autocast like in TextTower + with torch.autocast("cuda", enabled=False): + emb = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + + # Use last hidden state (matching TextTower's use_last_hidden_state=True default) + all_embeddings = emb["last_hidden_state"] + + # Split back into conditional and unconditional + batch_size = len(prompt) + text_embeddings = all_embeddings[:batch_size] + uncond_text_embeddings = all_embeddings[batch_size:] + + cross_attn_mask = attention_mask[:batch_size] + uncond_cross_attn_mask = attention_mask[batch_size:] + + return text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask + + def check_inputs( + self, + prompt: Union[str, List[str]], + height: int, + width: int, + guidance_scale: float, + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + ): + """Check that all inputs are in correct format.""" + if height % self.vae.spatial_compression_ratio != 0 or width % self.vae.spatial_compression_ratio != 0: + raise ValueError(f"`height` and `width` have to be divisible by {self.vae.spatial_compression_ratio} but are {height} and {width}.") + + if guidance_scale < 1.0: + raise ValueError(f"guidance_scale has to be >= 1.0 but is {guidance_scale}") + + if callback_on_step_end_tensor_inputs is not None and not isinstance(callback_on_step_end_tensor_inputs, list): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be a list but is {callback_on_step_end_tensor_inputs}" + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 4.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 28): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.mirage.MiragePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self, step, timestep, callback_kwargs)`. + `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include tensors that are listed + in the `._callback_tensor_inputs` attribute. + + Examples: + + Returns: + [`~pipelines.mirage.MiragePipelineOutput`] or `tuple`: [`~pipelines.mirage.MiragePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + + # 0. Default height and width to transformer config + height = height or 256 + width = width or 256 + + # 1. Check inputs + self.check_inputs( + prompt, + height, + width, + guidance_scale, + callback_on_step_end_tensor_inputs, + ) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError("prompt must be provided as a string or list of strings") + + device = self._execution_device + + # 2. Encode input prompt + text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask = self.encode_prompt( + prompt, device + ) + + # 3. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 4. Prepare latent variables + num_channels_latents = self.vae.latent_channels # From your transformer config + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 5. Prepare extra step kwargs + extra_step_kwargs = {} + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_eta: + extra_step_kwargs["eta"] = 0.0 + + # 6. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Duplicate latents for CFG + latents_in = torch.cat([latents, latents], dim=0) + + # Cross-attention batch (uncond, cond) + ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0) + ca_mask = None + if cross_attn_mask is not None and uncond_cross_attn_mask is not None: + ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0) + + # Normalize timestep for the transformer + t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device) + + # Process inputs for transformer + img_seq, txt, pe = self.transformer.process_inputs(latents_in, ca_embed) + + # Forward through transformer layers + img_seq = self.transformer.forward_transformers( + img_seq, txt, time_embedding=self.transformer.compute_timestep_embedding(t_cont, img_seq.dtype), + pe=pe, attention_mask=ca_mask + ) + + # Convert back to image format + from ...models.transformers.transformer_mirage import seq2img + noise_both = seq2img(img_seq, self.transformer.patch_size, latents_in.shape) + + # Apply CFG + noise_uncond, noise_text = noise_both.chunk(2, dim=0) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # Compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_on_step_end(self, i, t, callback_kwargs) + + # Call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # 8. Post-processing + if output_type == "latent": + image = latents + else: + # Unscale latents for VAE (supports both AutoencoderKL and AutoencoderDC) + latents = (latents / self.vae.scaling_factor) + self.vae.shift_factor + # Decode using VAE (AutoencoderKL or AutoencoderDC) + image = self.vae.decode(latents, return_dict=False)[0] + # Use standard image processor for post-processing + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return MiragePipelineOutput(images=image) \ No newline at end of file diff --git a/src/diffusers/pipelines/mirage/pipeline_output.py b/src/diffusers/pipelines/mirage/pipeline_output.py new file mode 100644 index 000000000000..e5cdb2a40924 --- /dev/null +++ b/src/diffusers/pipelines/mirage/pipeline_output.py @@ -0,0 +1,35 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class MiragePipelineOutput(BaseOutput): + """ + Output class for Mirage pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] \ No newline at end of file diff --git a/tests/models/transformers/test_models_transformer_mirage.py b/tests/models/transformers/test_models_transformer_mirage.py new file mode 100644 index 000000000000..11accdaecbee --- /dev/null +++ b/tests/models/transformers/test_models_transformer_mirage.py @@ -0,0 +1,252 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel, MirageParams + +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class MirageTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = MirageTransformer2DModel + main_input_name = "image_latent" + + @property + def dummy_input(self): + return self.prepare_dummy_input() + + @property + def input_shape(self): + return (16, 4, 4) + + @property + def output_shape(self): + return (16, 4, 4) + + def prepare_dummy_input(self, height=32, width=32): + batch_size = 1 + num_latent_channels = 16 + sequence_length = 16 + embedding_dim = 1792 + + image_latent = torch.randn((batch_size, num_latent_channels, height, width)).to(torch_device) + cross_attn_conditioning = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + micro_conditioning = torch.randn((batch_size, embedding_dim)).to(torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + + return { + "image_latent": image_latent, + "timestep": timestep, + "cross_attn_conditioning": cross_attn_conditioning, + "micro_conditioning": micro_conditioning, + } + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 16, + "patch_size": 2, + "context_in_dim": 1792, + "hidden_size": 1792, + "mlp_ratio": 3.5, + "num_heads": 28, + "depth": 4, # Smaller depth for testing + "axes_dim": [32, 32], + "theta": 10_000, + } + inputs_dict = self.prepare_dummy_input() + return init_dict, inputs_dict + + def test_forward_signature(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + # Test forward + outputs = model(**inputs_dict) + + self.assertIsNotNone(outputs) + expected_shape = inputs_dict["image_latent"].shape + self.assertEqual(outputs.shape, expected_shape) + + def test_mirage_params_initialization(self): + # Test model initialization + model = MirageTransformer2DModel( + in_channels=16, + patch_size=2, + context_in_dim=1792, + hidden_size=1792, + mlp_ratio=3.5, + num_heads=28, + depth=4, + axes_dim=[32, 32], + theta=10_000, + ) + self.assertEqual(model.config.in_channels, 16) + self.assertEqual(model.config.hidden_size, 1792) + self.assertEqual(model.config.num_heads, 28) + + def test_model_with_dict_config(self): + # Test model initialization with from_config + config_dict = { + "in_channels": 16, + "patch_size": 2, + "context_in_dim": 1792, + "hidden_size": 1792, + "mlp_ratio": 3.5, + "num_heads": 28, + "depth": 4, + "axes_dim": [32, 32], + "theta": 10_000, + } + + model = MirageTransformer2DModel.from_config(config_dict) + self.assertEqual(model.config.in_channels, 16) + self.assertEqual(model.config.hidden_size, 1792) + + def test_process_inputs(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + img_seq, txt, pe = model.process_inputs( + inputs_dict["image_latent"], + inputs_dict["cross_attn_conditioning"] + ) + + # Check shapes + batch_size = inputs_dict["image_latent"].shape[0] + height, width = inputs_dict["image_latent"].shape[2:] + patch_size = init_dict["patch_size"] + expected_seq_len = (height // patch_size) * (width // patch_size) + + self.assertEqual(img_seq.shape, (batch_size, expected_seq_len, init_dict["in_channels"] * patch_size**2)) + self.assertEqual(txt.shape, (batch_size, inputs_dict["cross_attn_conditioning"].shape[1], init_dict["hidden_size"])) + # Check that pe has the correct batch size, sequence length and some embedding dimension + self.assertEqual(pe.shape[0], batch_size) # batch size + self.assertEqual(pe.shape[1], 1) # unsqueeze(1) in EmbedND + self.assertEqual(pe.shape[2], expected_seq_len) # sequence length + self.assertEqual(pe.shape[-2:], (2, 2)) # rope rearrange output + + def test_forward_transformers(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + # Process inputs first + img_seq, txt, pe = model.process_inputs( + inputs_dict["image_latent"], + inputs_dict["cross_attn_conditioning"] + ) + + # Test forward_transformers + output_seq = model.forward_transformers( + img_seq, + txt, + timestep=inputs_dict["timestep"], + pe=pe + ) + + # Check output shape + expected_out_channels = init_dict["in_channels"] * init_dict["patch_size"]**2 + self.assertEqual(output_seq.shape, (img_seq.shape[0], img_seq.shape[1], expected_out_channels)) + + def test_attention_mask(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + # Create attention mask + batch_size = inputs_dict["cross_attn_conditioning"].shape[0] + seq_len = inputs_dict["cross_attn_conditioning"].shape[1] + attention_mask = torch.ones((batch_size, seq_len), dtype=torch.bool).to(torch_device) + attention_mask[:, seq_len//2:] = False # Mask second half + + with torch.no_grad(): + outputs = model( + **inputs_dict, + cross_attn_mask=attention_mask + ) + + self.assertIsNotNone(outputs) + expected_shape = inputs_dict["image_latent"].shape + self.assertEqual(outputs.shape, expected_shape) + + def test_invalid_config(self): + # Test invalid configuration - hidden_size not divisible by num_heads + with self.assertRaises(ValueError): + MirageTransformer2DModel( + in_channels=16, + patch_size=2, + context_in_dim=1792, + hidden_size=1793, # Not divisible by 28 + mlp_ratio=3.5, + num_heads=28, + depth=4, + axes_dim=[32, 32], + theta=10_000, + ) + + # Test invalid axes_dim that doesn't sum to pe_dim + with self.assertRaises(ValueError): + MirageTransformer2DModel( + in_channels=16, + patch_size=2, + context_in_dim=1792, + hidden_size=1792, + mlp_ratio=3.5, + num_heads=28, + depth=4, + axes_dim=[30, 30], # Sum = 60, but pe_dim = 1792/28 = 64 + theta=10_000, + ) + + def test_gradient_checkpointing_enable(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + + # Enable gradient checkpointing + model.enable_gradient_checkpointing() + + # Check that _activation_checkpointing is set + for block in model.blocks: + self.assertTrue(hasattr(block, '_activation_checkpointing')) + + def test_from_config(self): + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + + # Create model from config + model = self.model_class.from_config(init_dict) + self.assertIsInstance(model, self.model_class) + self.assertEqual(model.config.in_channels, init_dict["in_channels"]) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 4ac274be3d7647655437c6b810d1daa5c650f093 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Fri, 26 Sep 2025 11:51:14 +0200 Subject: [PATCH 2/7] use attention processors --- src/diffusers/models/attention_processor.py | 58 +++++++++++++ .../models/transformers/transformer_mirage.py | 86 ++++++++++++++++--- 2 files changed, 134 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 990245de1742..08e80e4329ba 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5609,6 +5609,63 @@ def __new__(cls, *args, **kwargs): return processor +class MirageAttnProcessor2_0: + r""" + Processor for implementing Mirage-style attention with multi-source tokens and RoPE. + Properly integrates with diffusers Attention module while handling Mirage-specific logic. + """ + + def __init__(self): + if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): + raise ImportError("MirageAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: "Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Apply Mirage attention using standard diffusers interface. + + Expected tensor formats from MirageBlock.attn_forward(): + - hidden_states: Image queries with RoPE applied [B, H, L_img, D] + - encoder_hidden_states: Packed key+value tensors [B, H, L_all, 2*D] + (concatenated keys and values from text + image + spatial conditioning) + - attention_mask: Custom attention mask [B, H, L_img, L_all] or None + """ + + if encoder_hidden_states is None: + raise ValueError( + "MirageAttnProcessor2_0 requires 'encoder_hidden_states' containing packed key+value tensors. " + "This should be provided by MirageBlock.attn_forward()." + ) + + # Unpack the combined key+value tensor + # encoder_hidden_states is [B, H, L_all, 2*D] containing [keys, values] + key, value = encoder_hidden_states.chunk(2, dim=-1) # Each [B, H, L_all, D] + + # Apply scaled dot-product attention with Mirage's processed tensors + # hidden_states is image queries [B, H, L_img, D] + attn_output = torch.nn.functional.scaled_dot_product_attention( + hidden_states.contiguous(), key.contiguous(), value.contiguous(), attn_mask=attention_mask + ) + + # Reshape from [B, H, L_img, D] to [B, L_img, H*D] + batch_size, num_heads, seq_len, head_dim = attn_output.shape + attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, num_heads * head_dim) + + # Apply output projection using the diffusers Attention module + attn_output = attn.to_out[0](attn_output) + if len(attn.to_out) > 1: + attn_output = attn.to_out[1](attn_output) # dropout if present + + return attn_output + + ADDED_KV_ATTENTION_PROCESSORS = ( AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, @@ -5657,6 +5714,7 @@ def __new__(cls, *args, **kwargs): PAGHunyuanAttnProcessor2_0, PAGCFGHunyuanAttnProcessor2_0, LuminaAttnProcessor2_0, + MirageAttnProcessor2_0, FusedAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0, diff --git a/src/diffusers/models/transformers/transformer_mirage.py b/src/diffusers/models/transformers/transformer_mirage.py index 39c569cbb26b..0225b9532aff 100644 --- a/src/diffusers/models/transformers/transformer_mirage.py +++ b/src/diffusers/models/transformers/transformer_mirage.py @@ -24,6 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from ..modeling_outputs import Transformer2DModelOutput +from ..attention_processor import Attention, AttentionProcessor, MirageAttnProcessor2_0 from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers @@ -159,13 +160,21 @@ def __init__( # img qkv self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.img_qkv_proj = nn.Linear(hidden_size, hidden_size * 3, bias=False) - self.attn_out = nn.Linear(hidden_size, hidden_size, bias=False) self.qk_norm = QKNorm(self.head_dim) # txt kv self.txt_kv_proj = nn.Linear(hidden_size, hidden_size * 2, bias=False) self.k_norm = RMSNorm(self.head_dim) + self.attention = Attention( + query_dim=hidden_size, + heads=num_heads, + dim_head=self.head_dim, + bias=False, + out_bias=False, + processor=MirageAttnProcessor2_0(), + ) + # mlp self.post_attention_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) @@ -214,15 +223,11 @@ def attn_forward( k = torch.cat((cond_k, k), dim=2) v = torch.cat((cond_v, v), dim=2) - # build additive attention bias - attn_bias: Tensor | None = None - attn_mask: Tensor | None = None - # build multiplicative 0/1 mask for provided attention_mask over [cond?, text, image] keys + attn_mask: Tensor | None = None if attention_mask is not None: bs, _, l_img, _ = img_q.shape l_txt = txt_k.shape[2] - l_all = k.shape[2] assert attention_mask.dim() == 2, f"Unsupported attention_mask shape: {attention_mask.shape}" assert ( @@ -244,11 +249,13 @@ def attn_forward( # repeat across heads and query positions attn_mask = joint_mask[:, None, None, :].expand(-1, self.num_heads, l_img, -1) # (B,H,L_img,L_all) - attn = torch.nn.functional.scaled_dot_product_attention( - img_q.contiguous(), k.contiguous(), v.contiguous(), attn_mask=attn_mask + kv_packed = torch.cat([k, v], dim=-1) + + attn = self.attention( + hidden_states=img_q, + encoder_hidden_states=kv_packed, + attention_mask=attn_mask, ) - attn = rearrange(attn, "B H L D -> B L (H D)") - attn = self.attn_out(attn) return attn @@ -413,6 +420,65 @@ def __init__( self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + def process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]: """Timestep independent stuff""" txt = self.txt_in(txt) From 904debcd11de7c6103e091b3223cd459b03d05a1 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Fri, 26 Sep 2025 12:50:19 +0200 Subject: [PATCH 3/7] use diffusers rmsnorm --- .../models/transformers/transformer_mirage.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_mirage.py b/src/diffusers/models/transformers/transformer_mirage.py index 0225b9532aff..f4199da1edcc 100644 --- a/src/diffusers/models/transformers/transformer_mirage.py +++ b/src/diffusers/models/transformers/transformer_mirage.py @@ -26,12 +26,12 @@ from ..modeling_outputs import Transformer2DModelOutput from ..attention_processor import Attention, AttentionProcessor, MirageAttnProcessor2_0 from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..normalization import RMSNorm logger = logging.get_logger(__name__) -# Mirage Layer Components def get_image_ids(bs: int, h: int, w: int, patch_size: int, device: torch.device) -> Tensor: img_ids = torch.zeros(h // patch_size, w // patch_size, 2, device=device) img_ids[..., 0] = torch.arange(h // patch_size, device=device)[:, None] @@ -93,23 +93,13 @@ def forward(self, x: Tensor) -> Tensor: return self.out_layer(self.silu(self.in_layer(x))) -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int): - super().__init__() - self.scale = nn.Parameter(torch.ones(dim)) - - def forward(self, x: Tensor) -> Tensor: - x_dtype = x.dtype - x = x.float() - rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) - return (x * rrms * self.scale).to(dtype=x_dtype) class QKNorm(torch.nn.Module): def __init__(self, dim: int): super().__init__() - self.query_norm = RMSNorm(dim) - self.key_norm = RMSNorm(dim) + self.query_norm = RMSNorm(dim, eps=1e-6) + self.key_norm = RMSNorm(dim, eps=1e-6) def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: q = self.query_norm(q) @@ -164,7 +154,7 @@ def __init__( # txt kv self.txt_kv_proj = nn.Linear(hidden_size, hidden_size * 2, bias=False) - self.k_norm = RMSNorm(self.head_dim) + self.k_norm = RMSNorm(self.head_dim, eps=1e-6) self.attention = Attention( query_dim=hidden_size, From 122115adb1305834b298e677ae30fcef65c4fd35 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Fri, 26 Sep 2025 14:26:50 +0200 Subject: [PATCH 4/7] use diffusers timestep embedding method --- .../models/transformers/transformer_mirage.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_mirage.py b/src/diffusers/models/transformers/transformer_mirage.py index f4199da1edcc..916559eb47ac 100644 --- a/src/diffusers/models/transformers/transformer_mirage.py +++ b/src/diffusers/models/transformers/transformer_mirage.py @@ -27,6 +27,7 @@ from ..attention_processor import Attention, AttentionProcessor, MirageAttnProcessor2_0 from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..normalization import RMSNorm +from ..embeddings import get_timestep_embedding logger = logging.get_logger(__name__) @@ -71,15 +72,6 @@ def forward(self, ids: Tensor) -> Tensor: return emb.unsqueeze(1) -def timestep_embedding(t: Tensor, dim: int, max_period: int = 10000, time_factor: float = 1000.0) -> Tensor: - t = time_factor * t - half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) - args = t[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding class MLPEmbedder(nn.Module): @@ -480,8 +472,12 @@ def process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[T def compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Tensor: return self.time_in( - timestep_embedding( - t=timestep, dim=256, max_period=self.time_max_period, time_factor=self.time_factor + get_timestep_embedding( + timesteps=timestep, + embedding_dim=256, + max_period=self.time_max_period, + scale=self.time_factor, + flip_sin_to_cos=True # Match original cos, sin order ).to(dtype) ) From e3fe0e8e1f79216cfe83719debc1ed33dfb3e788 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Fri, 26 Sep 2025 15:17:11 +0200 Subject: [PATCH 5/7] remove MirageParams --- .../models/transformers/transformer_mirage.py | 64 +++++-------------- .../pipelines/mirage/pipeline_output.py | 2 +- .../test_models_transformer_mirage.py | 8 +-- 3 files changed, 22 insertions(+), 52 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_mirage.py b/src/diffusers/models/transformers/transformer_mirage.py index 916559eb47ac..396e000524ec 100644 --- a/src/diffusers/models/transformers/transformer_mirage.py +++ b/src/diffusers/models/transformers/transformer_mirage.py @@ -288,20 +288,6 @@ def forward(self, x: Tensor, vec: Tensor) -> Tensor: return x -@dataclass -class MirageParams: - in_channels: int - patch_size: int - context_in_dim: int - hidden_size: int - mlp_ratio: float - num_heads: int - depth: int - axes_dim: list[int] - theta: int - time_factor: float = 1000.0 - time_max_period: int = 10_000 - conditioning_block_ids: list[int] | None = None def img2seq(img: Tensor, patch_size: int) -> Tensor: @@ -348,55 +334,39 @@ def __init__( if axes_dim is None: axes_dim = [32, 32] - # Create MirageParams from the provided arguments - params = MirageParams( - in_channels=in_channels, - patch_size=patch_size, - context_in_dim=context_in_dim, - hidden_size=hidden_size, - mlp_ratio=mlp_ratio, - num_heads=num_heads, - depth=depth, - axes_dim=axes_dim, - theta=theta, - time_factor=time_factor, - time_max_period=time_max_period, - conditioning_block_ids=conditioning_block_ids, - ) - - self.params = params - self.in_channels = params.in_channels - self.patch_size = params.patch_size + # Store parameters directly + self.in_channels = in_channels + self.patch_size = patch_size self.out_channels = self.in_channels * self.patch_size**2 - self.time_factor = params.time_factor - self.time_max_period = params.time_max_period + self.time_factor = time_factor + self.time_max_period = time_max_period - if params.hidden_size % params.num_heads != 0: - raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + if hidden_size % num_heads != 0: + raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}") - pe_dim = params.hidden_size // params.num_heads + pe_dim = hidden_size // num_heads - if sum(params.axes_dim) != pe_dim: - raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + if sum(axes_dim) != pe_dim: + raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}") - self.hidden_size = params.hidden_size - self.num_heads = params.num_heads - self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.hidden_size = hidden_size + self.num_heads = num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) - self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + self.txt_in = nn.Linear(context_in_dim, self.hidden_size) - conditioning_block_ids: list[int] = params.conditioning_block_ids or list(range(params.depth)) + conditioning_block_ids: list[int] = conditioning_block_ids or list(range(depth)) self.blocks = nn.ModuleList( [ MirageBlock( self.hidden_size, self.num_heads, - mlp_ratio=params.mlp_ratio, + mlp_ratio=mlp_ratio, ) - for i in range(params.depth) + for i in range(depth) ] ) diff --git a/src/diffusers/pipelines/mirage/pipeline_output.py b/src/diffusers/pipelines/mirage/pipeline_output.py index e5cdb2a40924..dfb55821d142 100644 --- a/src/diffusers/pipelines/mirage/pipeline_output.py +++ b/src/diffusers/pipelines/mirage/pipeline_output.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import List, Optional, Union +from typing import List, Union import numpy as np import PIL.Image diff --git a/tests/models/transformers/test_models_transformer_mirage.py b/tests/models/transformers/test_models_transformer_mirage.py index 11accdaecbee..5e7b0bd165a6 100644 --- a/tests/models/transformers/test_models_transformer_mirage.py +++ b/tests/models/transformers/test_models_transformer_mirage.py @@ -17,7 +17,7 @@ import torch -from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel, MirageParams +from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel from ...testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import ModelTesterMixin @@ -88,9 +88,9 @@ def test_forward_signature(self): self.assertIsNotNone(outputs) expected_shape = inputs_dict["image_latent"].shape - self.assertEqual(outputs.shape, expected_shape) + self.assertEqual(outputs.sample.shape, expected_shape) - def test_mirage_params_initialization(self): + def test_model_initialization(self): # Test model initialization model = MirageTransformer2DModel( in_channels=16, @@ -196,7 +196,7 @@ def test_attention_mask(self): self.assertIsNotNone(outputs) expected_shape = inputs_dict["image_latent"].shape - self.assertEqual(outputs.shape, expected_shape) + self.assertEqual(outputs.sample.shape, expected_shape) def test_invalid_config(self): # Test invalid configuration - hidden_size not divisible by num_heads From 85ae87b9311a1432f43c2928389c8eafc86c0991 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Fri, 26 Sep 2025 16:35:56 +0200 Subject: [PATCH 6/7] checkpoint conversion script --- scripts/convert_mirage_to_diffusers.py | 312 +++++++++++++++++++++++++ 1 file changed, 312 insertions(+) create mode 100644 scripts/convert_mirage_to_diffusers.py diff --git a/scripts/convert_mirage_to_diffusers.py b/scripts/convert_mirage_to_diffusers.py new file mode 100644 index 000000000000..85716e69ff92 --- /dev/null +++ b/scripts/convert_mirage_to_diffusers.py @@ -0,0 +1,312 @@ +#!/usr/bin/env python3 +""" +Script to convert Mirage checkpoint from original codebase to diffusers format. +""" + +import argparse +import json +import os +import shutil +import sys +import torch +from safetensors.torch import save_file +from transformers import GemmaTokenizerFast + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) + +from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.pipelines.mirage import MiragePipeline + +def load_reference_config(vae_type: str) -> dict: + """Load transformer config from existing pipeline checkpoint.""" + + if vae_type == "flux": + config_path = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_fluxvae_gemmaT5_updated/transformer/config.json" + elif vae_type == "dc-ae": + config_path = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_dcae_gemmaT5_updated/transformer/config.json" + else: + raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'") + + if not os.path.exists(config_path): + raise FileNotFoundError(f"Reference config not found: {config_path}") + + with open(config_path, 'r') as f: + config = json.load(f) + + print(f"✓ Loaded {vae_type} config: in_channels={config['in_channels']}") + return config + +def create_parameter_mapping() -> dict: + """Create mapping from old parameter names to new diffusers names.""" + + # Key mappings for structural changes + mapping = {} + + # RMSNorm: scale -> weight + for i in range(16): # 16 layers + mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.qk_norm.query_norm.weight" + mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.qk_norm.key_norm.weight" + mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.k_norm.weight" + + # Attention: attn_out -> attention.to_out.0 + mapping[f"blocks.{i}.attn_out.weight"] = f"blocks.{i}.attention.to_out.0.weight" + + return mapping + +def convert_checkpoint_parameters(old_state_dict: dict) -> dict: + """Convert old checkpoint parameters to new diffusers format.""" + + print("Converting checkpoint parameters...") + + mapping = create_parameter_mapping() + converted_state_dict = {} + + # First, print available keys to understand structure + print("Available keys in checkpoint:") + for key in sorted(old_state_dict.keys())[:10]: # Show first 10 keys + print(f" {key}") + if len(old_state_dict) > 10: + print(f" ... and {len(old_state_dict) - 10} more") + + for key, value in old_state_dict.items(): + new_key = key + + # Apply specific mappings if needed + if key in mapping: + new_key = mapping[key] + print(f" Mapped: {key} -> {new_key}") + + # Handle img_qkv_proj -> split to to_q, to_k, to_v + if "img_qkv_proj.weight" in key: + print(f" Found QKV projection: {key}") + # Split QKV weight into separate Q, K, V projections + qkv_weight = value + hidden_size = qkv_weight.shape[1] + q_weight, k_weight, v_weight = qkv_weight.chunk(3, dim=0) + + # Extract layer number from key (e.g., blocks.0.img_qkv_proj.weight -> 0) + parts = key.split('.') + layer_idx = None + for i, part in enumerate(parts): + if part == 'blocks' and i + 1 < len(parts) and parts[i+1].isdigit(): + layer_idx = parts[i+1] + break + + if layer_idx is not None: + converted_state_dict[f"blocks.{layer_idx}.attention.to_q.weight"] = q_weight + converted_state_dict[f"blocks.{layer_idx}.attention.to_k.weight"] = k_weight + converted_state_dict[f"blocks.{layer_idx}.attention.to_v.weight"] = v_weight + print(f" Split QKV for layer {layer_idx}") + + # Also keep the original img_qkv_proj for backward compatibility + converted_state_dict[new_key] = value + else: + converted_state_dict[new_key] = value + + print(f"✓ Converted {len(converted_state_dict)} parameters") + return converted_state_dict + + +def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> MirageTransformer2DModel: + """Create and load MirageTransformer2DModel from old checkpoint.""" + + print(f"Loading checkpoint from: {checkpoint_path}") + + # Load old checkpoint + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + old_checkpoint = torch.load(checkpoint_path, map_location='cpu') + + # Handle different checkpoint formats + if isinstance(old_checkpoint, dict): + if 'model' in old_checkpoint: + state_dict = old_checkpoint['model'] + elif 'state_dict' in old_checkpoint: + state_dict = old_checkpoint['state_dict'] + else: + state_dict = old_checkpoint + else: + state_dict = old_checkpoint + + print(f"✓ Loaded checkpoint with {len(state_dict)} parameters") + + # Convert parameter names if needed + converted_state_dict = convert_checkpoint_parameters(state_dict) + + # Create transformer with config + print("Creating MirageTransformer2DModel...") + transformer = MirageTransformer2DModel(**config) + + # Load state dict + print("Loading converted parameters...") + missing_keys, unexpected_keys = transformer.load_state_dict(converted_state_dict, strict=False) + + if missing_keys: + print(f"⚠ Missing keys: {missing_keys}") + if unexpected_keys: + print(f"⚠ Unexpected keys: {unexpected_keys}") + + if not missing_keys and not unexpected_keys: + print("✓ All parameters loaded successfully!") + + return transformer + +def copy_pipeline_components(vae_type: str, output_path: str): + """Copy VAE, scheduler, text encoder, and tokenizer from reference pipeline.""" + + if vae_type == "flux": + ref_pipeline = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_fluxvae_gemmaT5_updated" + else: # dc-ae + ref_pipeline = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_dcae_gemmaT5_updated" + + components = ["vae", "scheduler", "text_encoder", "tokenizer"] + + for component in components: + src_path = os.path.join(ref_pipeline, component) + dst_path = os.path.join(output_path, component) + + if os.path.exists(src_path): + if os.path.isdir(src_path): + shutil.copytree(src_path, dst_path, dirs_exist_ok=True) + else: + shutil.copy2(src_path, dst_path) + print(f"✓ Copied {component}") + else: + print(f"⚠ Component not found: {src_path}") + +def create_model_index(vae_type: str, output_path: str): + """Create model_index.json for the pipeline.""" + + if vae_type == "flux": + vae_class = "AutoencoderKL" + else: # dc-ae + vae_class = "AutoencoderDC" + + model_index = { + "_class_name": "MiragePipeline", + "_diffusers_version": "0.31.0.dev0", + "_name_or_path": os.path.basename(output_path), + "scheduler": [ + "diffusers", + "FlowMatchEulerDiscreteScheduler" + ], + "text_encoder": [ + "transformers", + "T5GemmaEncoder" + ], + "tokenizer": [ + "transformers", + "GemmaTokenizerFast" + ], + "transformer": [ + "diffusers", + "MirageTransformer2DModel" + ], + "vae": [ + "diffusers", + vae_class + ] + } + + model_index_path = os.path.join(output_path, "model_index.json") + with open(model_index_path, 'w') as f: + json.dump(model_index, f, indent=2) + + print(f"✓ Created model_index.json") + +def main(args): + # Validate inputs + if not os.path.exists(args.checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}") + + # Load reference config based on VAE type + config = load_reference_config(args.vae_type) + + # Create output directory + os.makedirs(args.output_path, exist_ok=True) + print(f"✓ Output directory: {args.output_path}") + + # Create transformer from checkpoint + transformer = create_transformer_from_checkpoint(args.checkpoint_path, config) + + # Save transformer + transformer_path = os.path.join(args.output_path, "transformer") + os.makedirs(transformer_path, exist_ok=True) + + # Save config + with open(os.path.join(transformer_path, "config.json"), 'w') as f: + json.dump(config, f, indent=2) + + # Save model weights as safetensors + state_dict = transformer.state_dict() + save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors")) + print(f"✓ Saved transformer to {transformer_path}") + + # Copy other pipeline components + copy_pipeline_components(args.vae_type, args.output_path) + + # Create model index + create_model_index(args.vae_type, args.output_path) + + # Verify the pipeline can be loaded + try: + pipeline = MiragePipeline.from_pretrained(args.output_path) + print(f"Pipeline loaded successfully!") + print(f"Transformer: {type(pipeline.transformer).__name__}") + print(f"VAE: {type(pipeline.vae).__name__}") + print(f"Text Encoder: {type(pipeline.text_encoder).__name__}") + print(f"Scheduler: {type(pipeline.scheduler).__name__}") + + # Display model info + num_params = sum(p.numel() for p in pipeline.transformer.parameters()) + print(f"✓ Transformer parameters: {num_params:,}") + + except Exception as e: + print(f"Pipeline verification failed: {e}") + return False + + print("Conversion completed successfully!") + print(f"Converted pipeline saved to: {args.output_path}") + print(f"VAE type: {args.vae_type}") + + + return True + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Mirage checkpoint to diffusers format") + + parser.add_argument( + "--checkpoint_path", + type=str, + required=True, + help="Path to the original Mirage checkpoint (.pth file)" + ) + + parser.add_argument( + "--output_path", + type=str, + required=True, + help="Output directory for the converted diffusers pipeline" + ) + + parser.add_argument( + "--vae_type", + type=str, + choices=["flux", "dc-ae"], + required=True, + help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)" + ) + + args = parser.parse_args() + + try: + success = main(args) + if not success: + sys.exit(1) + except Exception as e: + print(f"❌ Conversion failed: {e}") + import traceback + traceback.print_exc() + sys.exit(1) \ No newline at end of file From 9a697d06b70eaa4e0c9f1f1b5bca6209c65b005b Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Fri, 26 Sep 2025 17:00:55 +0200 Subject: [PATCH 7/7] ruff formating --- scripts/convert_mirage_to_diffusers.py | 83 ++++++++----------- .../models/transformers/transformer_mirage.py | 41 ++++----- src/diffusers/pipelines/mirage/__init__.py | 3 +- .../pipelines/mirage/pipeline_mirage.py | 50 +++++++---- .../pipelines/mirage/pipeline_output.py | 2 +- .../test_models_transformer_mirage.py | 30 +++---- 6 files changed, 100 insertions(+), 109 deletions(-) diff --git a/scripts/convert_mirage_to_diffusers.py b/scripts/convert_mirage_to_diffusers.py index 85716e69ff92..5e2a2ff768f4 100644 --- a/scripts/convert_mirage_to_diffusers.py +++ b/scripts/convert_mirage_to_diffusers.py @@ -8,16 +8,17 @@ import os import shutil import sys + import torch from safetensors.torch import save_file -from transformers import GemmaTokenizerFast -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel -from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.pipelines.mirage import MiragePipeline + def load_reference_config(vae_type: str) -> dict: """Load transformer config from existing pipeline checkpoint.""" @@ -31,12 +32,13 @@ def load_reference_config(vae_type: str) -> dict: if not os.path.exists(config_path): raise FileNotFoundError(f"Reference config not found: {config_path}") - with open(config_path, 'r') as f: + with open(config_path, "r") as f: config = json.load(f) print(f"✓ Loaded {vae_type} config: in_channels={config['in_channels']}") return config + def create_parameter_mapping() -> dict: """Create mapping from old parameter names to new diffusers names.""" @@ -54,6 +56,7 @@ def create_parameter_mapping() -> dict: return mapping + def convert_checkpoint_parameters(old_state_dict: dict) -> dict: """Convert old checkpoint parameters to new diffusers format.""" @@ -82,15 +85,14 @@ def convert_checkpoint_parameters(old_state_dict: dict) -> dict: print(f" Found QKV projection: {key}") # Split QKV weight into separate Q, K, V projections qkv_weight = value - hidden_size = qkv_weight.shape[1] q_weight, k_weight, v_weight = qkv_weight.chunk(3, dim=0) # Extract layer number from key (e.g., blocks.0.img_qkv_proj.weight -> 0) - parts = key.split('.') + parts = key.split(".") layer_idx = None for i, part in enumerate(parts): - if part == 'blocks' and i + 1 < len(parts) and parts[i+1].isdigit(): - layer_idx = parts[i+1] + if part == "blocks" and i + 1 < len(parts) and parts[i + 1].isdigit(): + layer_idx = parts[i + 1] break if layer_idx is not None: @@ -117,14 +119,14 @@ def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> Mi if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") - old_checkpoint = torch.load(checkpoint_path, map_location='cpu') + old_checkpoint = torch.load(checkpoint_path, map_location="cpu") # Handle different checkpoint formats if isinstance(old_checkpoint, dict): - if 'model' in old_checkpoint: - state_dict = old_checkpoint['model'] - elif 'state_dict' in old_checkpoint: - state_dict = old_checkpoint['state_dict'] + if "model" in old_checkpoint: + state_dict = old_checkpoint["model"] + elif "state_dict" in old_checkpoint: + state_dict = old_checkpoint["state_dict"] else: state_dict = old_checkpoint else: @@ -153,6 +155,7 @@ def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> Mi return transformer + def copy_pipeline_components(vae_type: str, output_path: str): """Copy VAE, scheduler, text encoder, and tokenizer from reference pipeline.""" @@ -176,6 +179,7 @@ def copy_pipeline_components(vae_type: str, output_path: str): else: print(f"⚠ Component not found: {src_path}") + def create_model_index(vae_type: str, output_path: str): """Create model_index.json for the pipeline.""" @@ -188,33 +192,19 @@ def create_model_index(vae_type: str, output_path: str): "_class_name": "MiragePipeline", "_diffusers_version": "0.31.0.dev0", "_name_or_path": os.path.basename(output_path), - "scheduler": [ - "diffusers", - "FlowMatchEulerDiscreteScheduler" - ], - "text_encoder": [ - "transformers", - "T5GemmaEncoder" - ], - "tokenizer": [ - "transformers", - "GemmaTokenizerFast" - ], - "transformer": [ - "diffusers", - "MirageTransformer2DModel" - ], - "vae": [ - "diffusers", - vae_class - ] + "scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"], + "text_encoder": ["transformers", "T5GemmaEncoder"], + "tokenizer": ["transformers", "GemmaTokenizerFast"], + "transformer": ["diffusers", "MirageTransformer2DModel"], + "vae": ["diffusers", vae_class], } model_index_path = os.path.join(output_path, "model_index.json") - with open(model_index_path, 'w') as f: + with open(model_index_path, "w") as f: json.dump(model_index, f, indent=2) - print(f"✓ Created model_index.json") + print("✓ Created model_index.json") + def main(args): # Validate inputs @@ -236,7 +226,7 @@ def main(args): os.makedirs(transformer_path, exist_ok=True) # Save config - with open(os.path.join(transformer_path, "config.json"), 'w') as f: + with open(os.path.join(transformer_path, "config.json"), "w") as f: json.dump(config, f, indent=2) # Save model weights as safetensors @@ -253,7 +243,7 @@ def main(args): # Verify the pipeline can be loaded try: pipeline = MiragePipeline.from_pretrained(args.output_path) - print(f"Pipeline loaded successfully!") + print("Pipeline loaded successfully!") print(f"Transformer: {type(pipeline.transformer).__name__}") print(f"VAE: {type(pipeline.vae).__name__}") print(f"Text Encoder: {type(pipeline.text_encoder).__name__}") @@ -271,24 +261,18 @@ def main(args): print(f"Converted pipeline saved to: {args.output_path}") print(f"VAE type: {args.vae_type}") - return True + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Mirage checkpoint to diffusers format") parser.add_argument( - "--checkpoint_path", - type=str, - required=True, - help="Path to the original Mirage checkpoint (.pth file)" + "--checkpoint_path", type=str, required=True, help="Path to the original Mirage checkpoint (.pth file)" ) parser.add_argument( - "--output_path", - type=str, - required=True, - help="Output directory for the converted diffusers pipeline" + "--output_path", type=str, required=True, help="Output directory for the converted diffusers pipeline" ) parser.add_argument( @@ -296,7 +280,7 @@ def main(args): type=str, choices=["flux", "dc-ae"], required=True, - help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)" + help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)", ) args = parser.parse_args() @@ -306,7 +290,8 @@ def main(args): if not success: sys.exit(1) except Exception as e: - print(f"❌ Conversion failed: {e}") + print(f"Conversion failed: {e}") import traceback + traceback.print_exc() - sys.exit(1) \ No newline at end of file + sys.exit(1) diff --git a/src/diffusers/models/transformers/transformer_mirage.py b/src/diffusers/models/transformers/transformer_mirage.py index 396e000524ec..923d44d4f1ec 100644 --- a/src/diffusers/models/transformers/transformer_mirage.py +++ b/src/diffusers/models/transformers/transformer_mirage.py @@ -13,21 +13,21 @@ # limitations under the License. from dataclasses import dataclass -from typing import Any, Dict, Optional, Union, Tuple +from typing import Any, Dict, Optional, Tuple, Union + import torch -import math -from torch import Tensor, nn -from torch.nn.functional import fold, unfold from einops import rearrange from einops.layers.torch import Rearrange +from torch import Tensor, nn +from torch.nn.functional import fold, unfold from ...configuration_utils import ConfigMixin, register_to_config -from ..modeling_utils import ModelMixin -from ..modeling_outputs import Transformer2DModelOutput -from ..attention_processor import Attention, AttentionProcessor, MirageAttnProcessor2_0 from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from ..normalization import RMSNorm +from ..attention_processor import Attention, AttentionProcessor, MirageAttnProcessor2_0 from ..embeddings import get_timestep_embedding +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import RMSNorm logger = logging.get_logger(__name__) @@ -72,8 +72,6 @@ def forward(self, ids: Tensor) -> Tensor: return emb.unsqueeze(1) - - class MLPEmbedder(nn.Module): def __init__(self, in_dim: int, hidden_dim: int): super().__init__() @@ -85,8 +83,6 @@ def forward(self, x: Tensor) -> Tensor: return self.out_layer(self.silu(self.in_layer(x))) - - class QKNorm(torch.nn.Module): def __init__(self, dim: int): super().__init__() @@ -157,7 +153,6 @@ def __init__( processor=MirageAttnProcessor2_0(), ) - # mlp self.post_attention_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.gate_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False) @@ -212,9 +207,9 @@ def attn_forward( l_txt = txt_k.shape[2] assert attention_mask.dim() == 2, f"Unsupported attention_mask shape: {attention_mask.shape}" - assert ( - attention_mask.shape[-1] == l_txt - ), f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}" + assert attention_mask.shape[-1] == l_txt, ( + f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}" + ) device = img_q.device @@ -234,8 +229,8 @@ def attn_forward( kv_packed = torch.cat([k, v], dim=-1) attn = self.attention( - hidden_states=img_q, - encoder_hidden_states=kv_packed, + hidden_states=img_q, + encoder_hidden_states=kv_packed, attention_mask=attn_mask, ) @@ -288,8 +283,6 @@ def forward(self, x: Tensor, vec: Tensor) -> Tensor: return x - - def img2seq(img: Tensor, patch_size: int) -> Tensor: """Flatten an image into a sequence of patches""" return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2) @@ -327,7 +320,7 @@ def __init__( time_factor: float = 1000.0, time_max_period: int = 10000, conditioning_block_ids: list = None, - **kwargs + **kwargs, ): super().__init__() @@ -447,7 +440,7 @@ def compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Te embedding_dim=256, max_period=self.time_max_period, scale=self.time_factor, - flip_sin_to_cos=True # Match original cos, sin order + flip_sin_to_cos=True, # Match original cos, sin order ).to(dtype) ) @@ -470,9 +463,7 @@ def forward_transformers( vec = self.compute_timestep_embedding(timestep, dtype=img.dtype) for block in self.blocks: - img = block( - img=img, txt=cross_attn_conditioning, vec=vec, attention_mask=attention_mask, **block_kwargs - ) + img = block(img=img, txt=cross_attn_conditioning, vec=vec, attention_mask=attention_mask, **block_kwargs) img = self.final_layer(img, vec) return img diff --git a/src/diffusers/pipelines/mirage/__init__.py b/src/diffusers/pipelines/mirage/__init__.py index 4fd8ad191b3f..cba951057370 100644 --- a/src/diffusers/pipelines/mirage/__init__.py +++ b/src/diffusers/pipelines/mirage/__init__.py @@ -1,4 +1,5 @@ from .pipeline_mirage import MiragePipeline from .pipeline_output import MiragePipelineOutput -__all__ = ["MiragePipeline", "MiragePipelineOutput"] \ No newline at end of file + +__all__ = ["MiragePipeline", "MiragePipelineOutput"] diff --git a/src/diffusers/pipelines/mirage/pipeline_mirage.py b/src/diffusers/pipelines/mirage/pipeline_mirage.py index 126eab07977c..c4a4783c5f38 100644 --- a/src/diffusers/pipelines/mirage/pipeline_mirage.py +++ b/src/diffusers/pipelines/mirage/pipeline_mirage.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import html import inspect import os -from typing import Any, Callable, Dict, List, Optional, Union - -import html import re import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union import ftfy import torch @@ -31,7 +30,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, AutoencoderDC +from ...models import AutoencoderDC, AutoencoderKL from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( logging, @@ -41,6 +40,7 @@ from ..pipeline_utils import DiffusionPipeline from .pipeline_output import MiragePipelineOutput + try: from ...models.transformers.transformer_mirage import MirageTransformer2DModel except ImportError: @@ -55,7 +55,19 @@ class TextPreprocessor: def __init__(self): """Initialize text preprocessor.""" self.bad_punct_regex = re.compile( - r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + r"\\" + r"\/" + r"\*" + r"]{1,}" + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + r"\\" + + r"\/" + + r"\*" + + r"]{1,}" ) def clean_text(self, text: str) -> str: @@ -93,7 +105,7 @@ def clean_text(self, text: str) -> str: ) # кавычки к одному стандарту - text = re.sub(r"[`´«»""¨]", '"', text) + text = re.sub(r"[`´«»" "¨]", '"', text) text = re.sub(r"['']", "'", text) # " and & @@ -243,9 +255,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P """ # Ensure T5GemmaEncoder is available for loading import transformers - if not hasattr(transformers, 'T5GemmaEncoder'): + + if not hasattr(transformers, "T5GemmaEncoder"): try: from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder + transformers.T5GemmaEncoder = T5GemmaEncoder except ImportError: # T5GemmaEncoder not available in this transformers version @@ -254,7 +268,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Proceed with standard loading return super().from_pretrained(pretrained_model_name_or_path, **kwargs) - def __init__( self, transformer: MirageTransformer2DModel, @@ -333,7 +346,7 @@ def _enhance_vae_properties(self): if hasattr(self.vae, "spatial_compression_ratio") and self.vae.spatial_compression_ratio == 32: self.vae.latent_channels = 32 # DC-AE default else: - self.vae.latent_channels = 4 # AutoencoderKL default + self.vae.latent_channels = 4 # AutoencoderKL default @property def vae_scale_factor(self): @@ -353,7 +366,10 @@ def prepare_latents( ): """Prepare initial latents for the diffusion process.""" if latents is None: - latent_height, latent_width = height // self.vae.spatial_compression_ratio, width // self.vae.spatial_compression_ratio + latent_height, latent_width = ( + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) shape = (batch_size, num_channels_latents, latent_height, latent_width) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: @@ -424,7 +440,9 @@ def check_inputs( ): """Check that all inputs are in correct format.""" if height % self.vae.spatial_compression_ratio != 0 or width % self.vae.spatial_compression_ratio != 0: - raise ValueError(f"`height` and `width` have to be divisible by {self.vae.spatial_compression_ratio} but are {height} and {width}.") + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae.spatial_compression_ratio} but are {height} and {width}." + ) if guidance_scale < 1.0: raise ValueError(f"guidance_scale has to be >= 1.0 but is {guidance_scale}") @@ -584,12 +602,16 @@ def __call__( # Forward through transformer layers img_seq = self.transformer.forward_transformers( - img_seq, txt, time_embedding=self.transformer.compute_timestep_embedding(t_cont, img_seq.dtype), - pe=pe, attention_mask=ca_mask + img_seq, + txt, + time_embedding=self.transformer.compute_timestep_embedding(t_cont, img_seq.dtype), + pe=pe, + attention_mask=ca_mask, ) # Convert back to image format from ...models.transformers.transformer_mirage import seq2img + noise_both = seq2img(img_seq, self.transformer.patch_size, latents_in.shape) # Apply CFG @@ -626,4 +648,4 @@ def __call__( if not return_dict: return (image,) - return MiragePipelineOutput(images=image) \ No newline at end of file + return MiragePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/mirage/pipeline_output.py b/src/diffusers/pipelines/mirage/pipeline_output.py index dfb55821d142..e41c8e3bea00 100644 --- a/src/diffusers/pipelines/mirage/pipeline_output.py +++ b/src/diffusers/pipelines/mirage/pipeline_output.py @@ -32,4 +32,4 @@ class MiragePipelineOutput(BaseOutput): num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. """ - images: Union[List[PIL.Image.Image], np.ndarray] \ No newline at end of file + images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/tests/models/transformers/test_models_transformer_mirage.py b/tests/models/transformers/test_models_transformer_mirage.py index 5e7b0bd165a6..0085627aa7e4 100644 --- a/tests/models/transformers/test_models_transformer_mirage.py +++ b/tests/models/transformers/test_models_transformer_mirage.py @@ -133,8 +133,7 @@ def test_process_inputs(self): with torch.no_grad(): img_seq, txt, pe = model.process_inputs( - inputs_dict["image_latent"], - inputs_dict["cross_attn_conditioning"] + inputs_dict["image_latent"], inputs_dict["cross_attn_conditioning"] ) # Check shapes @@ -144,7 +143,9 @@ def test_process_inputs(self): expected_seq_len = (height // patch_size) * (width // patch_size) self.assertEqual(img_seq.shape, (batch_size, expected_seq_len, init_dict["in_channels"] * patch_size**2)) - self.assertEqual(txt.shape, (batch_size, inputs_dict["cross_attn_conditioning"].shape[1], init_dict["hidden_size"])) + self.assertEqual( + txt.shape, (batch_size, inputs_dict["cross_attn_conditioning"].shape[1], init_dict["hidden_size"]) + ) # Check that pe has the correct batch size, sequence length and some embedding dimension self.assertEqual(pe.shape[0], batch_size) # batch size self.assertEqual(pe.shape[1], 1) # unsqueeze(1) in EmbedND @@ -160,20 +161,14 @@ def test_forward_transformers(self): with torch.no_grad(): # Process inputs first img_seq, txt, pe = model.process_inputs( - inputs_dict["image_latent"], - inputs_dict["cross_attn_conditioning"] + inputs_dict["image_latent"], inputs_dict["cross_attn_conditioning"] ) # Test forward_transformers - output_seq = model.forward_transformers( - img_seq, - txt, - timestep=inputs_dict["timestep"], - pe=pe - ) + output_seq = model.forward_transformers(img_seq, txt, timestep=inputs_dict["timestep"], pe=pe) # Check output shape - expected_out_channels = init_dict["in_channels"] * init_dict["patch_size"]**2 + expected_out_channels = init_dict["in_channels"] * init_dict["patch_size"] ** 2 self.assertEqual(output_seq.shape, (img_seq.shape[0], img_seq.shape[1], expected_out_channels)) def test_attention_mask(self): @@ -186,13 +181,10 @@ def test_attention_mask(self): batch_size = inputs_dict["cross_attn_conditioning"].shape[0] seq_len = inputs_dict["cross_attn_conditioning"].shape[1] attention_mask = torch.ones((batch_size, seq_len), dtype=torch.bool).to(torch_device) - attention_mask[:, seq_len//2:] = False # Mask second half + attention_mask[:, seq_len // 2 :] = False # Mask second half with torch.no_grad(): - outputs = model( - **inputs_dict, - cross_attn_mask=attention_mask - ) + outputs = model(**inputs_dict, cross_attn_mask=attention_mask) self.assertIsNotNone(outputs) expected_shape = inputs_dict["image_latent"].shape @@ -237,7 +229,7 @@ def test_gradient_checkpointing_enable(self): # Check that _activation_checkpointing is set for block in model.blocks: - self.assertTrue(hasattr(block, '_activation_checkpointing')) + self.assertTrue(hasattr(block, "_activation_checkpointing")) def test_from_config(self): init_dict, _ = self.prepare_init_args_and_inputs_for_common() @@ -249,4 +241,4 @@ def test_from_config(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main()