Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/api/models/chroma_transformer.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.

# ChromaTransformer2DModel

A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma)
A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma1-HD)

## ChromaTransformer2DModel

Expand Down
10 changes: 5 additions & 5 deletions docs/source/en/api/pipelines/chroma.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,20 @@ specific language governing permissions and limitations under the License.

Chroma is a text to image generation model based on Flux.

Original model checkpoints for Chroma can be found [here](https://huggingface.co/lodestones/Chroma).
Original model checkpoints for Chroma can be found [here](https://huggingface.co/lodestones/Chroma1-HD).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intentional. As if there's any other official compatible checkpoint released under https://huggingface.co/lodestones/, the users will likely notice it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://huggingface.co/lodestones/Chroma is actually a deprecated repo, rather than a hub for all Chroma models. I don't think there's currently a 'hub' repo. Did you mean to imply that it should link to https://huggingface.co/lodestones rather than https://huggingface.co/lodestones/Chroma or https://huggingface.co/lodestones/Chroma1-HD?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made some changes - unsure if this fixes your concerns.


> [!TIP]
> Chroma can use all the same optimizations as Flux.

## Inference

The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma).
The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma1-HD).

```python
import torch
from diffusers import ChromaPipeline

pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16)
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma1-HD", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()

prompt = [
Expand Down Expand Up @@ -63,10 +63,10 @@ Then run the following example
import torch
from diffusers import ChromaTransformer2DModel, ChromaPipeline

model_id = "lodestones/Chroma"
model_id = "lodestones/Chroma1-HD"
dtype = torch.bfloat16

transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors", torch_dtype=dtype)
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors", torch_dtype=dtype)

pipe = ChromaPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=dtype)
pipe.enable_model_cpu_offload()
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/transformer_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ class ChromaTransformer2DModel(
"""
The Transformer model introduced in Flux, modified for Chroma.

Reference: https://huggingface.co/lodestones/Chroma
Reference: https://huggingface.co/lodestones/Chroma1-HD

Args:
patch_size (`int`, defaults to `1`):
Expand Down
23 changes: 12 additions & 11 deletions src/diffusers/pipelines/chroma/pipeline_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@
>>> import torch
>>> from diffusers import ChromaPipeline

>>> model_id = "lodestones/Chroma"
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
>>> model_id = "lodestones/Chroma1-HD"
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors"
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
>>> pipe = ChromaPipeline.from_pretrained(
... model_id,
Expand Down Expand Up @@ -158,7 +158,7 @@ class ChromaPipeline(
r"""
The Chroma pipeline for text-to-image generation.

Reference: https://huggingface.co/lodestones/Chroma/
Reference: https://huggingface.co/lodestones/Chroma1-HD/

Args:
transformer ([`ChromaTransformer2DModel`]):
Expand Down Expand Up @@ -233,20 +233,21 @@ def _get_t5_prompt_embeds(
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask.clone()
tokenizer_mask = text_inputs.attention_mask

# Chroma requires the attention mask to include one padding token
seq_lengths = attention_mask.sum(dim=1)
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).bool()
tokenizer_mask_device = tokenizer_mask.to(device)

prompt_embeds = self.text_encoder(
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
text_input_ids.to(device),
output_hidden_states=False,
attention_mask=tokenizer_mask_device,
)[0]

dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
attention_mask = attention_mask.to(device=device)

seq_lengths = tokenizer_mask_device.sum(dim=1)
mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)

_, seq_len, _ = prompt_embeds.shape

Expand Down
25 changes: 14 additions & 11 deletions src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@
>>> import torch
>>> from diffusers import ChromaTransformer2DModel, ChromaImg2ImgPipeline

>>> model_id = "lodestones/Chroma"
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
>>> model_id = "lodestones/Chroma1-HD"
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors"
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
... model_id,
... transformer=transformer,
Expand Down Expand Up @@ -170,7 +170,7 @@ class ChromaImg2ImgPipeline(
r"""
The Chroma pipeline for image-to-image generation.

Reference: https://huggingface.co/lodestones/Chroma/
Reference: https://huggingface.co/lodestones/Chroma1-HD/

Args:
transformer ([`ChromaTransformer2DModel`]):
Expand Down Expand Up @@ -247,20 +247,23 @@ def _get_t5_prompt_embeds(
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask.clone()
tokenizer_mask = text_inputs.attention_mask

# Chroma requires the attention mask to include one padding token
seq_lengths = attention_mask.sum(dim=1)
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
tokenizer_mask_device = tokenizer_mask.to(device)

prompt_embeds = self.text_encoder(
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
text_input_ids.to(device),
output_hidden_states=False,
attention_mask=tokenizer_mask_device,
)[0]

dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
attention_mask = attention_mask.to(dtype=dtype, device=device)

seq_lengths = tokenizer_mask_device.sum(dim=1)
mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(
batch_size, -1
)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)

_, seq_len, _ = prompt_embeds.shape

Expand Down
Loading