Skip to content

Commit d3daf1e

Browse files
authored
Merge branch 'main' into xla_sana
2 parents 7a14bd5 + 1b202c5 commit d3daf1e

File tree

7 files changed

+472
-125
lines changed

7 files changed

+472
-125
lines changed

.github/workflows/nightly_tests.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,8 @@ jobs:
359359
test_location: "bnb"
360360
- backend: "gguf"
361361
test_location: "gguf"
362+
- backend: "torchao"
363+
test_location: "torchao"
362364
runs-on:
363365
group: aws-g6e-xlarge-plus
364366
container:

docs/source/en/quantization/torchao.md

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]
2525
The example below only quantizes the weights to int8.
2626

2727
```python
28+
import torch
2829
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
2930

3031
model_id = "black-forest-labs/FLUX.1-dev"
@@ -44,6 +45,10 @@ pipe = FluxPipeline.from_pretrained(
4445
)
4546
pipe.to("cuda")
4647

48+
# Without quantization: ~31.447 GB
49+
# With quantization: ~20.40 GB
50+
print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")
51+
4752
prompt = "A cat holding a sign that says hello world"
4853
image = pipe(
4954
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
@@ -88,6 +93,63 @@ Some quantization methods are aliases (for example, `int8wo` is the commonly use
8893

8994
Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
9095

96+
## Serializing and Deserializing quantized models
97+
98+
To serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the [`~ModelMixin.save_pretrained`] method.
99+
100+
```python
101+
import torch
102+
from diffusers import FluxTransformer2DModel, TorchAoConfig
103+
104+
quantization_config = TorchAoConfig("int8wo")
105+
transformer = FluxTransformer2DModel.from_pretrained(
106+
"black-forest-labs/Flux.1-Dev",
107+
subfolder="transformer",
108+
quantization_config=quantization_config,
109+
torch_dtype=torch.bfloat16,
110+
)
111+
transformer.save_pretrained("/path/to/flux_int8wo", safe_serialization=False)
112+
```
113+
114+
To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] method.
115+
116+
```python
117+
import torch
118+
from diffusers import FluxPipeline, FluxTransformer2DModel
119+
120+
transformer = FluxTransformer2DModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False)
121+
pipe = FluxPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", transformer=transformer, torch_dtype=torch.bfloat16)
122+
pipe.to("cuda")
123+
124+
prompt = "A cat holding a sign that says hello world"
125+
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
126+
image.save("output.png")
127+
```
128+
129+
Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
130+
131+
```python
132+
import torch
133+
from accelerate import init_empty_weights
134+
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
135+
136+
# Serialize the model
137+
transformer = FluxTransformer2DModel.from_pretrained(
138+
"black-forest-labs/Flux.1-Dev",
139+
subfolder="transformer",
140+
quantization_config=TorchAoConfig("uint4wo"),
141+
torch_dtype=torch.bfloat16,
142+
)
143+
transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB")
144+
# ...
145+
146+
# Load the model
147+
state_dict = torch.load("/path/to/flux_uint4wo/diffusion_pytorch_model.bin", weights_only=False, map_location="cpu")
148+
with init_empty_weights():
149+
transformer = FluxTransformer2DModel.from_config("/path/to/flux_uint4wo/config.json")
150+
transformer.load_state_dict(state_dict, strict=True, assign=True)
151+
```
152+
91153
## Resources
92154

93155
- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md)

src/diffusers/loaders/lora_pipeline.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2286,6 +2286,50 @@ def unload_lora_weights(self):
22862286
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
22872287
transformer._transformer_norm_layers = None
22882288

2289+
if getattr(transformer, "_overwritten_params", None) is not None:
2290+
overwritten_params = transformer._overwritten_params
2291+
module_names = set()
2292+
2293+
for param_name in overwritten_params:
2294+
if param_name.endswith(".weight"):
2295+
module_names.add(param_name.replace(".weight", ""))
2296+
2297+
for name, module in transformer.named_modules():
2298+
if isinstance(module, torch.nn.Linear) and name in module_names:
2299+
module_weight = module.weight.data
2300+
module_bias = module.bias.data if module.bias is not None else None
2301+
bias = module_bias is not None
2302+
2303+
parent_module_name, _, current_module_name = name.rpartition(".")
2304+
parent_module = transformer.get_submodule(parent_module_name)
2305+
2306+
current_param_weight = overwritten_params[f"{name}.weight"]
2307+
in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0]
2308+
with torch.device("meta"):
2309+
original_module = torch.nn.Linear(
2310+
in_features,
2311+
out_features,
2312+
bias=bias,
2313+
dtype=module_weight.dtype,
2314+
)
2315+
2316+
tmp_state_dict = {"weight": current_param_weight}
2317+
if module_bias is not None:
2318+
tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]})
2319+
original_module.load_state_dict(tmp_state_dict, assign=True, strict=True)
2320+
setattr(parent_module, current_module_name, original_module)
2321+
2322+
del tmp_state_dict
2323+
2324+
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
2325+
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
2326+
new_value = int(current_param_weight.shape[1])
2327+
old_value = getattr(transformer.config, attribute_name)
2328+
setattr(transformer.config, attribute_name, new_value)
2329+
logger.info(
2330+
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
2331+
)
2332+
22892333
@classmethod
22902334
def _maybe_expand_transformer_param_shape_or_error_(
22912335
cls,
@@ -2312,6 +2356,8 @@ def _maybe_expand_transformer_param_shape_or_error_(
23122356

23132357
# Expand transformer parameter shapes if they don't match lora
23142358
has_param_with_shape_update = False
2359+
overwritten_params = {}
2360+
23152361
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
23162362
for name, module in transformer.named_modules():
23172363
if isinstance(module, torch.nn.Linear):
@@ -2386,6 +2432,16 @@ def _maybe_expand_transformer_param_shape_or_error_(
23862432
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
23872433
)
23882434

2435+
# For `unload_lora_weights()`.
2436+
# TODO: this could lead to more memory overhead if the number of overwritten params
2437+
# are large. Should be revisited later and tackled through a `discard_original_layers` arg.
2438+
overwritten_params[f"{current_module_name}.weight"] = module_weight
2439+
if module_bias is not None:
2440+
overwritten_params[f"{current_module_name}.bias"] = module_bias
2441+
2442+
if len(overwritten_params) > 0:
2443+
transformer._overwritten_params = overwritten_params
2444+
23892445
return has_param_with_shape_update
23902446

23912447
@classmethod

src/diffusers/models/modeling_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -718,10 +718,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
718718
hf_quantizer = None
719719

720720
if hf_quantizer is not None:
721-
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"
722-
if is_bnb_quantization_method and device_map is not None:
721+
if device_map is not None:
723722
raise NotImplementedError(
724-
"Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future."
723+
"Currently, providing `device_map` is not supported for quantized models. Providing `device_map` as an input will be added in the future."
725724
)
726725

727726
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
@@ -820,7 +819,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
820819
revision=revision,
821820
subfolder=subfolder or "",
822821
)
823-
if hf_quantizer is not None and is_bnb_quantization_method:
822+
# TODO: https://github.com/huggingface/diffusers/issues/10013
823+
if hf_quantizer is not None:
824824
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
825825
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
826826
is_sharded = False

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def validate_environment(self, *args, **kwargs):
132132
def update_torch_dtype(self, torch_dtype):
133133
quant_type = self.quantization_config.quant_type
134134

135-
if quant_type.startswith("int"):
135+
if quant_type.startswith("int") or quant_type.startswith("uint"):
136136
if torch_dtype is not None and torch_dtype != torch.bfloat16:
137137
logger.warning(
138138
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "

tests/lora/test_lora_layers_flux.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,72 @@ def test_load_regular_lora(self):
558558
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
559559
self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3))
560560

561+
def test_lora_unload_with_parameter_expanded_shapes(self):
562+
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
563+
564+
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
565+
logger.setLevel(logging.DEBUG)
566+
567+
# Change the transformer config to mimic a real use case.
568+
num_channels_without_control = 4
569+
transformer = FluxTransformer2DModel.from_config(
570+
components["transformer"].config, in_channels=num_channels_without_control
571+
).to(torch_device)
572+
self.assertTrue(
573+
transformer.config.in_channels == num_channels_without_control,
574+
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
575+
)
576+
577+
# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
578+
components["transformer"] = transformer
579+
pipe = FluxPipeline(**components)
580+
pipe = pipe.to(torch_device)
581+
pipe.set_progress_bar_config(disable=None)
582+
583+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
584+
control_image = inputs.pop("control_image")
585+
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
586+
587+
control_pipe = self.pipeline_class(**components)
588+
out_features, in_features = control_pipe.transformer.x_embedder.weight.shape
589+
rank = 4
590+
591+
dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
592+
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
593+
lora_state_dict = {
594+
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
595+
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
596+
}
597+
with CaptureLogger(logger) as cap_logger:
598+
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
599+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
600+
601+
inputs["control_image"] = control_image
602+
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
603+
604+
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
605+
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
606+
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
607+
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
608+
609+
control_pipe.unload_lora_weights()
610+
self.assertTrue(
611+
control_pipe.transformer.config.in_channels == num_channels_without_control,
612+
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
613+
)
614+
loaded_pipe = FluxPipeline.from_pipe(control_pipe)
615+
self.assertTrue(
616+
loaded_pipe.transformer.config.in_channels == num_channels_without_control,
617+
f"Expected {num_channels_without_control} channels in the modified transformer but has {loaded_pipe.transformer.config.in_channels=}",
618+
)
619+
inputs.pop("control_image")
620+
unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0]
621+
622+
self.assertFalse(np.allclose(unloaded_lora_out, lora_out, rtol=1e-4, atol=1e-4))
623+
self.assertTrue(np.allclose(unloaded_lora_out, original_out, atol=1e-4, rtol=1e-4))
624+
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
625+
self.assertTrue(pipe.transformer.config.in_channels == in_features)
626+
561627
@unittest.skip("Not supported in Flux.")
562628
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
563629
pass

0 commit comments

Comments
 (0)