Skip to content

Commit e090177

Browse files
committed
update
1 parent 4ae8691 commit e090177

File tree

4 files changed

+102
-10
lines changed

4 files changed

+102
-10
lines changed

docs/source/en/quantization/quanto.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,19 @@ image = pipe(
4747
image.save("output.png")
4848
```
4949

50+
## Skipping Quantization on specific modules
51+
52+
It is possible to skip applying quantization on certain modules using the `modules_to_not_convert` argument in the `QuantoConfig`. Please ensure that the modules passed in to this argument match the keys of the modules in the `state_dict`
53+
54+
```python
55+
import torch
56+
from diffusers import FluxTransformer2DModel, QuantoConfig
57+
58+
model_id = "black-forest-labs/FLUX.1-dev"
59+
quantization_config = QuantoConfig(weights="float8", modules_to_not_convert=["proj_out"])
60+
transformer = FluxTransformer2DModel.from_pretrained(model_id, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
61+
```
62+
5063
## Using `from_single_file` with the Quanto Backend
5164

5265
```python

src/diffusers/models/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1036,11 +1036,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
10361036
)
10371037

10381038
named_buffers = model.named_buffers()
1039-
10401039
unexpected_keys = load_model_dict_into_meta(
10411040
model,
10421041
state_dict,
10431042
device=param_device,
1043+
dtype=torch_dtype,
10441044
model_name_or_path=pretrained_model_name_or_path,
10451045
hf_quantizer=hf_quantizer,
10461046
keep_in_fp32_modules=keep_in_fp32_modules,

src/diffusers/quantizers/quanto/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
import torch.nn as nn
44

5-
from ...utils import is_accelerate_available
5+
from ...utils import is_accelerate_available, logging
66

77

8+
logger = logging.get_logger(__name__)
9+
810
if is_accelerate_available():
911
from accelerate import init_empty_weights
1012

@@ -47,5 +49,13 @@ def _replace_layers(model, quantization_config, modules_to_not_convert):
4749
return model
4850

4951
model = _replace_layers(model, quantization_config, modules_to_not_convert)
52+
has_been_replaced = any(isinstance(replaced_module, QLinear) for _, replaced_module in model.named_modules())
53+
54+
if not has_been_replaced:
55+
logger.warning(
56+
f"{model.__class__.__name__} does not appear to have any `nn.Linear` modules. Quantization will not be applied."
57+
" Please check your model architecture, or submit an issue on Github if you think this is a bug."
58+
" https://github.com/huggingface/diffusers"
59+
)
5060

5161
return model

tests/quantization/quanto/test_quanto.py

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import tempfile
12
import unittest
23

34
import torch
@@ -9,7 +10,6 @@
910
from diffusers.utils import is_optimum_quanto_available
1011
from diffusers.utils.testing_utils import (
1112
nightly,
12-
numpy_cosine_similarity_distance,
1313
require_accelerate,
1414
require_big_gpu_with_torch_cuda,
1515
torch_device,
@@ -29,6 +29,7 @@ class QuantoBaseTesterMixin:
2929
torch_dtype = torch.bfloat16
3030
expected_memory_use_in_gb = 5
3131
keep_in_fp32_module = ""
32+
modules_to_not_convert = ""
3233

3334
def get_dummy_init_kwargs(self):
3435
return {"weights": "float8"}
@@ -76,6 +77,22 @@ def test_keep_modules_in_fp32(self):
7677
assert module.weight.dtype == torch.float32
7778
self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules
7879

80+
def test_modules_to_not_convert(self):
81+
init_kwargs = self.get_dummy_model_init_kwargs()
82+
83+
quantization_config_kwargs = self.get_dummy_init_kwargs()
84+
quantization_config_kwargs.update({"modules_to_not_convert": self.modules_to_not_convert})
85+
quantization_config = QuantoConfig(**quantization_config_kwargs)
86+
87+
init_kwargs.update({"quantization_config": quantization_config})
88+
89+
model = self.model_cls.from_pretrained(**init_kwargs)
90+
model.to("cuda")
91+
92+
for name, module in model.named_modules():
93+
if name in self.modules_to_not_convert:
94+
assert not isinstance(module, QLinear)
95+
7996
def test_dtype_assignment(self):
8097
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
8198
assert (model.get_memory_footprint() / 1024**3) < self.expected_memory_use_in_gb
@@ -99,12 +116,35 @@ def test_dtype_assignment(self):
99116
# This should work
100117
model.to("cuda")
101118

119+
def test_serialization(self):
120+
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
121+
inputs = self.get_dummy_inputs()
122+
123+
model.to(torch_device)
124+
with torch.no_grad():
125+
model_output = model(**inputs)
126+
127+
with tempfile.TemporaryDirectory() as tmp_dir:
128+
model.save_pretrained(tmp_dir)
129+
saved_model = self.model_cls.from_pretrained(
130+
tmp_dir,
131+
torch_dtype=torch.bfloat16,
132+
)
133+
134+
saved_model.to(torch_device)
135+
with torch.no_grad():
136+
saved_model_output = saved_model(**inputs)
137+
138+
max_diff = torch.abs(model_output - saved_model_output).max()
139+
assert max_diff < 1e-5
140+
102141

103142
class FluxTransformerQuantoMixin(QuantoBaseTesterMixin):
104143
model_id = "hf-internal-testing/tiny-flux-transformer"
105144
model_cls = FluxTransformer2DModel
106145
torch_dtype = torch.bfloat16
107146
keep_in_fp32_module = "proj_out"
147+
modules_to_not_convert = ["proj_out"]
108148

109149
def get_dummy_inputs(self):
110150
return {
@@ -130,14 +170,21 @@ def get_dummy_inputs(self):
130170
}
131171

132172

133-
class FluxTransformerFloat8(FluxTransformerQuantoMixin, unittest.TestCase):
173+
class FluxTransformerFloat8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
134174
expected_memory_use_in_gb = 10
135175

136176
def get_dummy_init_kwargs(self):
137177
return {"weights": "float8"}
138178

139179

140-
class FluxTransformerInt8(FluxTransformerQuantoMixin, unittest.TestCase):
180+
class FluxTransformerFloat8WeightsAndActivationTest(FluxTransformerQuantoMixin, unittest.TestCase):
181+
expected_memory_use_in_gb = 10
182+
183+
def get_dummy_init_kwargs(self):
184+
return {"weights": "float8", "activations": "float8"}
185+
186+
187+
class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
141188
expected_memory_use_in_gb = 10
142189

143190
def get_dummy_init_kwargs(self):
@@ -157,20 +204,42 @@ def test_torch_compile(self):
157204
with torch.no_grad():
158205
compiled_model_output = compiled_model(**inputs).sample
159206

160-
max_diff = numpy_cosine_similarity_distance(
161-
model_output.cpu().flatten(), compiled_model_output.cpu().flatten()
162-
)
207+
max_diff = torch.abs(model_output - compiled_model_output).max()
208+
assert max_diff < 1e-4
209+
210+
211+
class FluxTransformerInt8WeightsAndActivationTest(FluxTransformerQuantoMixin, unittest.TestCase):
212+
expected_memory_use_in_gb = 10
213+
214+
def get_dummy_init_kwargs(self):
215+
return {"weights": "int8", "activations": "int8"}
216+
217+
def test_torch_compile(self):
218+
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
219+
compiled_model = torch.compile(model, mode="max-autotune", fullgraph=True)
220+
inputs = self.get_dummy_inputs()
221+
222+
model.to(torch_device)
223+
with torch.no_grad():
224+
model_output = model(**inputs).sample
225+
model.to("cpu")
226+
227+
compiled_model.to(torch_device)
228+
with torch.no_grad():
229+
compiled_model_output = compiled_model(**inputs).sample
230+
231+
max_diff = torch.abs(model_output - compiled_model_output).max()
163232
assert max_diff < 1e-4
164233

165234

166-
class FluxTransformerInt4(FluxTransformerQuantoMixin, unittest.TestCase):
235+
class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
167236
expected_memory_use_in_gb = 6
168237

169238
def get_dummy_init_kwargs(self):
170239
return {"weights": "int4"}
171240

172241

173-
class FluxTransformerInt2(FluxTransformerQuantoMixin, unittest.TestCase):
242+
class FluxTransformerInt2WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
174243
expected_memory_use_in_gb = 6
175244

176245
def get_dummy_init_kwargs(self):

0 commit comments

Comments
 (0)