Skip to content

Commit bd57ff1

Browse files
Reworked on review comments
1 parent a16d669 commit bd57ff1

File tree

8 files changed

+17
-10
lines changed

8 files changed

+17
-10
lines changed

keras/api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from keras.src.ops.function import Function as Function
6161
from keras.src.ops.operation import Operation as Operation
6262
from keras.src.optimizers.optimizer import Optimizer as Optimizer
63+
from keras.src.quantizers.gptqconfig import GPTQConfig as GPTQConfig
6364
from keras.src.quantizers.quantizers import Quantizer as Quantizer
6465
from keras.src.regularizers.regularizers import Regularizer as Regularizer
6566
from keras.src.version import __version__ as __version__

keras/api/_tf_keras/keras/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from keras.src.ops.function import Function as Function
5959
from keras.src.ops.operation import Operation as Operation
6060
from keras.src.optimizers.optimizer import Optimizer as Optimizer
61+
from keras.src.quantizers.gptqconfig import GPTQConfig as GPTQConfig
6162
from keras.src.quantizers.quantizers import Quantizer as Quantizer
6263
from keras.src.regularizers.regularizers import Regularizer as Regularizer
6364
from keras.src.version import __version__ as __version__

keras/api/_tf_keras/keras/quantizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from keras.src.quantizers import deserialize as deserialize
88
from keras.src.quantizers import get as get
99
from keras.src.quantizers import serialize as serialize
10+
from keras.src.quantizers.gptqconfig import GPTQConfig as GPTQConfig
1011
from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer
1112
from keras.src.quantizers.quantizers import Quantizer as Quantizer
1213
from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize

keras/api/quantizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from keras.src.quantizers import deserialize as deserialize
88
from keras.src.quantizers import get as get
99
from keras.src.quantizers import serialize as serialize
10+
from keras.src.quantizers.gptqconfig import GPTQConfig as GPTQConfig
1011
from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer
1112
from keras.src.quantizers.quantizers import Quantizer as Quantizer
1213
from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize

keras/src/quantizers/gptq.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from keras.src import ops
22
from keras.src.layers import Dense
33
from keras.src.layers import EinsumDense
4-
5-
from .gptqquant import quantize
4+
from keras.src.quantizers.gptqquant import dequantize
65

76

87
class GPTQ:
@@ -201,7 +200,7 @@ def quantize_and_correct_block(
201200
)
202201

203202
# Quantize the current weight column
204-
q = quantize(
203+
q = dequantize(
205204
ops.expand_dims(w, 1),
206205
self.quantizer.scale,
207206
self.quantizer.zero,

keras/src/quantizers/gptqconfig.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from absl import logging
22

3-
from .gptqutils import quantize_model
3+
from keras.src.api_export import keras_export
4+
from keras.src.quantizers.gptqutils import quantize_model
45

56

7+
@keras_export(["keras.GPTQConfig", "keras.quantizers.GPTQConfig"])
68
class GPTQConfig:
79
"""
810
Configuration class for the GPTQ (Generative Pre-trained Transformer

keras/src/quantizers/gptqquant.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from keras.src import ops
22

33

4-
def quantize(x, scale, zero, maxq):
4+
def dequantize(x, scale, zero, maxq):
55
"""The core quantization function with correct broadcasting."""
66
# Ensure scale is broadcastable with the input tensor x
77
if scale.shape != x.shape:
@@ -12,9 +12,12 @@ def quantize(x, scale, zero, maxq):
1212
zero = ops.broadcast_to(zero, x.shape)
1313

1414
scale = ops.where(ops.equal(scale, 0), 1e-8, scale)
15-
q = ops.round(x / scale) + zero
15+
quantized_x = ops.divide(x, scale)
16+
quantized_x = ops.round(quantized_x)
17+
q = ops.add(quantized_x, zero)
1618
q = ops.clip(q, 0, maxq)
17-
return scale * (q - zero)
19+
dequantized_x = ops.subtract(q, zero)
20+
return ops.multiply(scale, dequantized_x)
1821

1922

2023
class GPTQQuant:

keras/src/quantizers/gptqutils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111
from keras.src.layers import Dense
1212
from keras.src.layers import EinsumDense
1313
from keras.src.layers import Embedding
14-
15-
from .gptq import GPTQ
16-
from .gptqquant import GPTQQuant
14+
from keras.src.quantizers.gptq import GPTQ
15+
from keras.src.quantizers.gptqquant import GPTQQuant
1716

1817

1918
def get_dataloader(tokenizer, seqlen, dataset, nsamples=128, seed=0):

0 commit comments

Comments
 (0)