Skip to content

Commit c9ff4d1

Browse files
updated the interface as per comments
1 parent 085923e commit c9ff4d1

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

keras/src/models/model.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from keras.src.api_export import keras_export
99
from keras.src.layers.layer import Layer
1010
from keras.src.models.variable_mapping import map_saveable_variables
11+
from keras.src.quantizers.gptq_config import GPTQConfig
1112
from keras.src.saving import saving_api
1213
from keras.src.trainers import trainer as base_trainer
1314
from keras.src.utils import summary_utils
@@ -420,7 +421,7 @@ def load_weights(self, filepath, skip_mismatch=False, **kwargs):
420421
**kwargs,
421422
)
422423

423-
def quantize(self, mode, **kwargs):
424+
def quantize(self, mode, config=None, **kwargs):
424425
"""Quantize the weights of the model.
425426
426427
Note that the model must be built first before calling this method.
@@ -434,19 +435,22 @@ def quantize(self, mode, **kwargs):
434435
from keras.src.dtype_policies import QUANTIZATION_MODES
435436

436437
if mode == "gptq":
437-
from keras.src.quantizers.gptq_config import GPTQConfig
438-
439-
config = kwargs.get("quant_config")
440438
if not isinstance(config, GPTQConfig):
441439
raise TypeError(
442-
"When using 'gptq' mode, you must pass a `quant_config` "
443-
"keyword argument of type `keras.quantizers.GPTQConfig`."
440+
"When using 'gptq' mode, you must pass a `config` "
441+
"argument of type `keras.quantizers.GPTQConfig`."
444442
)
445-
446-
# The config object's own quantize method drives the process.
443+
# The config object's own quantize method drives the process
447444
config.quantize(self)
448445
return
449446

447+
# For all other modes, verify that a config object was not passed.
448+
if config is not None:
449+
raise ValueError(
450+
f"The `config` argument is only supported for 'gptq' mode, "
451+
f"but received mode='{mode}'."
452+
)
453+
450454
type_check = kwargs.pop("type_check", True)
451455
if kwargs:
452456
raise ValueError(

keras/src/models/model_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1334,7 +1334,7 @@ def _run_gptq_test_on_dataset(self, dataset, **config_kwargs):
13341334
final_config = {**base_config, **config_kwargs}
13351335
gptq_config = GPTQConfig(**final_config)
13361336

1337-
model.quantize("gptq", quant_config=gptq_config)
1337+
model.quantize("gptq", config=gptq_config)
13381338

13391339
# Assertions and verification
13401340
quantized_weights = target_layer.kernel

0 commit comments

Comments
 (0)