Skip to content

Commit 9e0caa7

Browse files
committed
up
1 parent 269813f commit 9e0caa7

File tree

3 files changed

+37
-9
lines changed

3 files changed

+37
-9
lines changed

src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,13 +169,16 @@ def _process_model_before_weight_loading(
169169
model,
170170
modules_to_not_convert=self.modules_to_not_convert,
171171
quantization_config=self.quantization_config,
172-
pre_quantized=self.pre_quantized,
173172
)
174173
model.config.quantization_config = self.quantization_config
175174

176175
def _process_model_after_weight_loading(self, model, **kwargs):
177176
return model
178177

179-
# @property
180-
# def is_serializable(self):
181-
# return True
178+
@property
179+
def is_serializable(self):
180+
return False
181+
182+
@property
183+
def is_trainable(self):
184+
return False

src/diffusers/quantizers/nunchaku/utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,15 @@
55

66
if is_accelerate_available():
77
from accelerate import init_empty_weights
8-
9-
if is_nunchaku_available():
10-
from nunchaku.models.linear import SVDQW4A4Linear
8+
119

1210

1311
logger = logging.get_logger(__name__)
1412

1513

1614
def _replace_with_nunchaku_linear(
1715
model,
16+
svdq_linear_cls,
1817
modules_to_not_convert=None,
1918
current_key_name=None,
2019
quantization_config=None,
@@ -36,7 +35,7 @@ def _replace_with_nunchaku_linear(
3635
out_features = module.out_features
3736

3837
if quantization_config.precision in ["int4", "nvfp4"]:
39-
model._modules[name] = SVDQW4A4Linear(
38+
model._modules[name] = svdq_linear_cls(
4039
in_features,
4140
out_features,
4241
rank=quantization_config.rank,
@@ -62,7 +61,10 @@ def _replace_with_nunchaku_linear(
6261

6362

6463
def replace_with_nunchaku_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
65-
model, _ = _replace_with_nunchaku_linear(model, modules_to_not_convert, current_key_name, quantization_config)
64+
if is_nunchaku_available():
65+
from nunchaku.models.linear import SVDQW4A4Linear
66+
67+
model, _ = _replace_with_nunchaku_linear(model, SVDQW4A4Linear, modules_to_not_convert, current_key_name, quantization_config)
6668

6769
has_been_replaced = any(
6870
isinstance(replaced_module, SVDQW4A4Linear) for _, replaced_module in model.named_modules()

src/diffusers/quantizers/quantization_config.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,3 +762,26 @@ def post_init(self):
762762
accpeted_precision = ["int4", "nvfp4"]
763763
if self.precision not in accpeted_precision:
764764
raise ValueError(f"Only supported precision in {accpeted_precision} but found {self.precision}")
765+
766+
# Copied from diffusers.quantizers.bitsandbytes.quantization_config.BitsandBytesConfig.to_diff_dict with BitsandBytesConfig->NunchakuConfig
767+
def to_diff_dict(self) -> Dict[str, Any]:
768+
"""
769+
Removes all attributes from config which correspond to the default config attributes for better readability and
770+
serializes to a Python dictionary.
771+
772+
Returns:
773+
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
774+
"""
775+
config_dict = self.to_dict()
776+
777+
# get the default config dict
778+
default_config_dict = NunchakuConfig().to_dict()
779+
780+
serializable_config_dict = {}
781+
782+
# only serialize values that differ from the default config
783+
for key, value in config_dict.items():
784+
if value != default_config_dict[key]:
785+
serializable_config_dict[key] = value
786+
787+
return serializable_config_dict

0 commit comments

Comments
 (0)