Skip to content

Commit cbb0da4

Browse files
committed
update
1 parent 355509e commit cbb0da4

File tree

4 files changed

+10
-9
lines changed

4 files changed

+10
-9
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,9 @@ def load_model_dict_into_meta(
176176
hf_quantizer=None,
177177
keep_in_fp32_modules=None,
178178
) -> List[str]:
179-
if hf_quantizer is None:
180-
device = device or torch.device("cpu")
179+
device = device or torch.device("cpu")
181180
dtype = dtype or torch.float32
182181
is_quantized = hf_quantizer is not None
183-
is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
184182

185183
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
186184
empty_state_dict = model.state_dict()
@@ -213,12 +211,12 @@ def load_model_dict_into_meta(
213211
# bnb params are flattened.
214212
if empty_state_dict[param_name].shape != param.shape:
215213
if (
216-
is_quant_method_bnb
214+
is_quantized
217215
and hf_quantizer.pre_quantized
218216
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
219217
):
220218
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape)
221-
elif not is_quant_method_bnb:
219+
else:
222220
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
223221
raise ValueError(
224222
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."

src/diffusers/models/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
835835
if hf_quantizer is None:
836836
param_device = "cpu"
837837
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
838-
elif is_quant_method_bnb:
838+
else:
839839
param_device = torch.cuda.current_device()
840840
state_dict = load_state_dict(model_file, variant=variant)
841841
model._convert_deprecated_attention_blocks(state_dict)

src/diffusers/quantizers/auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919
from typing import Dict, Optional, Union
2020

2121
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
22+
from .torchao import TorchAoHfQuantizer
2223
from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod, TorchAoConfig
2324

2425

2526
AUTO_QUANTIZER_MAPPING = {
2627
"bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
2728
"bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
29+
"torchao": TorchAoHfQuantizer,
2830
}
2931

3032
AUTO_QUANTIZATION_CONFIG_MAPPING = {

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,15 +101,16 @@ def validate_environment(self, *args, **kwargs):
101101
def update_torch_dtype(self, torch_dtype):
102102
quant_type = self.quantization_config.quant_type
103103

104-
if quant_type.startswith("int") or quant_type.startswith("uint"):
104+
if quant_type.startswith("int"):
105105
if torch_dtype is not None and torch_dtype != torch.bfloat16:
106106
logger.warning(
107-
f"Setting torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`."
107+
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
108+
f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`."
108109
)
109110

110111
if torch_dtype is None:
111112
# we need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op
112-
logger.info(
113+
logger.warning(
113114
"Overriding `torch_dtype` with `torch_dtype=torch.bfloat16` due to requirements of `torchao` "
114115
"to enable model loading in different precisions. Pass your own `torch_dtype` to specify the "
115116
"dtype of the remaining non-linear layers, or pass torch_dtype=torch.bfloat16, to remove this warning."

0 commit comments

Comments
 (0)