Skip to content

Commit dc6bd15

Browse files
josephroccagithub-actions[bot]DN6
authored
Fix Chroma attention padding order and update docs to use lodestones/Chroma1-HD (#12508)
* [Fix] Move attention mask padding after T5 embedding * [Fix] Move attention mask padding after T5 embedding * Clean up whitespace in pipeline_chroma.py Removed unnecessary blank lines for cleaner code. * Fix * Fix * Update model to final Chroma1-HD checkpoint * Update to Chroma1-HD * Update model to Chroma1-HD * Update model to Chroma1-HD * Update Chroma model links to Chroma1-HD * Add comment about padding/masking * Fix checkpoint/repo references * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Dhruv Nair <[email protected]>
1 parent 500b9cf commit dc6bd15

File tree

5 files changed

+35
-30
lines changed

5 files changed

+35
-30
lines changed

docs/source/en/api/models/chroma_transformer.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
1212

1313
# ChromaTransformer2DModel
1414

15-
A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma)
15+
A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma1-HD)
1616

1717
## ChromaTransformer2DModel
1818

docs/source/en/api/pipelines/chroma.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,21 @@ specific language governing permissions and limitations under the License.
1919

2020
Chroma is a text to image generation model based on Flux.
2121

22-
Original model checkpoints for Chroma can be found [here](https://huggingface.co/lodestones/Chroma).
22+
Original model checkpoints for Chroma can be found here:
23+
* High-resolution finetune: [lodestones/Chroma1-HD](https://huggingface.co/lodestones/Chroma1-HD)
24+
* Base model: [lodestones/Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base)
25+
* Original repo with progress checkpoints: [lodestones/Chroma](https://huggingface.co/lodestones/Chroma) (loading this repo with `from_pretrained` will load a Diffusers-compatible version of the `unlocked-v37` checkpoint)
2326

2427
> [!TIP]
2528
> Chroma can use all the same optimizations as Flux.
2629
2730
## Inference
2831

29-
The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma).
30-
3132
```python
3233
import torch
3334
from diffusers import ChromaPipeline
3435

35-
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16)
36+
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma1-HD", torch_dtype=torch.bfloat16)
3637
pipe.enable_model_cpu_offload()
3738

3839
prompt = [
@@ -63,10 +64,10 @@ Then run the following example
6364
import torch
6465
from diffusers import ChromaTransformer2DModel, ChromaPipeline
6566

66-
model_id = "lodestones/Chroma"
67+
model_id = "lodestones/Chroma1-HD"
6768
dtype = torch.bfloat16
6869

69-
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors", torch_dtype=dtype)
70+
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors", torch_dtype=dtype)
7071

7172
pipe = ChromaPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=dtype)
7273
pipe.enable_model_cpu_offload()

src/diffusers/models/transformers/transformer_chroma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ class ChromaTransformer2DModel(
379379
"""
380380
The Transformer model introduced in Flux, modified for Chroma.
381381
382-
Reference: https://huggingface.co/lodestones/Chroma
382+
Reference: https://huggingface.co/lodestones/Chroma1-HD
383383
384384
Args:
385385
patch_size (`int`, defaults to `1`):

src/diffusers/pipelines/chroma/pipeline_chroma.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@
5353
>>> import torch
5454
>>> from diffusers import ChromaPipeline
5555
56-
>>> model_id = "lodestones/Chroma"
57-
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
56+
>>> model_id = "lodestones/Chroma1-HD"
57+
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors"
5858
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
5959
>>> pipe = ChromaPipeline.from_pretrained(
6060
... model_id,
@@ -158,7 +158,7 @@ class ChromaPipeline(
158158
r"""
159159
The Chroma pipeline for text-to-image generation.
160160
161-
Reference: https://huggingface.co/lodestones/Chroma/
161+
Reference: https://huggingface.co/lodestones/Chroma1-HD/
162162
163163
Args:
164164
transformer ([`ChromaTransformer2DModel`]):
@@ -233,20 +233,23 @@ def _get_t5_prompt_embeds(
233233
return_tensors="pt",
234234
)
235235
text_input_ids = text_inputs.input_ids
236-
attention_mask = text_inputs.attention_mask.clone()
236+
tokenizer_mask = text_inputs.attention_mask
237237

238-
# Chroma requires the attention mask to include one padding token
239-
seq_lengths = attention_mask.sum(dim=1)
240-
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
241-
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).bool()
238+
tokenizer_mask_device = tokenizer_mask.to(device)
242239

240+
# unlike FLUX, Chroma uses the attention mask when generating the T5 embedding
243241
prompt_embeds = self.text_encoder(
244-
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
242+
text_input_ids.to(device),
243+
output_hidden_states=False,
244+
attention_mask=tokenizer_mask_device,
245245
)[0]
246246

247-
dtype = self.text_encoder.dtype
248247
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
249-
attention_mask = attention_mask.to(device=device)
248+
249+
# for the text tokens, chroma requires that all except the first padding token are masked out during the forward pass through the transformer
250+
seq_lengths = tokenizer_mask_device.sum(dim=1)
251+
mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1)
252+
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)
250253

251254
_, seq_len, _ = prompt_embeds.shape
252255

src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@
5353
>>> import torch
5454
>>> from diffusers import ChromaTransformer2DModel, ChromaImg2ImgPipeline
5555
56-
>>> model_id = "lodestones/Chroma"
57-
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
56+
>>> model_id = "lodestones/Chroma1-HD"
57+
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors"
5858
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
5959
... model_id,
6060
... transformer=transformer,
@@ -170,7 +170,7 @@ class ChromaImg2ImgPipeline(
170170
r"""
171171
The Chroma pipeline for image-to-image generation.
172172
173-
Reference: https://huggingface.co/lodestones/Chroma/
173+
Reference: https://huggingface.co/lodestones/Chroma1-HD/
174174
175175
Args:
176176
transformer ([`ChromaTransformer2DModel`]):
@@ -247,20 +247,21 @@ def _get_t5_prompt_embeds(
247247
return_tensors="pt",
248248
)
249249
text_input_ids = text_inputs.input_ids
250-
attention_mask = text_inputs.attention_mask.clone()
250+
tokenizer_mask = text_inputs.attention_mask
251251

252-
# Chroma requires the attention mask to include one padding token
253-
seq_lengths = attention_mask.sum(dim=1)
254-
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
255-
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
252+
tokenizer_mask_device = tokenizer_mask.to(device)
256253

257254
prompt_embeds = self.text_encoder(
258-
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
255+
text_input_ids.to(device),
256+
output_hidden_states=False,
257+
attention_mask=tokenizer_mask_device,
259258
)[0]
260259

261-
dtype = self.text_encoder.dtype
262260
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
263-
attention_mask = attention_mask.to(dtype=dtype, device=device)
261+
262+
seq_lengths = tokenizer_mask_device.sum(dim=1)
263+
mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1)
264+
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)
264265

265266
_, seq_len, _ = prompt_embeds.shape
266267

0 commit comments

Comments
 (0)