Skip to content

Commit 364d4da

Browse files
reworke on review comments
1 parent 7ad29c6 commit 364d4da

File tree

10 files changed

+9
-17
lines changed

10 files changed

+9
-17
lines changed

keras/api/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@
6060
from keras.src.ops.function import Function as Function
6161
from keras.src.ops.operation import Operation as Operation
6262
from keras.src.optimizers.optimizer import Optimizer as Optimizer
63-
from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
6463
from keras.src.quantizers.quantizers import Quantizer as Quantizer
6564
from keras.src.regularizers.regularizers import Regularizer as Regularizer
6665
from keras.src.version import __version__ as __version__

keras/api/_tf_keras/keras/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
from keras.src.ops.function import Function as Function
5959
from keras.src.ops.operation import Operation as Operation
6060
from keras.src.optimizers.optimizer import Optimizer as Optimizer
61-
from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
6261
from keras.src.quantizers.quantizers import Quantizer as Quantizer
6362
from keras.src.regularizers.regularizers import Regularizer as Regularizer
6463
from keras.src.version import __version__ as __version__

keras/api/_tf_keras/keras/quantizers/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from keras.src.quantizers import deserialize as deserialize
88
from keras.src.quantizers import get as get
99
from keras.src.quantizers import serialize as serialize
10-
from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
1110
from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer
1211
from keras.src.quantizers.quantizers import Quantizer as Quantizer
1312
from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize

keras/api/quantizers/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from keras.src.quantizers import deserialize as deserialize
88
from keras.src.quantizers import get as get
99
from keras.src.quantizers import serialize as serialize
10-
from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
1110
from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer
1211
from keras.src.quantizers.quantizers import Quantizer as Quantizer
1312
from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize

keras/src/models/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,8 @@ def quantize(self, mode, config=None, **kwargs):
438438
if not isinstance(config, GPTQConfig):
439439
raise TypeError(
440440
"When using 'gptq' mode, you must pass a `config` "
441-
"argument of type `keras.quantizers.GPTQConfig`."
441+
"argument of type "
442+
"`keras.quantizers.gptq_config.GPTQConfig`."
442443
)
443444
# The config object's own quantize method drives the process
444445
config.quantize(self)

keras/src/models/model_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88
from absl.testing import parameterized
99

10+
from keras.quantizers import GPTQConfig
1011
from keras.src import backend
1112
from keras.src import layers
1213
from keras.src import losses
@@ -17,7 +18,6 @@
1718
from keras.src.models.functional import Functional
1819
from keras.src.models.model import Model
1920
from keras.src.models.model import model_from_json
20-
from keras.src.quantizers.gptq_config import GPTQConfig
2121

2222

2323
def _get_model():

keras/src/quantizers/gptq_config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
from absl import logging
22

3-
from keras.src.api_export import keras_export
43
from keras.src.quantizers.gptq_core import quantize_model
54

65

7-
@keras_export(["keras.GPTQConfig", "keras.quantizers.GPTQConfig"])
86
class GPTQConfig:
97
"""Configuration class for the GPTQ algorithm.
108
@@ -57,7 +55,6 @@ def __init__(
5755
self.group_size = group_size
5856
self.symmetric = symmetric
5957
self.act_order = act_order
60-
self.quantization_method = "gptq"
6158

6259
def quantize(self, model):
6360
"""

keras/src/quantizers/gptq_core.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from keras.src.layers import EinsumDense
1010
from keras.src.layers import Embedding
1111
from keras.src.quantizers.gptq import GPTQ
12-
from keras.src.quantizers.gptq_quant import GPTQQuant
12+
from keras.src.quantizers.gptq_quant import GPTQQuantization
1313

1414

1515
def get_dataloader(tokenizer, seqlen, dataset, nsamples=128):
@@ -271,7 +271,7 @@ def hook(*args, **kwargs):
271271
inp_reshaped = ops.reshape(layer_inputs, (-1, num_features))
272272
gptq_object.update_hessian_with_batch(inp_reshaped)
273273

274-
quantizer = GPTQQuant(
274+
quantizer = GPTQQuantization(
275275
wbits,
276276
perchannel=True,
277277
symmetric=symmetric,
@@ -331,5 +331,3 @@ def quantize_model(model, config):
331331
config.act_order,
332332
config.wbits,
333333
)
334-
335-
return

keras/src/quantizers/gptq_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ def dequantize(x, scale, zero, maxq):
1515
return ops.multiply(scale, dequantized_x)
1616

1717

18-
class GPTQQuant:
19-
"""Initializes the GPTQQuant state.
18+
class GPTQQuantization:
19+
"""Initializes the GPTQQuantization state.
2020
2121
Args:
2222
shape (int, optional): This argument is currently unused.

keras/src/quantizers/gptq_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from keras.src import ops
66
from keras.src import testing
77
from keras.src.quantizers.gptq import GPTQ
8-
from keras.src.quantizers.gptq_quant import GPTQQuant
8+
from keras.src.quantizers.gptq_quant import GPTQQuantization
99

1010

1111
def _get_mock_layer(layer_type, kernel_shape, rng):
@@ -69,7 +69,7 @@ def test_full_quantization_process(self):
6969
original_weights = np.copy(ops.convert_to_numpy(mock_layer.kernel))
7070

7171
gptq_instance = GPTQ(mock_layer)
72-
gptq_instance.quantizer = GPTQQuant(wbits=4, symmetric=False)
72+
gptq_instance.quantizer = GPTQQuantization(wbits=4, symmetric=False)
7373
calibration_data = rng.standard_normal(size=(128, 16)).astype("float32")
7474
gptq_instance.update_hessian_with_batch(calibration_data)
7575
gptq_instance.quantize_and_correct_block()

0 commit comments

Comments
 (0)