Skip to content

Commit 7022169

Browse files
committed
up
1 parent 8e1ea00 commit 7022169

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929
logger = logging.get_logger(__name__)
3030

3131

32-
class QuantoQuantizer(DiffusersQuantizer):
32+
class NunChakuQuantizer(DiffusersQuantizer):
3333
r"""
34-
Diffusers Quantizer for Optimum Quanto
34+
Diffusers Quantizer for Nunchaku (https://github.com/nunchaku-tech/nunchaku)
3535
"""
3636

3737
use_keep_in_fp32_modules = True
@@ -68,7 +68,7 @@ def validate_environment(self, *args, **kwargs):
6868
# device_map = kwargs.get("device_map", None)
6969
# if isinstance(device_map, dict) and len(device_map.keys()) > 1:
7070
# raise ValueError(
71-
# "`device_map` for multi-GPU inference or CPU/disk offload is currently not supported with Diffusers and the Quanto backend"
71+
# "`device_map` for multi-GPU inference or CPU/disk offload is currently not supported with Diffusers and the nunchaku backend"
7272
# )
7373

7474
def check_if_quantized_param(
@@ -79,7 +79,6 @@ def check_if_quantized_param(
7979
state_dict: Dict[str, Any],
8080
**kwargs,
8181
):
82-
# Quanto imports diffusers internally. This is here to prevent circular imports
8382
from nunchaku.models.linear import SVDQW4A4Linear
8483

8584
module, tensor_name = get_module_from_name(model, param_name)
@@ -140,7 +139,7 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
140139

141140
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
142141
if torch_dtype is None:
143-
# We force the `dtype` to be bfloat16, this is a requirement from `bitsandbytes`
142+
# We force the `dtype` to be bfloat16, this is a requirement from `nunchaku`
144143
logger.info(
145144
"Overriding torch_dtype=%s with `torch_dtype=torch.bfloat16` due to "
146145
"requirements of `nunchaku` to enable model loading in 4-bit. "

0 commit comments

Comments
 (0)