8
8
from keras .src .api_export import keras_export
9
9
from keras .src .layers .layer import Layer
10
10
from keras .src .models .variable_mapping import map_saveable_variables
11
+ from keras .src .quantizers .gptq_config import GPTQConfig
11
12
from keras .src .saving import saving_api
12
13
from keras .src .trainers import trainer as base_trainer
13
14
from keras .src .utils import summary_utils
@@ -420,7 +421,7 @@ def load_weights(self, filepath, skip_mismatch=False, **kwargs):
420
421
** kwargs ,
421
422
)
422
423
423
- def quantize (self , mode , ** kwargs ):
424
+ def quantize (self , mode , config = None , ** kwargs ):
424
425
"""Quantize the weights of the model.
425
426
426
427
Note that the model must be built first before calling this method.
@@ -434,19 +435,22 @@ def quantize(self, mode, **kwargs):
434
435
from keras .src .dtype_policies import QUANTIZATION_MODES
435
436
436
437
if mode == "gptq" :
437
- from keras .src .quantizers .gptq_config import GPTQConfig
438
-
439
- config = kwargs .get ("quant_config" )
440
438
if not isinstance (config , GPTQConfig ):
441
439
raise TypeError (
442
- "When using 'gptq' mode, you must pass a `quant_config ` "
443
- "keyword argument of type `keras.quantizers.GPTQConfig`."
440
+ "When using 'gptq' mode, you must pass a `config ` "
441
+ "argument of type `keras.quantizers.GPTQConfig`."
444
442
)
445
-
446
- # The config object's own quantize method drives the process.
443
+ # The config object's own quantize method drives the process
447
444
config .quantize (self )
448
445
return
449
446
447
+ # For all other modes, verify that a config object was not passed.
448
+ if config is not None :
449
+ raise ValueError (
450
+ f"The `config` argument is only supported for 'gptq' mode, "
451
+ f"but received mode='{ mode } '."
452
+ )
453
+
450
454
type_check = kwargs .pop ("type_check" , True )
451
455
if kwargs :
452
456
raise ValueError (
0 commit comments