Skip to content

Commit e9fccb6

Browse files
committed
update
1 parent 7d9d1dc commit e9fccb6

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,34 @@ def update_torch_dtype(self, torch_dtype):
139139
return torch_dtype
140140

141141
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
142+
quant_type = self.quantization_config.quant_type
143+
144+
if quant_type.startswith("int8") or quant_type.startswith("int4"):
145+
# Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8
146+
return torch.int8
147+
elif quant_type == "uintx_weight_only":
148+
return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
149+
elif quant_type.startswith("uint"):
150+
return {
151+
1: torch.uint1,
152+
2: torch.uint2,
153+
3: torch.uint3,
154+
4: torch.uint4,
155+
5: torch.uint5,
156+
6: torch.uint6,
157+
7: torch.uint7,
158+
}[int(quant_type[4])]
159+
elif quant_type.startswith("float") or quant_type.startswith("fp"):
160+
return torch.bfloat16
161+
142162
if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION):
143163
return target_dtype
144164

145165
# We need one of the supported dtypes to be selected in order for accelerate to determine
146-
# the total size of modules/parameters for auto device placement. This method will not be
147-
# called when device_map is not "auto".
166+
# the total size of modules/parameters for auto device placement.
167+
possible_device_maps = ["auto", "balanced", "balanced_low_0", "sequential"]
148168
raise ValueError(
149-
f"You are using `device_map='auto'` on a TorchAO quantized model but a suitable target dtype "
169+
f"You have set `device_map` as one of {possible_device_maps} on a TorchAO quantized model but a suitable target dtype "
150170
f"could not be inferred. The supported target_dtypes are: {SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION}. If you think the "
151171
f"dtype you are using should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
152172
)

0 commit comments

Comments
 (0)