Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,9 +719,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

if hf_quantizer is not None:
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"
if is_bnb_quantization_method and device_map is not None:
is_torchao_quantization_method = hf_quantizer.quantization_config.quant_method.value == "torchao"
if (is_bnb_quantization_method or is_torchao_quantization_method) and device_map is not None:
raise NotImplementedError(
"Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future."
"Currently, `device_map` is automatically inferred for quantized bitsandbytes and torchao models. Support for providing `device_map` as an input will be added in the future."
)

hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
Expand Down Expand Up @@ -820,7 +821,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
revision=revision,
subfolder=subfolder or "",
)
if hf_quantizer is not None and is_bnb_quantization_method:
if hf_quantizer is not None:
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
is_sharded = False
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/quantizers/torchao/torchao_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def validate_environment(self, *args, **kwargs):
def update_torch_dtype(self, torch_dtype):
quant_type = self.quantization_config.quant_type

if quant_type.startswith("int"):
if quant_type.startswith("int") or quant_type.startswith("uint"):
if torch_dtype is not None and torch_dtype != torch.bfloat16:
logger.warning(
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
Expand Down
Loading
Loading