2929logger = logging .get_logger (__name__ )
3030
3131
32- class QuantoQuantizer (DiffusersQuantizer ):
32+ class NunChakuQuantizer (DiffusersQuantizer ):
3333 r"""
34- Diffusers Quantizer for Optimum Quanto
34+ Diffusers Quantizer for Nunchaku (https://github.com/nunchaku-tech/nunchaku)
3535 """
3636
3737 use_keep_in_fp32_modules = True
@@ -68,7 +68,7 @@ def validate_environment(self, *args, **kwargs):
6868 # device_map = kwargs.get("device_map", None)
6969 # if isinstance(device_map, dict) and len(device_map.keys()) > 1:
7070 # raise ValueError(
71- # "`device_map` for multi-GPU inference or CPU/disk offload is currently not supported with Diffusers and the Quanto backend"
71+ # "`device_map` for multi-GPU inference or CPU/disk offload is currently not supported with Diffusers and the nunchaku backend"
7272 # )
7373
7474 def check_if_quantized_param (
@@ -79,7 +79,6 @@ def check_if_quantized_param(
7979 state_dict : Dict [str , Any ],
8080 ** kwargs ,
8181 ):
82- # Quanto imports diffusers internally. This is here to prevent circular imports
8382 from nunchaku .models .linear import SVDQW4A4Linear
8483
8584 module , tensor_name = get_module_from_name (model , param_name )
@@ -140,7 +139,7 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
140139
141140 def update_torch_dtype (self , torch_dtype : "torch.dtype" ) -> "torch.dtype" :
142141 if torch_dtype is None :
143- # We force the `dtype` to be bfloat16, this is a requirement from `bitsandbytes `
142+ # We force the `dtype` to be bfloat16, this is a requirement from `nunchaku `
144143 logger .info (
145144 "Overriding torch_dtype=%s with `torch_dtype=torch.bfloat16` due to "
146145 "requirements of `nunchaku` to enable model loading in 4-bit. "
0 commit comments