Skip to content

Commit 1245393

Browse files
authored
Merge branch 'main' into xpu-precision-ut
2 parents b35ba59 + 42077e6 commit 1245393

File tree

7 files changed

+144
-34
lines changed

7 files changed

+144
-34
lines changed

docs/source/en/api/pipelines/chroma.md

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,36 @@ Chroma can use all the same optimizations as Flux.
2727

2828
</Tip>
2929

30-
## Inference (Single File)
30+
## Inference
3131

32-
The `ChromaTransformer2DModel` supports loading checkpoints in the original format. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
32+
The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma).
33+
34+
```python
35+
import torch
36+
from diffusers import ChromaPipeline
37+
38+
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16)
39+
pipe.enabe_model_cpu_offload()
40+
41+
prompt = [
42+
"A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
43+
]
44+
negative_prompt = ["low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"]
45+
46+
image = pipe(
47+
prompt=prompt,
48+
negative_prompt=negative_prompt,
49+
generator=torch.Generator("cpu").manual_seed(433),
50+
num_inference_steps=40,
51+
guidance_scale=3.0,
52+
num_images_per_prompt=1,
53+
).images[0]
54+
image.save("chroma.png")
55+
```
56+
57+
## Loading from a single file
58+
59+
To use updated model checkpoints that are not in the Diffusers format, you can use the `ChromaTransformer2DModel` class to load the model from a single file in the original format. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
3360

3461
The following example demonstrates how to run Chroma from a single file.
3562

@@ -38,34 +65,39 @@ Then run the following example
3865
```python
3966
import torch
4067
from diffusers import ChromaTransformer2DModel, ChromaPipeline
41-
from transformers import T5EncoderModel
4268

43-
bfl_repo = "black-forest-labs/FLUX.1-dev"
69+
model_id = "lodestones/Chroma"
4470
dtype = torch.bfloat16
4571

46-
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v35.safetensors", torch_dtype=dtype)
47-
48-
text_encoder = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
49-
tokenizer = T5Tokenizer.from_pretrained(bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype)
50-
51-
pipe = ChromaPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=dtype)
72+
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors", torch_dtype=dtype)
5273

74+
pipe = ChromaPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=dtype)
5375
pipe.enable_model_cpu_offload()
5476

55-
prompt = "A cat holding a sign that says hello world"
77+
prompt = [
78+
"A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
79+
]
80+
negative_prompt = ["low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"]
81+
5682
image = pipe(
57-
prompt,
58-
guidance_scale=4.0,
59-
output_type="pil",
60-
num_inference_steps=26,
61-
generator=torch.Generator("cpu").manual_seed(0)
83+
prompt=prompt,
84+
negative_prompt=negative_prompt,
85+
generator=torch.Generator("cpu").manual_seed(433),
86+
num_inference_steps=40,
87+
guidance_scale=3.0,
6288
).images[0]
6389

64-
image.save("image.png")
90+
image.save("chroma-single-file.png")
6591
```
6692

6793
## ChromaPipeline
6894

6995
[[autodoc]] ChromaPipeline
7096
- all
7197
- __call__
98+
99+
## ChromaImg2ImgPipeline
100+
101+
[[autodoc]] ChromaImg2ImgPipeline
102+
- all
103+
- __call__

src/diffusers/pipelines/chroma/pipeline_chroma.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,20 +52,21 @@
5252
>>> import torch
5353
>>> from diffusers import ChromaPipeline
5454
55+
>>> model_id = "lodestones/Chroma"
5556
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
5657
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
57-
>>> text_encoder = AutoModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2")
58-
>>> tokenizer = AutoTokenizer.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2")
59-
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
60-
... "black-forest-labs/FLUX.1-schnell",
58+
>>> pipe = ChromaPipeline.from_pretrained(
59+
... model_id,
6160
... transformer=transformer,
62-
... text_encoder=text_encoder,
63-
... tokenizer=tokenizer,
6461
... torch_dtype=torch.bfloat16,
6562
... )
6663
>>> pipe.enable_model_cpu_offload()
67-
>>> prompt = "A cat holding a sign that says hello world"
68-
>>> negative_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
64+
>>> prompt = [
65+
... "A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
66+
... ]
67+
>>> negative_prompt = [
68+
... "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
69+
... ]
6970
>>> image = pipe(prompt, negative_prompt=negative_prompt).images[0]
7071
>>> image.save("chroma.png")
7172
```

src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,26 +51,21 @@
5151
```py
5252
>>> import torch
5353
>>> from diffusers import ChromaTransformer2DModel, ChromaImg2ImgPipeline
54-
>>> from transformers import AutoModel, Autotokenizer
5554
55+
>>> model_id = "lodestones/Chroma"
5656
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
57-
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
58-
>>> text_encoder = AutoModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2")
59-
>>> tokenizer = AutoTokenizer.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2")
6057
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
61-
... "black-forest-labs/FLUX.1-schnell",
58+
... model_id,
6259
... transformer=transformer,
63-
... text_encoder=text_encoder,
64-
... tokenizer=tokenizer,
6560
... torch_dtype=torch.bfloat16,
6661
... )
6762
>>> pipe.enable_model_cpu_offload()
68-
>>> image = load_image(
63+
>>> init_image = load_image(
6964
... "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
7065
... )
7166
>>> prompt = "a scenic fastasy landscape with a river and mountains in the background, vibrant colors, detailed, high resolution"
7267
>>> negative_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
73-
>>> image = pipe(prompt, image=image, negative_prompt=negative_prompt).images[0]
68+
>>> image = pipe(prompt, image=init_image, negative_prompt=negative_prompt).images[0]
7469
>>> image.save("chroma-img2img.png")
7570
```
7671
"""

src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def retrieve_latents(
4444

4545

4646
class LTXLatentUpsamplePipeline(DiffusionPipeline):
47+
model_cpu_offload_seq = ""
48+
4749
def __init__(
4850
self,
4951
vae: AutoencoderKLLTXVideo,

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,3 +1131,26 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict):
11311131
break
11321132
if has_transformers_component and not is_transformers_version(">", "4.47.1"):
11331133
raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")
1134+
1135+
1136+
def _maybe_warn_for_wrong_component_in_quant_config(pipe_init_dict, quant_config):
1137+
if quant_config is None:
1138+
return
1139+
1140+
actual_pipe_components = set(pipe_init_dict.keys())
1141+
missing = ""
1142+
quant_components = None
1143+
if getattr(quant_config, "components_to_quantize", None) is not None:
1144+
quant_components = set(quant_config.components_to_quantize)
1145+
elif getattr(quant_config, "quant_mapping", None) is not None and isinstance(quant_config.quant_mapping, dict):
1146+
quant_components = set(quant_config.quant_mapping.keys())
1147+
1148+
if quant_components and not quant_components.issubset(actual_pipe_components):
1149+
missing = quant_components - actual_pipe_components
1150+
1151+
if missing:
1152+
logger.warning(
1153+
f"The following components in the quantization config {missing} will be ignored "
1154+
"as they do not belong to the underlying pipeline. Acceptable values for the pipeline "
1155+
f"components are: {', '.join(actual_pipe_components)}."
1156+
)

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
_identify_model_variants,
8989
_maybe_raise_error_for_incorrect_transformers,
9090
_maybe_raise_warning_for_inpainting,
91+
_maybe_warn_for_wrong_component_in_quant_config,
9192
_resolve_custom_pipeline_and_cls,
9293
_unwrap_model,
9394
_update_init_kwargs_with_connected_pipeline,
@@ -984,6 +985,7 @@ def load_module(name, value):
984985

985986
# 7. Load each module in the pipeline
986987
current_device_map = None
988+
_maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
987989
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
988990
# 7.1 device_map shenanigans
989991
if final_device_map is not None and len(final_device_map) > 0:

tests/quantization/test_pipeline_level_quantization.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616
import unittest
1717

1818
import torch
19+
from parameterized import parameterized
1920

2021
from diffusers import DiffusionPipeline, QuantoConfig
2122
from diffusers.quantizers import PipelineQuantizationConfig
23+
from diffusers.utils import logging
2224
from diffusers.utils.testing_utils import (
25+
CaptureLogger,
2326
is_transformers_available,
2427
require_accelerate,
2528
require_bitsandbytes_version_greater,
@@ -188,3 +191,55 @@ def test_saving_loading(self):
188191
output_2 = loaded_pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images
189192

190193
self.assertTrue(torch.allclose(output_1, output_2))
194+
195+
@parameterized.expand(["quant_kwargs", "quant_mapping"])
196+
def test_warn_invalid_component(self, method):
197+
invalid_component = "foo"
198+
if method == "quant_kwargs":
199+
components_to_quantize = ["transformer", invalid_component]
200+
quant_config = PipelineQuantizationConfig(
201+
quant_backend="bitsandbytes_8bit",
202+
quant_kwargs={"load_in_8bit": True},
203+
components_to_quantize=components_to_quantize,
204+
)
205+
else:
206+
quant_config = PipelineQuantizationConfig(
207+
quant_mapping={
208+
"transformer": QuantoConfig("int8"),
209+
invalid_component: TranBitsAndBytesConfig(load_in_8bit=True),
210+
}
211+
)
212+
213+
logger = logging.get_logger("diffusers.pipelines.pipeline_loading_utils")
214+
logger.setLevel(logging.WARNING)
215+
with CaptureLogger(logger) as cap_logger:
216+
_ = DiffusionPipeline.from_pretrained(
217+
self.model_name,
218+
quantization_config=quant_config,
219+
torch_dtype=torch.bfloat16,
220+
)
221+
self.assertTrue(invalid_component in cap_logger.out)
222+
223+
@parameterized.expand(["quant_kwargs", "quant_mapping"])
224+
def test_no_quantization_for_all_invalid_components(self, method):
225+
invalid_component = "foo"
226+
if method == "quant_kwargs":
227+
components_to_quantize = [invalid_component]
228+
quant_config = PipelineQuantizationConfig(
229+
quant_backend="bitsandbytes_8bit",
230+
quant_kwargs={"load_in_8bit": True},
231+
components_to_quantize=components_to_quantize,
232+
)
233+
else:
234+
quant_config = PipelineQuantizationConfig(
235+
quant_mapping={invalid_component: TranBitsAndBytesConfig(load_in_8bit=True)}
236+
)
237+
238+
pipe = DiffusionPipeline.from_pretrained(
239+
self.model_name,
240+
quantization_config=quant_config,
241+
torch_dtype=torch.bfloat16,
242+
)
243+
for name, component in pipe.components.items():
244+
if isinstance(component, torch.nn.Module):
245+
self.assertTrue(not hasattr(component.config, "quantization_config"))

0 commit comments

Comments
 (0)