Skip to content

Commit 8cef1c7

Browse files
fix comments and added additional test file
1 parent 248570d commit 8cef1c7

File tree

6 files changed

+226
-73
lines changed

6 files changed

+226
-73
lines changed

keras/src/models/model_test.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,7 +1298,7 @@ def _run_gptq_test_on_dataset(self, dataset):
12981298
test on a given dataset."""
12991299

13001300
model = _get_model_with_dense_attention()
1301-
1301+
rng = np.random.default_rng(seed=42)
13021302
# 1. Common setup
13031303
NUM_SAMPLES = 16
13041304
SEQUENCE_LENGTH = 128
@@ -1328,25 +1328,26 @@ def _run_gptq_test_on_dataset(self, dataset):
13281328
wbits=W_BITS,
13291329
nsamples=NUM_SAMPLES,
13301330
seqlen=SEQUENCE_LENGTH,
1331-
groupsize=GROUP_SIZE,
1331+
group_size=GROUP_SIZE,
13321332
)
13331333
model.quantize("gptq", quant_config=gptq_config)
13341334

13351335
# 4. Assertions and verification
13361336
quantized_weights = target_layer.kernel.numpy()
13371337

13381338
# Assert that the weights have been changed
1339-
self.assertFalse(
1340-
np.allclose(original_weights, quantized_weights),
1341-
"Weights were not changed by the GPTQ process for "
1339+
self.assertNotAllClose(
1340+
original_weights,
1341+
quantized_weights,
1342+
msg="Weights were not changed by the GPTQ process for "
13421343
"dataset: {dataset}",
13431344
)
13441345

13451346
# Verify the quantized model can still make a prediction
1346-
dummy_input = np.random.randint(
1347-
0, VOCAB_SIZE, size=(1, SEQUENCE_LENGTH)
1347+
dummy_sample = rng.integers(
1348+
low=0, high=VOCAB_SIZE, size=(1, SEQUENCE_LENGTH)
13481349
)
1349-
_ = model.predict(dummy_input)
1350+
_ = model.predict(dummy_sample)
13501351

13511352
def test_quantize_gptq_on_different_datasets(self):
13521353
"""Tests GPTQ with various dataset types (string list, generator)."""

keras/src/quantizers/gptq.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
class GPTQ:
88
def __init__(self, layer):
99
self.original_layer = layer
10-
self.kernel_shape = layer.kernel.shape
1110
self.nsamples = 0
1211
self.quantizer = None
1312

@@ -16,13 +15,15 @@ def __init__(self, layer):
1615
isinstance(layer, EinsumDense) and layer.kernel.ndim == 2
1716
):
1817
# For a standard Dense layer, the dimensions are straightforward.
18+
self.kernel_shape = layer.kernel.shape
1919
self.rows = self.kernel_shape[0] # Input features
2020
self.columns = self.kernel_shape[1] # Output features
2121
self.layer = layer # The layer itself can be used directly.
2222

2323
# Handle 3D EinsumDense layers (typically from attention blocks).
2424
elif isinstance(layer, EinsumDense) and layer.kernel.ndim == 3:
2525
# For EinsumDense, we determine the effective 2D dimensions.
26+
self.kernel_shape = layer.kernel.shape
2627
shape = list(self.kernel_shape)
2728
try:
2829
d_model_dim_index = shape.index(max(shape))
@@ -53,10 +54,7 @@ def __init__(self, layer):
5354

5455
else:
5556
# Raise an error if the layer is not supported.
56-
raise TypeError(
57-
f"Unsupported layer type or kernel shape for GPTQ: "
58-
f"{type(layer)} with kernel ndim {layer.kernel.ndim}"
59-
)
57+
raise TypeError(f"Unsupported layer type for GPTQ: {type(layer)}")
6058
self.H = ops.zeros((self.rows, self.rows), dtype="float32")
6159

6260
def update_hessian_with_batch(self, inp):
@@ -81,6 +79,17 @@ def update_hessian_with_batch(self, inp):
8179
not match the dimensions of the pre-initialized Hessian matrix
8280
`self.H`.
8381
"""
82+
if inp is None:
83+
raise ValueError("Input tensor 'inp' cannot be None.")
84+
85+
if len(inp.shape) < 2:
86+
raise ValueError(
87+
f"Input tensor 'inp' must have a rank of at least 2 "
88+
f"(e.g., [batch, features]), but got rank {len(inp.shape)}."
89+
)
90+
if ops.size(inp) == 0:
91+
raise ValueError("Input tensor 'inp' cannot be empty.")
92+
8493
if len(inp.shape) > 2:
8594
inp = ops.reshape(inp, (-1, inp.shape[-1]))
8695
inp = ops.cast(inp, "float32")
@@ -103,7 +112,7 @@ def update_hessian_with_batch(self, inp):
103112
self.nsamples += inp.shape[0]
104113

105114
def quantize_and_correct_block(
106-
self, blocksize=128, percdamp=0.01, groupsize=-1, actorder=False
115+
self, blocksize=128, percdamp=0.01, group_size=-1, actorder=False
107116
):
108117
"""
109118
Performs GPTQ quantization and correction on the layer's weights.
@@ -143,7 +152,7 @@ def quantize_and_correct_block(
143152
percdamp (float, optional): The percentage of dampening to add the
144153
Hessian's diagonal. A value of 0.01 is recommended.
145154
Defaults to 0.01.
146-
groupsize (int, optional): The number of weights that share the
155+
group_size (int, optional): The number of weights that share the
147156
same quantization parameters (scale and zero-point).
148157
A value of -1 indicates per-channel quantization.
149158
actorder (bool, optional): If True, reorders weight columns based
@@ -189,10 +198,10 @@ def quantize_and_correct_block(
189198
w = W1[:, i]
190199
d = Hinv1[i, i]
191200

192-
if groupsize != -1:
193-
if (i1 + i) % groupsize == 0:
201+
if group_size != -1:
202+
if (i1 + i) % group_size == 0:
194203
self.quantizer.find_params(
195-
W[:, (i1 + i) : (i1 + i + groupsize)], weight=True
204+
W[:, (i1 + i) : (i1 + i + group_size)], weight=True
196205
)
197206
else:
198207
self.quantizer.find_params(

keras/src/quantizers/gptq_config.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,45 +6,45 @@
66

77
@keras_export(["keras.GPTQConfig", "keras.quantizers.GPTQConfig"])
88
class GPTQConfig:
9-
"""
10-
Configuration class for the GPTQ (Generative Pre-trained Transformer
11-
Quantization) algorithm.
9+
"""Configuration class for the GPTQ algorithm.
1210
1311
This class holds all the parameters needed to apply the GPTQ method
14-
to a model. Its attributes are based on the original command-line
15-
arguments from the research repository's `opt.py` script.
12+
to a model.
1613
1714
Args:
18-
dataset (str): Path to the calibration dataset.
19-
wbits (int, optional): The number of bits to quantize the weights to.
15+
dataset: The calibration dataset. It can be an iterable that yields
16+
strings or pre-tokenized numerical tensors (e.g., a list of
17+
strings, a generator, or a NumPy array). This data is used to
18+
analyze the model's activations.
19+
tokenizer: A `keras_nlp.Tokenizer` instance (or a similar callable)
20+
that is used to process the `dataset` if it contains strings.
21+
wbits (int, optional): The number of bits to quantize weights to.
2022
Defaults to 4.
2123
nsamples (int, optional): The number of calibration data samples to
22-
use. Defaults to 128.
23-
seqlen (int, optional): The sequence length to use for calibration.
24-
Defaults to 512.
25-
percdamp (float, optional): The percentage of Hessian damping to use.
26-
Defaults to 0.01.
27-
groupsize (int, optional): The size of the group of weights to
28-
quantize together.A groupsize of
29-
-1 means quantization is done per-column.
30-
Defaults to 128.
31-
symmetric (bool, optional): If True, uses symmetric quantization.
32-
If False,uses asymmetric quantization.
33-
Defaults to False.
34-
act_order (bool, optional): If True, quantizes columns in order of
35-
decreasing activation size.
36-
Defaults to False.
24+
use from the dataset. Defaults to 128.
25+
seqlen (int, optional): The sequence length to use for each calibration
26+
sample. Defaults to 512.
27+
percdamp (float, optional): The % of Hessian damping to use for
28+
stabilization during inverse calculation. Defaults to 0.01.
29+
group_size (int, optional): The size of weight groups to quantize
30+
together. A `group_size` of -1 indicates per-channel quantization.
31+
Defaults to 128.
32+
symmetric (bool, optional): If `True`, uses symmetric quantization.
33+
If `False`, uses asymmetric quantization. Defaults to `False`.
34+
act_order (bool, optional): If `True`, reorders weight columns based on
35+
activation magnitude, which can improve quantization accuracy.
36+
Defaults to `False`.
3737
"""
3838

3939
def __init__(
4040
self,
4141
dataset,
42-
tokenizer: str,
42+
tokenizer,
4343
wbits: int = 4,
4444
nsamples: int = 128,
4545
seqlen: int = 512,
4646
percdamp: float = 0.01,
47-
groupsize: int = 128,
47+
group_size: int = 128,
4848
symmetric: bool = False,
4949
act_order: bool = False,
5050
):
@@ -54,7 +54,7 @@ def __init__(
5454
self.seqlen = seqlen
5555
self.percdamp = percdamp
5656
self.wbits = wbits
57-
self.groupsize = groupsize
57+
self.group_size = group_size
5858
self.symmetric = symmetric
5959
self.act_order = act_order
6060
self.quantization_method = "gptq"

keras/src/quantizers/gptq_test.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import numpy as np
2+
import pytest
3+
4+
from keras.src import layers
5+
from keras.src import ops
6+
from keras.src import testing
7+
from keras.src.quantizers.gptq import GPTQ
8+
from keras.src.quantizers.gptqquant import GPTQQuant
9+
10+
11+
def _get_mock_layer(layer_type, kernel_shape, rng):
12+
if layer_type == "Dense":
13+
layer = layers.Dense(units=kernel_shape[1])
14+
layer.build(input_shape=(None, kernel_shape[0]))
15+
elif layer_type == "EinsumDense":
16+
output_shape = (kernel_shape[1], kernel_shape[2])
17+
layer = layers.EinsumDense(
18+
equation="...h,hio->...io", output_shape=output_shape
19+
)
20+
dummy_input = rng.standard_normal(size=(1, 1, kernel_shape[0]))
21+
layer(dummy_input)
22+
layer.kernel.assign(
23+
rng.standard_normal(size=kernel_shape).astype("float32")
24+
)
25+
else:
26+
layer = layers.Layer()
27+
return layer
28+
29+
30+
@pytest.mark.requires_trainable_backend
31+
class GPTQTest(testing.TestCase):
32+
def test_initialization_with_dense_layer(self):
33+
rng = np.random.default_rng(seed=42)
34+
35+
mock_layer = _get_mock_layer("Dense", kernel_shape=(64, 128), rng=rng)
36+
37+
gptq_instance = GPTQ(mock_layer)
38+
self.assertEqual(gptq_instance.rows, 64)
39+
self.assertEqual(gptq_instance.columns, 128)
40+
self.assertEqual(gptq_instance.H.shape, (64, 64))
41+
42+
def test_initialization_with_einsumdense_3d(self):
43+
rng = np.random.default_rng(seed=42)
44+
mock_layer = _get_mock_layer(
45+
"EinsumDense", kernel_shape=(64, 4, 32), rng=rng
46+
)
47+
gptq_instance = GPTQ(mock_layer)
48+
self.assertEqual(gptq_instance.rows, 64)
49+
self.assertEqual(gptq_instance.columns, 4 * 32)
50+
self.assertEqual(gptq_instance.H.shape, (64, 64))
51+
52+
def test_update_hessian(self):
53+
rng = np.random.default_rng(seed=42)
54+
mock_layer = _get_mock_layer("Dense", kernel_shape=(16, 32), rng=rng)
55+
gptq_instance = GPTQ(mock_layer)
56+
batch1 = rng.standard_normal(size=(8, 16)).astype("float32")
57+
gptq_instance.update_hessian_with_batch(batch1)
58+
self.assertEqual(gptq_instance.nsamples, 8)
59+
H1 = np.copy(ops.convert_to_numpy(gptq_instance.H))
60+
batch2 = rng.standard_normal(size=(4, 16)).astype("float32")
61+
gptq_instance.update_hessian_with_batch(batch2)
62+
self.assertEqual(gptq_instance.nsamples, 12)
63+
H2 = np.copy(ops.convert_to_numpy(gptq_instance.H))
64+
self.assertFalse(np.allclose(H1, H2))
65+
66+
def test_full_quantization_process(self):
67+
rng = np.random.default_rng(seed=42)
68+
mock_layer = _get_mock_layer("Dense", kernel_shape=(16, 32), rng=rng)
69+
original_weights = np.copy(ops.convert_to_numpy(mock_layer.kernel))
70+
71+
gptq_instance = GPTQ(mock_layer)
72+
gptq_instance.quantizer = GPTQQuant()
73+
gptq_instance.quantizer.configure(wbits=4, symmetric=False)
74+
75+
calibration_data = rng.standard_normal(size=(128, 16)).astype("float32")
76+
gptq_instance.update_hessian_with_batch(calibration_data)
77+
gptq_instance.quantize_and_correct_block()
78+
79+
quantized_weights = ops.convert_to_numpy(mock_layer.kernel)
80+
self.assertFalse(np.allclose(original_weights, quantized_weights))
81+
82+
gptq_instance.free()
83+
self.assertIsNone(gptq_instance.H)
84+
85+
def test_unsupported_layer_error(self):
86+
rng = np.random.default_rng(seed=42)
87+
unsupported_layer = _get_mock_layer(
88+
"Unsupported", kernel_shape=None, rng=rng
89+
)
90+
with self.assertRaisesRegex(TypeError, "Unsupported layer type"):
91+
GPTQ(unsupported_layer)
92+
93+
def test_update_hessian_invalid_input(self):
94+
rng = np.random.default_rng(seed=42)
95+
mock_layer = _get_mock_layer("Dense", kernel_shape=(16, 32), rng=rng)
96+
gptq_instance = GPTQ(mock_layer)
97+
with self.assertRaisesRegex(ValueError, "cannot be None"):
98+
gptq_instance.update_hessian_with_batch(None)
99+
with self.assertRaisesRegex(ValueError, "cannot be empty"):
100+
gptq_instance.update_hessian_with_batch(np.empty((0, 16)))
101+
with self.assertRaisesRegex(ValueError, "match input features"):
102+
bad_input = rng.standard_normal(size=(8, 99))
103+
gptq_instance.update_hessian_with_batch(bad_input)

0 commit comments

Comments
 (0)