Skip to content

Commit 542a603

Browse files
committed
update
1 parent e95ac9d commit 542a603

File tree

9 files changed

+987
-278
lines changed

9 files changed

+987
-278
lines changed

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@
159159
"AutoencoderTiny",
160160
"AutoModel",
161161
"CacheMixin",
162+
"ChromaTransformer2DModel",
162163
"CogVideoXTransformer3DModel",
163164
"CogView3PlusTransformer2DModel",
164165
"CogView4Transformer2DModel",

src/diffusers/loaders/single_file_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
convert_animatediff_checkpoint_to_diffusers,
3030
convert_auraflow_transformer_checkpoint_to_diffusers,
3131
convert_autoencoder_dc_checkpoint_to_diffusers,
32+
convert_chroma_transformer_to_diffusers,
3233
convert_controlnet_checkpoint,
3334
convert_flux_transformer_checkpoint_to_diffusers,
3435
convert_hidream_transformer_to_diffusers,
@@ -138,6 +139,10 @@
138139
"checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
139140
"default_subfolder": "transformer",
140141
},
142+
"ChromaTransformer2DModel": {
143+
"checkpoint_mapping_fn": convert_chroma_transformer_to_diffusers,
144+
"default_subfolder": "transformer",
145+
},
141146
}
142147

143148

src/diffusers/loaders/single_file_utils.py

Lines changed: 201 additions & 77 deletions
Large diffs are not rendered by default.

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
7575
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
7676
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
77+
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
7778
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
7879
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
7980
_import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
@@ -150,6 +151,7 @@
150151
from .transformers import (
151152
AllegroTransformer3DModel,
152153
AuraFlowTransformer2DModel,
154+
ChromaTransformer2DModel,
153155
CogVideoXTransformer3DModel,
154156
CogView3PlusTransformer2DModel,
155157
CogView4Transformer2DModel,

src/diffusers/models/embeddings.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1637,35 +1637,6 @@ def forward(self, timestep, guidance, pooled_projection):
16371637
return conditioning
16381638

16391639

1640-
class CombinedTimestepTextProjChromaEmbeddings(nn.Module):
1641-
def __init__(self, factor: int, hidden_dim: int, out_dim: int, n_layers: int, embedding_dim: int):
1642-
super().__init__()
1643-
1644-
self.time_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
1645-
self.guidance_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
1646-
1647-
self.register_buffer(
1648-
"mod_proj",
1649-
get_timestep_embedding(torch.arange(out_dim)*1000, 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0, ),
1650-
persistent=False,
1651-
)
1652-
1653-
def forward(
1654-
self, timestep: torch.Tensor, guidance: Optional[torch.Tensor], pooled_projections: torch.Tensor
1655-
) -> torch.Tensor:
1656-
mod_index_length = self.mod_proj.shape[0]
1657-
timesteps_proj = self.time_proj(timestep).to(dtype=timestep.dtype)
1658-
guidance_proj = self.guidance_proj(torch.tensor([0])).to(dtype=timestep.dtype, device=timestep.device)
1659-
1660-
mod_proj = self.mod_proj.to(dtype=timesteps_proj.dtype, device=timesteps_proj.device)
1661-
timestep_guidance = (
1662-
torch.cat([timesteps_proj, guidance_proj], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
1663-
)
1664-
input_vec = torch.cat([timestep_guidance, mod_proj.unsqueeze(0)], dim=-1)
1665-
1666-
return input_vec
1667-
1668-
16691640
class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
16701641
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
16711642
super().__init__()
@@ -2259,25 +2230,6 @@ def forward(self, caption):
22592230
return hidden_states
22602231

22612232

2262-
class ChromaApproximator(nn.Module):
2263-
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers: int = 5):
2264-
super().__init__()
2265-
self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True)
2266-
self.layers = nn.ModuleList(
2267-
[PixArtAlphaTextProjection(hidden_dim, hidden_dim, act_fn="silu") for _ in range(n_layers)]
2268-
)
2269-
self.norms = nn.ModuleList([nn.RMSNorm(hidden_dim) for _ in range(n_layers)])
2270-
self.out_proj = nn.Linear(hidden_dim, out_dim)
2271-
2272-
def forward(self, x):
2273-
x = self.in_proj(x)
2274-
2275-
for layer, norms in zip(self.layers, self.norms):
2276-
x = x + layer(norms(x))
2277-
2278-
return self.out_proj(x)
2279-
2280-
22812233
class IPAdapterPlusImageProjectionBlock(nn.Module):
22822234
def __init__(
22832235
self,

src/diffusers/models/normalization.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -374,50 +374,6 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
374374
return x
375375

376376

377-
class AdaLayerNormContinuousPruned(nn.Module):
378-
r"""
379-
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
380-
381-
Args:
382-
embedding_dim (`int`): Embedding dimension to use during projection.
383-
conditioning_embedding_dim (`int`): Dimension of the input condition.
384-
elementwise_affine (`bool`, defaults to `True`):
385-
Boolean flag to denote if affine transformation should be applied.
386-
eps (`float`, defaults to 1e-5): Epsilon factor.
387-
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
388-
norm_type (`str`, defaults to `"layer_norm"`):
389-
Normalization layer to use. Values supported: "layer_norm", "rms_norm".
390-
"""
391-
392-
def __init__(
393-
self,
394-
embedding_dim: int,
395-
conditioning_embedding_dim: int,
396-
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
397-
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
398-
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
399-
# However, this is how it was implemented in the original code, and it's rather likely you should
400-
# set `elementwise_affine` to False.
401-
elementwise_affine=True,
402-
eps=1e-5,
403-
bias=True,
404-
norm_type="layer_norm",
405-
):
406-
super().__init__()
407-
if norm_type == "layer_norm":
408-
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
409-
elif norm_type == "rms_norm":
410-
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
411-
else:
412-
raise ValueError(f"unknown norm_type {norm_type}")
413-
414-
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
415-
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
416-
shift, scale = torch.chunk(emb.squeeze(0).to(x.dtype), 2, dim=0)
417-
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
418-
return x
419-
420-
421377
class AdaLayerNormContinuous(nn.Module):
422378
r"""
423379
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).

src/diffusers/models/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .t5_film_transformer import T5FilmDecoder
1818
from .transformer_2d import Transformer2DModel
1919
from .transformer_allegro import AllegroTransformer3DModel
20+
from .transformer_chroma import ChromaTransformer2DModel
2021
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
2122
from .transformer_cogview4 import CogView4Transformer2DModel
2223
from .transformer_cosmos import CosmosTransformer3DModel

0 commit comments

Comments
 (0)