Skip to content

Commit b78a36c

Browse files
committed
make style
1 parent 64cbf11 commit b78a36c

File tree

4 files changed

+75
-44
lines changed

4 files changed

+75
-44
lines changed

src/diffusers/quantizers/auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Dict, Optional, Union
2020

2121
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
22-
from .quantization_config import BitsAndBytesConfig, TorchAoConfig, QuantizationConfigMixin, QuantizationMethod
22+
from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod, TorchAoConfig
2323

2424

2525
AUTO_QUANTIZER_MAPPING = {

src/diffusers/quantizers/quantization_config.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -400,13 +400,15 @@ class TorchAoConfig(QuantizationConfigMixin):
400400
401401
Args:
402402
quant_type (`str`):
403-
The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight`.
403+
The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only` and
404+
`int8_dynamic_activation_int8_weight`.
404405
modules_to_not_convert (`list`, *optional*, default to `None`):
405-
The list of modules to not quantize, useful for quantizing models that explicitly require to have
406-
some modules left in their original precision.
406+
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
407+
modules left in their original precision.
407408
kwargs (`Dict[str, Any]`, *optional*):
408-
The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization supports two keyword arguments
409-
`group_size` and `inner_k_tiles` currently. More API examples and documentation of arguments can be found in
409+
The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization
410+
supports two keyword arguments `group_size` and `inner_k_tiles` currently. More API examples and
411+
documentation of arguments can be found in
410412
https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
411413
412414
Example:
@@ -415,15 +417,17 @@ class TorchAoConfig(QuantizationConfigMixin):
415417
TODO(aryan): update
416418
quantization_config = TorchAoConfig("int4_weight_only", group_size=32)
417419
# int4_weight_only quant is only working with *torch.bfloat16* dtype right now
418-
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
420+
model = AutoModelForCausalLM.from_pretrained(
421+
model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config
422+
)
419423
```
420424
"""
421425

422426
def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = None, **kwargs):
423427
self.quant_method = QuantizationMethod.TORCHAO
424428
self.quant_type = quant_type
425429
self.modules_to_not_convert = modules_to_not_convert
426-
430+
427431
# When we load from serialized config, "quant_type_kwargs" will be the key
428432
if "quant_type_kwargs" in kwargs:
429433
self.quant_type_kwargs = kwargs["quant_type_kwargs"]
@@ -448,7 +452,7 @@ def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = Non
448452

449453
if len(unsupported_kwargs) > 0:
450454
raise ValueError(
451-
f"The quantization method \"{method}\" does not supported the following keyword arguments: "
455+
f'The quantization method "{method}" does not supported the following keyword arguments: '
452456
f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
453457
)
454458

@@ -460,16 +464,17 @@ def _get_torchao_quant_type_to_method(cls):
460464

461465
if is_torchao_available():
462466
from torchao.quantization import (
463-
int4_weight_only,
464-
int8_dynamic_activation_int8_weight,
465-
int8_dynamic_activation_int4_weight,
466-
int8_weight_only,
467467
float8_dynamic_activation_float8_weight,
468468
float8_static_activation_float8_weight,
469469
float8_weight_only,
470470
fpx_weight_only,
471+
int4_weight_only,
472+
int8_dynamic_activation_int4_weight,
473+
int8_dynamic_activation_int8_weight,
474+
int8_weight_only,
471475
uintx_weight_only,
472476
)
477+
473478
# TODO(aryan): Add a note on how to use PerAxis and PerGroup observers
474479
from torchao.quantization.observer import PerRow, PerTensor
475480

@@ -502,8 +507,10 @@ def _get_torchao_quant_type_to_method(cls):
502507
def generate_float8dq_types(dtype: torch.dtype):
503508
name = "e5m2" if dtype == torch.float8_e5m2 else "e4m3"
504509
types = {}
505-
506-
types[f"float8dq_{name}_a8w8"] = partial(float8_dynamic_activation_float8_weight, activation_dtype=dtype, weight_dtype=dtype)
510+
511+
types[f"float8dq_{name}_a8w8"] = partial(
512+
float8_dynamic_activation_float8_weight, activation_dtype=dtype, weight_dtype=dtype
513+
)
507514
for activation_granularity_cls in [PerTensor, PerRow]:
508515
for weight_granularity_cls in [PerTensor, PerRow]:
509516
activation_name = "t" if activation_granularity_cls is PerTensor else "r"
@@ -526,22 +533,22 @@ def generate_float8dq_types(dtype: torch.dtype):
526533
weight_dtype=dtype,
527534
granularity=(activation_granularity_cls(), weight_granularity_cls()),
528535
)
529-
536+
530537
return types
531538

532539
def generate_fpx_quantization_types(bits: int):
533540
types = {}
534-
541+
535542
for ebits in range(1, bits):
536543
mbits = bits - ebits - 1
537544
types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
538545
types[f"fp{bits}_e{ebits}m{mbits}_a16w{bits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
539-
546+
540547
non_sign_bits = bits - 1
541548
default_ebits = (non_sign_bits + 1) // 2
542549
default_mbits = non_sign_bits - default_ebits
543550
types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits)
544-
551+
545552
return types
546553

547554
# TODO(aryan): handle cuda capability and torch 2.2/2.3
@@ -561,11 +568,19 @@ def generate_fpx_quantization_types(bits: int):
561568
# float8_e5m2 weight + float8 activation (dynamic)
562569
"float8_dynamic_activation_float8_weight": float8_dynamic_activation_float8_weight,
563570
"float8dq": float8_dynamic_activation_float8_weight,
564-
"float8dq_e5m2": partial(float8_dynamic_activation_float8_weight, activation_dtype=torch.float8_e5m2, weight_dtype=torch.float8_e5m2),
571+
"float8dq_e5m2": partial(
572+
float8_dynamic_activation_float8_weight,
573+
activation_dtype=torch.float8_e5m2,
574+
weight_dtype=torch.float8_e5m2,
575+
),
565576
"float8_a8w8": float8_dynamic_activation_float8_weight,
566577
**generate_float8dq_types(torch.float8_e5m2),
567578
# float8_e4m3 weight + float8 activation (dynamic)
568-
"float8dq_e4m3": partial(float8_dynamic_activation_float8_weight, activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn),
579+
"float8dq_e4m3": partial(
580+
float8_dynamic_activation_float8_weight,
581+
activation_dtype=torch.float8_e4m3fn,
582+
weight_dtype=torch.float8_e4m3fn,
583+
),
569584
**generate_float8dq_types(torch.float8_e4m3fn),
570585
# float8 weight + float8 activation (static)
571586
"float8_static_activation_float8_weight": float8_static_activation_float8_weight,

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919

2020
import importlib
2121
import types
22-
from packaging import version
2322
from typing import TYPE_CHECKING, Any, Dict, List, Union
2423

25-
from ..base import DiffusersQuantizer
24+
from packaging import version
25+
2626
from ...utils import get_module_from_name, is_torch_available, is_torchao_available, logging
27+
from ..base import DiffusersQuantizer
28+
2729

2830
if TYPE_CHECKING:
2931
from ...models.modeling_utils import ModelMixin
@@ -69,10 +71,12 @@ def __init__(self, quantization_config, **kwargs):
6971

7072
def validate_environment(self, *args, **kwargs):
7173
if not is_torchao_available():
72-
raise ImportError("Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`")
74+
raise ImportError(
75+
"Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`"
76+
)
7377

7478
self.offload = False
75-
79+
7680
device_map = kwargs.get("device_map", None)
7781
if isinstance(device_map, dict):
7882
if "cpu" in device_map.values() or "disk" in device_map.values():
@@ -83,7 +87,7 @@ def validate_environment(self, *args, **kwargs):
8387
)
8488
else:
8589
self.offload = True
86-
90+
8791
if self.pre_quantized:
8892
weights_only = kwargs.get("weights_only", None)
8993
if weights_only:
@@ -96,29 +100,41 @@ def validate_environment(self, *args, **kwargs):
96100

97101
def update_torch_dtype(self, torch_dtype):
98102
quant_type = self.quantization_config.quant_type
99-
103+
100104
if quant_type.startswith("int") or quant_type.startswith("uint"):
101105
if torch_dtype is not None and torch_dtype != torch.bfloat16:
102106
logger.warning(
103107
f"Setting torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`."
104108
)
105-
109+
106110
if torch_dtype is None:
107111
# we need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op
108112
logger.info(
109-
f"Overriding `torch_dtype` with `torch_dtype=torch.bfloat16` due to requirements of `torchao` "
110-
f"to enable model loading in different precisions. Pass your own `torch_dtype` to specify the "
111-
f"dtype of the remaining non-linear layers, or pass torch_dtype=torch.bfloat16, to remove this warning."
113+
"Overriding `torch_dtype` with `torch_dtype=torch.bfloat16` due to requirements of `torchao` "
114+
"to enable model loading in different precisions. Pass your own `torch_dtype` to specify the "
115+
"dtype of the remaining non-linear layers, or pass torch_dtype=torch.bfloat16, to remove this warning."
112116
)
113117
torch_dtype = torch.bfloat16
114-
118+
115119
return torch_dtype
116120

117121
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
118-
supported_dtypes = (torch.int8, torch.float8_e4m3fn, torch.float8_e5m2, torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7, torch.uint8)
122+
supported_dtypes = (
123+
torch.int8,
124+
torch.float8_e4m3fn,
125+
torch.float8_e5m2,
126+
torch.uint1,
127+
torch.uint2,
128+
torch.uint3,
129+
torch.uint4,
130+
torch.uint5,
131+
torch.uint6,
132+
torch.uint7,
133+
torch.uint8,
134+
)
119135
if isinstance(target_dtype, supported_dtypes):
120136
return target_dtype
121-
137+
122138
raise ValueError(
123139
f"You are using `device_map='auto'` on a TorchAO quantized model but a suitable target dtype "
124140
f"could not be inferred. The supported target_dtypes are: {supported_dtypes}. If you think the "
@@ -161,8 +177,8 @@ def create_quantized_param(
161177
unexpected_keys: List[str],
162178
):
163179
r"""
164-
Each nn.Linear layer that needs to be quantized is processsed here.
165-
First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module.
180+
Each nn.Linear layer that needs to be quantized is processsed here. First, we set the value the weight tensor,
181+
then we move it to the target device. Finally, we quantize the module.
166182
"""
167183
from torchao.quantization import quantize_
168184

@@ -187,22 +203,22 @@ def _process_model_before_weight_loading(
187203

188204
if not isinstance(self.modules_to_not_convert, list):
189205
self.modules_to_not_convert = [self.modules_to_not_convert]
190-
206+
191207
self.modules_to_not_convert.extend(keep_in_fp32_modules)
192208

193209
# Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
194210
if isinstance(device_map, dict) and len(device_map.keys()) > 1:
195211
keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
196212
self.modules_to_not_convert.extend(keys_on_cpu)
197-
213+
198214
# Purge `None`.
199215
# Unlike `transformers`, we don't know if we should always keep certain modules in FP32
200216
# in case of diffusion transformer models. For language models and others alike, `lm_head`
201217
# and tied modules are usually kept in FP32.
202218
self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
203219

204220
model.config.quantization_config = self.quantization_config
205-
221+
206222
def _process_model_after_weight_loading(self, model: "ModelMixin"):
207223
return model
208224

@@ -213,21 +229,21 @@ def is_serializable(self, safe_serialization=None):
213229
"torchao quantized model does not support safe serialization, please set `safe_serialization` to False."
214230
)
215231
return False
216-
232+
217233
_is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(
218234
"0.25.0"
219235
)
220-
236+
221237
if not _is_torchao_serializable:
222238
logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ")
223-
239+
224240
if self.offload and self.quantization_config.modules_to_not_convert is None:
225241
logger.warning(
226242
"The model contains offloaded modules and these modules are not quantized. We don't recommend saving the model as we won't be able to reload them."
227243
"If you want to specify modules to not quantize, please specify modules_to_not_convert in the quantization_config."
228244
)
229245
return False
230-
246+
231247
return _is_torchao_serializable
232248

233249
@property

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,11 @@
8282
is_sentencepiece_available,
8383
is_tensorboard_available,
8484
is_timm_available,
85-
is_torchao_available,
8685
is_torch_available,
8786
is_torch_npu_available,
8887
is_torch_version,
8988
is_torch_xla_available,
89+
is_torchao_available,
9090
is_torchsde_available,
9191
is_torchvision_available,
9292
is_transformers_available,

0 commit comments

Comments
 (0)