Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions examples/chroma_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Example script for generating images with Chroma model
"""

import torch
from diffusers import ChromaTransformer2DModel, ChromaPipeline, AutoencoderKL
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from transformers import T5EncoderModel, T5TokenizerFast

def generate_with_chroma():
# Model paths
chroma_path = "lodestones/Chroma" # or local path to safetensors
vae_path = "black-forest-labs/FLUX.1-schnell" # Chroma uses Flux VAE
text_encoder_path = "google/t5-v1_1-xxl" # T5 XXL encoder

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32

print("Loading models...")

# Load VAE from Flux
vae = AutoencoderKL.from_pretrained(
vae_path,
subfolder="vae",
torch_dtype=dtype
).to(device)

# Load T5 encoder
text_encoder = T5EncoderModel.from_pretrained(
text_encoder_path,
torch_dtype=dtype
).to(device)

tokenizer = T5TokenizerFast.from_pretrained(text_encoder_path)

# Load Chroma transformer
# Option 1: From HuggingFace Hub (when available)
# transformer = ChromaTransformer2DModel.from_pretrained(
# chroma_path,
# torch_dtype=dtype
# ).to(device)

# Option 2: From single file
transformer = ChromaTransformer2DModel.from_single_file(
"path/to/chroma-unlocked-v29.safetensors",
torch_dtype=dtype
).to(device)

# Create scheduler
scheduler = FlowMatchEulerDiscreteScheduler()

# Create pipeline
pipe = ChromaPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler
)

# Enable optimizations
pipe.enable_model_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()

# Generate image
prompt = "A majestic lion made of galaxies and stardust, cosmic art style"

print(f"Generating image with prompt: {prompt}")

image = pipe(
prompt=prompt,
height=1024,
width=1024,
num_inference_steps=20,
guidance_scale=3.5,
generator=torch.Generator(device=device).manual_seed(42)
).images[0]

# Save image
image.save("chroma_output.png")
print("Image saved as chroma_output.png")

if __name__ == "__main__":
generate_with_chroma()
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
"AutoencoderTiny",
"AutoModel",
"CacheMixin",
"ChromaTransformer2DModel",
"CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel",
"CogView4Transformer2DModel",
Expand Down Expand Up @@ -351,6 +352,7 @@
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"ChromaPipeline",
"CLIPImageProjection",
"CogVideoXFunControlPipeline",
"CogVideoXImageToVideoPipeline",
Expand Down Expand Up @@ -764,6 +766,7 @@
AutoencoderTiny,
AutoModel,
CacheMixin,
ChromaTransformer2DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
CogView4Transformer2DModel,
Expand Down Expand Up @@ -935,6 +938,7 @@
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
AuraFlowPipeline,
ChromaPipeline,
CLIPImageProjection,
CogVideoXFunControlPipeline,
CogVideoXImageToVideoPipeline,
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
convert_animatediff_checkpoint_to_diffusers,
convert_auraflow_transformer_checkpoint_to_diffusers,
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_chroma_transformer_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers,
convert_hidream_transformer_to_diffusers,
Expand Down Expand Up @@ -138,6 +139,10 @@
"checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"ChromaTransformer2DModel": {
"checkpoint_mapping_fn": convert_chroma_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
}


Expand Down
194 changes: 194 additions & 0 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3329,3 +3329,197 @@ def convert_hidream_transformer_to_diffusers(checkpoint, **kwargs):
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)

return checkpoint


# Add Ednaordinary's converter function
def convert_chroma_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
"""
Convert ChromaTransformer2DModel checkpoint to diffusers format.
This handles the conversion from original checkpoint format to the diffusers naming convention.
"""
converted_state_dict = {}
keys = list(checkpoint.keys())

# Handle model.diffusion_model prefix removal (common in many checkpoints)
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)

# Handle distilled guidance layer conversion (similar to flux chroma variant)
variant = "chroma" if "distilled_guidance_layer.in_proj.weight" in checkpoint else "flux"

for k in list(checkpoint.keys()):
if variant == "chroma" and "distilled_guidance_layer." in k:
new_key = k
if k.startswith("distilled_guidance_layer.norms"):
new_key = k.replace(".scale", ".weight")
elif k.startswith("distilled_guidance_layer.layer"):
new_key = k.replace("in_layer", "linear_1").replace("out_layer", "linear_2")
converted_state_dict[new_key] = checkpoint.pop(k)

# Get number of layers from checkpoint
num_layers = 0
num_single_layers = 0

if any("double_blocks." in k for k in checkpoint):
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1
if any("single_blocks." in k for k in checkpoint):
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1

# Helper function to swap scale and shift for AdaLayerNorm
def swap_scale_shift(weight):
if weight.dim() == 1 and weight.size(0) % 2 == 0:
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
return weight

# Convert time and text embeddings
if "time_in.in_layer.weight" in checkpoint:
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("time_in.in_layer.weight")
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias")
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("time_in.out_layer.weight")
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias")

if "vector_in.in_layer.weight" in checkpoint:
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight")
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias")
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop("vector_in.out_layer.weight")
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias")

# Convert guidance embeddings if present
if "guidance_in.in_layer.weight" in checkpoint:
converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop("guidance_in.in_layer.weight")
converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop("guidance_in.in_layer.bias")
converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop("guidance_in.out_layer.weight")
converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop("guidance_in.out_layer.bias")

# Convert context and x embedders
if "txt_in.weight" in checkpoint:
converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")

if "img_in.weight" in checkpoint:
converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")

# Convert double transformer blocks
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."

# Convert norms
if f"double_blocks.{i}.img_mod.lin.weight" in checkpoint:
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mod.lin.weight")
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mod.lin.bias")

if f"double_blocks.{i}.txt_mod.lin.weight" in checkpoint:
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop(f"double_blocks.{i}.txt_mod.lin.weight")
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop(f"double_blocks.{i}.txt_mod.lin.bias")

# Convert attention layers
if f"double_blocks.{i}.img_attn.qkv.weight" in checkpoint:
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0)

converted_state_dict[f"{block_prefix}attn.to_q.weight"] = sample_q
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = sample_q_bias
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = sample_k
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = sample_k_bias
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = sample_v
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = sample_v_bias

if f"double_blocks.{i}.txt_attn.qkv.weight" in checkpoint:
context_q, context_k, context_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0)
context_q_bias, context_k_bias, context_v_bias = torch.chunk(checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0)

converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = context_q
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = context_q_bias
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = context_k
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = context_k_bias
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = context_v
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = context_v_bias

# Convert QK norms
if f"double_blocks.{i}.img_attn.norm.query_norm.scale" in checkpoint:
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(f"double_blocks.{i}.img_attn.norm.query_norm.scale")
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(f"double_blocks.{i}.img_attn.norm.key_norm.scale")

if f"double_blocks.{i}.txt_attn.norm.query_norm.scale" in checkpoint:
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(f"double_blocks.{i}.txt_attn.norm.query_norm.scale")
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(f"double_blocks.{i}.txt_attn.norm.key_norm.scale")

# Convert output projections
if f"double_blocks.{i}.img_attn.proj.weight" in checkpoint:
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(f"double_blocks.{i}.img_attn.proj.weight")
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(f"double_blocks.{i}.img_attn.proj.bias")

if f"double_blocks.{i}.txt_attn.proj.weight" in checkpoint:
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(f"double_blocks.{i}.txt_attn.proj.weight")
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(f"double_blocks.{i}.txt_attn.proj.bias")

# Convert MLPs
if f"double_blocks.{i}.img_mlp.0.weight" in checkpoint:
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.weight")
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")

if f"double_blocks.{i}.txt_mlp.0.weight" in checkpoint:
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(f"double_blocks.{i}.txt_mlp.0.weight")
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.txt_mlp.0.bias")
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.txt_mlp.2.weight")
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.txt_mlp.2.bias")

# Convert single transformer blocks
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."

# Convert norms
if f"single_blocks.{i}.modulation.lin.weight" in checkpoint:
converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop(f"single_blocks.{i}.modulation.lin.weight")
converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop(f"single_blocks.{i}.modulation.lin.bias")

# Convert combined QKV and MLP
if f"single_blocks.{i}.linear1.weight" in checkpoint:
inner_dim = 3072
mlp_ratio = 4.0
mlp_hidden_dim = int(inner_dim * mlp_ratio)
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)

q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
q_bias, k_bias, v_bias, mlp_bias = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0)

converted_state_dict[f"{block_prefix}attn.to_q.weight"] = q
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = q_bias
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = k
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = k_bias
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = v
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = v_bias
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = mlp
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = mlp_bias

# Convert QK norms
if f"single_blocks.{i}.norm.query_norm.scale" in checkpoint:
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(f"single_blocks.{i}.norm.query_norm.scale")
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(f"single_blocks.{i}.norm.key_norm.scale")

# Convert output projection
if f"single_blocks.{i}.linear2.weight" in checkpoint:
converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")

# Convert final layer
if "final_layer.linear.weight" in checkpoint:
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")

if "final_layer.adaLN_modulation.1.weight" in checkpoint:
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.adaLN_modulation.1.weight"))
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.adaLN_modulation.1.bias"))

# Add any remaining keys from the original checkpoint
for key, value in checkpoint.items():
if key not in converted_state_dict:
converted_state_dict[key] = value

return converted_state_dict
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
_import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
Expand Down Expand Up @@ -150,6 +151,7 @@
from .transformers import (
AllegroTransformer3DModel,
AuraFlowTransformer2DModel,
ChromaTransformer2DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
CogView4Transformer2DModel,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
from .transformer_allegro import AllegroTransformer3DModel
from .transformer_chroma import ChromaTransformer2DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_cogview4 import CogView4Transformer2DModel
from .transformer_cosmos import CosmosTransformer3DModel
Expand Down
Loading