Skip to content

Commit 5d08150

Browse files
committed
up
1 parent 9e0caa7 commit 5d08150

File tree

3 files changed

+13
-3
lines changed

3 files changed

+13
-3
lines changed

src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,11 @@ def _process_model_before_weight_loading(
164164
self.modules_to_not_convert = [self.modules_to_not_convert]
165165

166166
self.modules_to_not_convert.extend(keep_in_fp32_modules)
167+
# Purge `None`.
168+
# Unlike `transformers`, we don't know if we should always keep certain modules in FP32
169+
# in case of diffusion transformer models. For language models and others alike, `lm_head`
170+
# and tied modules are usually kept in FP32.
171+
self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
167172

168173
model = replace_with_nunchaku_linear(
169174
model,

src/diffusers/quantizers/nunchaku/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
if is_accelerate_available():
77
from accelerate import init_empty_weights
8-
98

109

1110
logger = logging.get_logger(__name__)
@@ -40,7 +39,7 @@ def _replace_with_nunchaku_linear(
4039
out_features,
4140
rank=quantization_config.rank,
4241
bias=module.bias is not None,
43-
dtype=model.dtype,
42+
torch_dtype=module.weight.dtype,
4443
)
4544
has_been_replaced = True
4645
# Store the module class in case we need to transpose the weight later
@@ -50,6 +49,7 @@ def _replace_with_nunchaku_linear(
5049
if len(list(module.children())) > 0:
5150
_, has_been_replaced = _replace_with_nunchaku_linear(
5251
module,
52+
svdq_linear_cls,
5353
modules_to_not_convert,
5454
current_key_name,
5555
quantization_config,
@@ -64,7 +64,9 @@ def replace_with_nunchaku_linear(model, modules_to_not_convert=None, current_key
6464
if is_nunchaku_available():
6565
from nunchaku.models.linear import SVDQW4A4Linear
6666

67-
model, _ = _replace_with_nunchaku_linear(model, SVDQW4A4Linear, modules_to_not_convert, current_key_name, quantization_config)
67+
model, _ = _replace_with_nunchaku_linear(
68+
model, SVDQW4A4Linear, modules_to_not_convert, current_key_name, quantization_config
69+
)
6870

6971
has_been_replaced = any(
7072
isinstance(replaced_module, SVDQW4A4Linear) for _, replaced_module in model.named_modules()

src/diffusers/quantizers/quantization_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,7 @@ def __init__(
750750
):
751751
self.quant_method = QuantizationMethod.NUNCHAKU
752752
self.precision = precision
753+
self.rank = rank
753754
self.group_size = self.group_size_map[precision]
754755
self.modules_to_not_convert = modules_to_not_convert
755756

@@ -763,6 +764,8 @@ def post_init(self):
763764
if self.precision not in accpeted_precision:
764765
raise ValueError(f"Only supported precision in {accpeted_precision} but found {self.precision}")
765766

767+
# TODO: should there be a check for rank?
768+
766769
# Copied from diffusers.quantizers.bitsandbytes.quantization_config.BitsandBytesConfig.to_diff_dict with BitsandBytesConfig->NunchakuConfig
767770
def to_diff_dict(self) -> Dict[str, Any]:
768771
"""

0 commit comments

Comments
 (0)