Skip to content

Commit f35ec17

Browse files
committed
Merge remote-tracking branch '11698/chroma' into chroma-final
2 parents 19733af + 381e64b commit f35ec17

File tree

6 files changed

+24
-13
lines changed

6 files changed

+24
-13
lines changed

.github/workflows/pr_style_bot.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ jobs:
1414
with:
1515
python_quality_dependencies: "[quality]"
1616
secrets:
17-
bot_token: ${{ secrets.GITHUB_TOKEN }}
17+
bot_token: ${{ secrets.HF_STYLE_BOT_ACTION }}

docs/source/en/quantization/torchao.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
6565

6666
For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware.
6767

68+
> [!TIP]
69+
> The FP8 post-training quantization schemes in torchao are effective for GPUs with compute capability of at least 8.9 (RTX-4090, Hopper, etc.). FP8 often provides the best speed, memory, and quality trade-off when generating images and videos. We recommend combining FP8 and torch.compile if your GPU is compatible.
70+
6871
torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future.
6972

7073
The `TorchAoConfig` class accepts three parameters:

src/diffusers/loaders/single_file_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3323,8 +3323,8 @@ def convert_chroma_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
33233323
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
33243324
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
33253325
num_guidance_layers = (
3326-
list({int(k.split(".", 3)[2]) for k in checkpoint if "distilled_guidance_layer.layers." in k})[-1] + 1
3327-
) # noqa: C401
3326+
list(set(int(k.split(".", 3)[2]) for k in checkpoint if "distilled_guidance_layer.layers." in k))[-1] + 1 # noqa: C401
3327+
)
33283328
mlp_ratio = 4.0
33293329
inner_dim = 3072
33303330

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,7 @@
532532
)
533533
from .aura_flow import AuraFlowPipeline
534534
from .blip_diffusion import BlipDiffusionPipeline
535+
from .chroma import ChromaPipeline
535536
from .cogvideo import (
536537
CogVideoXFunControlPipeline,
537538
CogVideoXImageToVideoPipeline,

src/diffusers/pipelines/chroma/pipeline_chroma.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,6 @@ def __init__(
182182
transformer: ChromaTransformer2DModel,
183183
image_encoder: CLIPVisionModelWithProjection = None,
184184
feature_extractor: CLIPImageProcessor = None,
185-
variant: str = "flux",
186185
):
187186
super().__init__()
188187

@@ -220,24 +219,20 @@ def _get_t5_prompt_embeds(
220219

221220
text_inputs = self.tokenizer(
222221
prompt,
223-
padding="max_length",
222+
padding=False,
224223
max_length=max_sequence_length,
225224
truncation=True,
226225
return_length=False,
227226
return_overflowing_tokens=False,
228227
return_tensors="pt",
229228
)
230-
text_input_ids = text_inputs.input_ids
229+
text_input_ids = text_inputs.input_ids + self.tokenizer.pad_token_id
231230

232231
prompt_embeds = self.text_encoder(
233232
text_input_ids.to(device),
234233
output_hidden_states=False,
235-
attention_mask=text_inputs.attention_mask.to(device),
236234
)[0]
237235

238-
max_len = min(text_inputs.attention_mask.sum() + 1, max_sequence_length)
239-
prompt_embeds = prompt_embeds[:, :max_len]
240-
241236
dtype = self.text_encoder.dtype
242237
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
243238

@@ -397,7 +392,6 @@ def check_inputs(
397392
if max_sequence_length is not None and max_sequence_length > 512:
398393
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
399394

400-
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latent_image_ids
401395
@staticmethod
402396
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
403397
latent_image_ids = torch.zeros(height, width, 3)
@@ -412,7 +406,6 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
412406

413407
return latent_image_ids.to(device=device, dtype=dtype)
414408

415-
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
416409
@staticmethod
417410
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
418411
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
@@ -421,7 +414,6 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
421414

422415
return latents
423416

424-
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
425417
@staticmethod
426418
def _unpack_latents(latents, height, width, vae_scale_factor):
427419
batch_size, num_patches, channels = latents.shape

src/diffusers/utils/dummy_torch_and_transformers_objects.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,21 @@ def from_pretrained(cls, *args, **kwargs):
272272
requires_backends(cls, ["torch", "transformers"])
273273

274274

275+
class ChromaPipeline(metaclass=DummyObject):
276+
_backends = ["torch", "transformers"]
277+
278+
def __init__(self, *args, **kwargs):
279+
requires_backends(self, ["torch", "transformers"])
280+
281+
@classmethod
282+
def from_config(cls, *args, **kwargs):
283+
requires_backends(cls, ["torch", "transformers"])
284+
285+
@classmethod
286+
def from_pretrained(cls, *args, **kwargs):
287+
requires_backends(cls, ["torch", "transformers"])
288+
289+
275290
class CLIPImageProjection(metaclass=DummyObject):
276291
_backends = ["torch", "transformers"]
277292

0 commit comments

Comments
 (0)