@@ -139,14 +139,34 @@ def update_torch_dtype(self, torch_dtype):
139139 return torch_dtype
140140
141141 def adjust_target_dtype (self , target_dtype : "torch.dtype" ) -> "torch.dtype" :
142+ quant_type = self .quantization_config .quant_type
143+
144+ if quant_type .startswith ("int8" ) or quant_type .startswith ("int4" ):
145+ # Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8
146+ return torch .int8
147+ elif quant_type == "uintx_weight_only" :
148+ return self .quantization_config .quant_type_kwargs .get ("dtype" , torch .uint8 )
149+ elif quant_type .startswith ("uint" ):
150+ return {
151+ 1 : torch .uint1 ,
152+ 2 : torch .uint2 ,
153+ 3 : torch .uint3 ,
154+ 4 : torch .uint4 ,
155+ 5 : torch .uint5 ,
156+ 6 : torch .uint6 ,
157+ 7 : torch .uint7 ,
158+ }[int (quant_type [4 ])]
159+ elif quant_type .startswith ("float" ) or quant_type .startswith ("fp" ):
160+ return torch .bfloat16
161+
142162 if isinstance (target_dtype , SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION ):
143163 return target_dtype
144164
145165 # We need one of the supported dtypes to be selected in order for accelerate to determine
146- # the total size of modules/parameters for auto device placement. This method will not be
147- # called when device_map is not "auto".
166+ # the total size of modules/parameters for auto device placement.
167+ possible_device_maps = [ "auto" , "balanced" , "balanced_low_0" , "sequential" ]
148168 raise ValueError (
149- f"You are using `device_map='auto'` on a TorchAO quantized model but a suitable target dtype "
169+ f"You have set `device_map` as one of { possible_device_maps } on a TorchAO quantized model but a suitable target dtype "
150170 f"could not be inferred. The supported target_dtypes are: { SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION } . If you think the "
151171 f"dtype you are using should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
152172 )
0 commit comments