Skip to content

Commit 66e61c5

Browse files
updated the review comments
1 parent 7a90148 commit 66e61c5

File tree

5 files changed

+187
-96
lines changed

5 files changed

+187
-96
lines changed

keras/src/models/model.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -434,14 +434,7 @@ def quantize(self, mode, **kwargs):
434434
from keras.src.dtype_policies import QUANTIZATION_MODES
435435

436436
if mode == "gptq":
437-
try:
438-
from keras.src.quantizers.gptqconfig import GPTQConfig
439-
except ImportError:
440-
raise ImportError(
441-
"To use 'gptq' mode, please ensure the necessary "
442-
"quantization modules are correctly placed in"
443-
"keras/src/quantizers."
444-
)
437+
from keras.src.quantizers.gptqconfig import GPTQConfig
445438

446439
config = kwargs.get("quant_config")
447440
if not isinstance(config, GPTQConfig):

keras/src/models/model_test.py

Lines changed: 60 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import logging
21
import os
32
import pickle
43
from collections import namedtuple
@@ -19,9 +18,6 @@
1918
from keras.src.models.model import model_from_json
2019
from keras.src.quantizers.gptqconfig import GPTQConfig
2120

22-
# Configure logging
23-
logging.basicConfig(level=logging.INFO)
24-
2521

2622
def _get_model():
2723
input_a = Input(shape=(3,), batch_size=2, name="input_a")
@@ -1295,81 +1291,67 @@ def call(self, inputs):
12951291
return model
12961292

12971293

1298-
def _run_gptq_test_on_dataset(test_case, dataset):
1299-
"""Helper function to run a full GPTQ quantization
1300-
test on a given dataset."""
1301-
model = _get_model_with_dense_attention()
1294+
@pytest.mark.requires_trainable_backend
1295+
class ModelQuantizationTest(testing.TestCase):
1296+
def _run_gptq_test_on_dataset(self, dataset):
1297+
"""Helper function to run a full GPTQ quantization
1298+
test on a given dataset."""
13021299

1303-
# --- 1. Common Setup ---
1304-
NUM_SAMPLES = 16
1305-
SEQUENCE_LENGTH = 128
1306-
VOCAB_SIZE = 1000
1307-
W_BITS = 4
1308-
GROUP_SIZE = 32
1300+
model = _get_model_with_dense_attention()
13091301

1310-
mock_tokenizer = lambda text: np.array([ord(c) % VOCAB_SIZE for c in text])
1311-
mock_tokenizer.tokenize = mock_tokenizer
1302+
# 1. Common setup
1303+
NUM_SAMPLES = 16
1304+
SEQUENCE_LENGTH = 128
1305+
VOCAB_SIZE = 1000
1306+
W_BITS = 4
1307+
GROUP_SIZE = 32
13121308

1313-
# --- 2. Find Target Layer and Get Original Weights ---
1314-
target_layer = None
1315-
for layer in model.layers:
1316-
if hasattr(layer, "ffn") and hasattr(layer.ffn, "layers"):
1317-
dense_layer_in_ffn = next(
1318-
(
1319-
ffn_layer
1320-
for ffn_layer in layer.ffn.layers
1321-
if isinstance(ffn_layer, layers.Dense)
1322-
),
1323-
None,
1324-
)
1325-
if dense_layer_in_ffn:
1326-
target_layer = dense_layer_in_ffn
1327-
break
1309+
mock_tokenizer = lambda text: np.array(
1310+
[ord(c) % VOCAB_SIZE for c in text]
1311+
)
1312+
mock_tokenizer.tokenize = mock_tokenizer
13281313

1329-
test_case.assertIsNotNone(
1330-
target_layer,
1331-
"Test setup failed: No Dense layer was found inside an 'ffn' block.",
1332-
)
1333-
original_weights = np.copy(target_layer.kernel.numpy())
1334-
1335-
# --- 3. Configure and Run Quantization ---
1336-
gptq_config = GPTQConfig(
1337-
dataset=dataset,
1338-
tokenizer=mock_tokenizer,
1339-
wbits=W_BITS,
1340-
nsamples=NUM_SAMPLES,
1341-
seqlen=SEQUENCE_LENGTH,
1342-
groupsize=GROUP_SIZE,
1343-
)
1344-
model.quantize("gptq", quant_config=gptq_config)
1314+
# 2. Find target layer and get original weights
1315+
target_layer = model.layers[2].ffn.layers[0]
13451316

1346-
# --- 4. Assertions and Verification ---
1347-
quantized_weights = target_layer.kernel.numpy()
1317+
self.assertIsNotNone(
1318+
target_layer,
1319+
"Test setup failed: No Dense layer was found inside "
1320+
"an 'ffn' block.",
1321+
)
1322+
original_weights = np.copy(target_layer.kernel.numpy())
1323+
1324+
# 3. Configure and run quantization
1325+
gptq_config = GPTQConfig(
1326+
dataset=dataset,
1327+
tokenizer=mock_tokenizer,
1328+
wbits=W_BITS,
1329+
nsamples=NUM_SAMPLES,
1330+
seqlen=SEQUENCE_LENGTH,
1331+
groupsize=GROUP_SIZE,
1332+
)
1333+
model.quantize("gptq", quant_config=gptq_config)
13481334

1349-
# Assert that the weights have been changed
1350-
test_case.assertFalse(
1351-
np.allclose(original_weights, quantized_weights),
1352-
f"Weights were not changed by the GPTQ process for dataset: {dataset}",
1353-
)
1335+
# 4. Assertions and verification
1336+
quantized_weights = target_layer.kernel.numpy()
1337+
1338+
# Assert that the weights have been changed
1339+
self.assertFalse(
1340+
np.allclose(original_weights, quantized_weights),
1341+
"Weights were not changed by the GPTQ process for "
1342+
"dataset: {dataset}",
1343+
)
13541344

1355-
# Verify the quantized model can still make a prediction
1356-
try:
1345+
# Verify the quantized model can still make a prediction
13571346
dummy_input = np.random.randint(
13581347
0, VOCAB_SIZE, size=(1, SEQUENCE_LENGTH)
13591348
)
13601349
_ = model.predict(dummy_input)
1361-
except Exception as e:
1362-
test_case.fail(
1363-
"Prediction failed for the quantized model with dataset: "
1364-
f"{dataset}. Error: {e}"
1365-
)
1366-
13671350

1368-
@pytest.mark.requires_trainable_backend
1369-
class ModelQuantizationTest(testing.TestCase):
1370-
def test_quantize_gptq_with_dense_attention(self):
1371-
"""Tests GPTQ with an in-memory list of strings as the dataset."""
1351+
def test_quantize_gptq_on_different_datasets(self):
1352+
"""Tests GPTQ with various dataset types (string list, generator)."""
13721353

1354+
# Define the datasets to be tested
13731355
long_text = """auto-gptq is an easy-to-use model quantization library
13741356
with user-friendly apis, based on GPTQ algorithm. The goal is to
13751357
quantize pre-trained models to 4-bit or even 3-bit precision with
@@ -1382,12 +1364,16 @@ def test_quantize_gptq_with_dense_attention(self):
13821364
models in resource-constrained environments where every bit of memory
13831365
and every millisecond of latency counts."""
13841366

1385-
string_dataset = [long_text]
1386-
_run_gptq_test_on_dataset(self, string_dataset)
1367+
datasets_to_test = {
1368+
"string_dataset": [long_text],
1369+
"generator_dataset": dummy_dataset_generator(
1370+
nsamples=16, seqlen=128, vocab_size=1000
1371+
),
1372+
}
13871373

1388-
def test_quantize_gptq_with_data_gen(self):
1389-
"""Tests GPTQ with a Python generator as the dataset."""
1390-
generator_dataset = dummy_dataset_generator(
1391-
nsamples=16, seqlen=128, vocab_size=1000
1392-
)
1393-
_run_gptq_test_on_dataset(self, generator_dataset)
1374+
# Loop through the datasets and run each as a sub-test
1375+
for dataset_name, dataset in datasets_to_test.items():
1376+
# 'with self.subTest(...)' ensures that failures are reported
1377+
# for each specific dataset without stopping the whole test.
1378+
with self.subTest(dataset_type=dataset_name):
1379+
self._run_gptq_test_on_dataset(dataset)

keras/src/quantizers/gptq.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,27 @@ def __init__(self, layer):
6161
self.H = ops.zeros((self.rows, self.rows), dtype="float32")
6262

6363
def update_hessian_with_batch(self, inp):
64+
"""
65+
Updates the running average of the Hessian matrix with a new batch.
66+
67+
This method computes the Hessian matrix for a given batch of input
68+
activations and updates the accumulated Hessian (`self.H`) using a
69+
numerically stable running average. This allows the Hessian to be
70+
computed over a large dataset without loading all samples into memory
71+
at once.
72+
73+
The input tensor is first reshaped into a 2D matrix [num_samples,
74+
num_features] before the Hessian is calculated.
75+
76+
Args:
77+
inp: A 2D or higher-dimensional tensor of input activations from a
78+
calibration batch.
79+
80+
Raises:
81+
ValueError: If the feature dimension of the input tensor `inp` does
82+
not match the dimensions of the pre-initialized Hessian matrix
83+
`self.H`.
84+
"""
6485
if len(inp.shape) > 2:
6586
inp = ops.reshape(inp, (-1, inp.shape[-1]))
6687
inp = ops.cast(inp, "float32")
@@ -85,6 +106,51 @@ def update_hessian_with_batch(self, inp):
85106
def quantize_and_correct_block(
86107
self, blocksize=128, percdamp=0.01, groupsize=-1, actorder=False
87108
):
109+
"""
110+
Performs GPTQ quantization and correction on the layer's weights.
111+
112+
This method implements the core logic of the "Optimal Brain Quant"
113+
(OBQ) method, as applied by GPTQ, to quantize the weights of a single
114+
layer. It iteratively quantizes blocks of weights and corrects for the
115+
quantization error by updating the remaining weights.
116+
117+
The algorithm follows these main steps:
118+
1. **Initialization**: It optionally reorders the weight columns based
119+
on activation magnitudes (`actorder=True`) to protect more salient
120+
weights.
121+
2. **Hessian Modification**: The Hessian matrix `H`, pre-computed from
122+
calibration data, is dampened to ensure its invertibility and
123+
stability.
124+
3. **Iterative Quantization**: The function iterates through the
125+
weight columns in blocks (`blocksize`). In each iteration, it:
126+
a. Quantizes one column (`w`).
127+
b. Calculates the quantization error (`err`).
128+
c. Updates the remaining weights in the *current* block by
129+
distributing the error, using the inverse Hessian (`Hinv`).
130+
4. **Block-wise Correction**: After a block is quantized, the total
131+
error from that block is propagated to the *next* block of weights
132+
to be processed.
133+
5. **Finalization**: The quantized weights (`Q`) are reordered back if
134+
`actorder` was used, and the layer's weights are updated.
135+
136+
This implementation is based on the official GPTQ paper and repository.
137+
For more details, see:
138+
- Paper: https://arxiv.org/abs/2210.17323
139+
- Original Code: https://github.com/IST-DASLab/gptq
140+
141+
Args:
142+
blocksize (int, optional): The size of the weight block to process
143+
at a time. Defaults to 128.
144+
percdamp (float, optional): The percentage of dampening to add the
145+
Hessian's diagonal. A value of 0.01 is recommended.
146+
Defaults to 0.01.
147+
groupsize (int, optional): The number of weights that share the
148+
same quantization parameters (scale and zero-point).
149+
A value of -1 indicates per-channel quantization.
150+
actorder (bool, optional): If True, reorders weight columns based
151+
on their activation's second-order information.
152+
"""
153+
88154
W = ops.transpose(ops.cast(self.layer.kernel, "float32"))
89155
H = ops.cast(self.H, "float32")
90156

@@ -94,26 +160,32 @@ def quantize_and_correct_block(
94160
H = ops.take(ops.take(H, perm, axis=0), perm, axis=1)
95161
invperm = ops.argsort(perm)
96162

163+
# Dampen the Hessian for Stability
97164
diag_H = ops.diagonal(H)
98165
dead = ops.equal(diag_H, 0.0)
99166
diag_H = ops.where(dead, 1.0, diag_H)
100167
H = H + ops.diag(ops.where(dead, 1.0, ops.zeros_like(diag_H)))
168+
169+
# Add dampening factor to the Hessian diagonal
101170
damp = percdamp * ops.mean(diag_H)
102171
diag_H = diag_H + damp
103172
H = (H - ops.diag(ops.diagonal(H))) + ops.diag(diag_H)
104173

174+
# Compute the inverse Hessian, which is used for error correction
105175
Hinv = ops.linalg.inv(H)
106176
Q = ops.zeros_like(W)
107177

108178
for i1 in range(0, self.rows, blocksize):
109179
i2 = min(i1 + blocksize, self.rows)
110180
count = i2 - i1
111-
181+
# Extract the current block of weights and its corresponding
182+
# Hessian
112183
W1 = W[:, i1:i2]
113184
Q1 = ops.zeros_like(W1)
114185
Err1 = ops.zeros_like(W1)
115186
Hinv1 = Hinv[i1:i2, i1:i2]
116187

188+
# Process one column at a time within the block
117189
for i in range(count):
118190
w = W1[:, i]
119191
d = Hinv1[i, i]
@@ -128,6 +200,7 @@ def quantize_and_correct_block(
128200
ops.expand_dims(w, 1), weight=True
129201
)
130202

203+
# Quantize the current weight column
131204
q = quantize(
132205
ops.expand_dims(w, 1),
133206
self.quantizer.scale,
@@ -148,11 +221,11 @@ def quantize_and_correct_block(
148221
)
149222

150223
# Efficiently update the remaining part of the W1 tensor.
151-
# This is equivalent to W1[:, i + 1 :] -= update
152224
slice_to_update = W1[:, i + 1 :]
153225
updated_slice = slice_to_update - update
154226
W1 = ops.slice_update(W1, (0, i + 1), updated_slice)
155227

228+
# Update the full quantized matrix Q with the processed block
156229
Q = ops.concatenate([Q[:, :i1], Q1, Q[:, i2:]], axis=1)
157230

158231
if i2 < self.rows:
@@ -169,6 +242,7 @@ def quantize_and_correct_block(
169242
if isinstance(self.original_layer, EinsumDense):
170243
Q = ops.reshape(Q, self.kernel_shape)
171244

245+
# Set the new quantized weights in the original layer
172246
new_weights = [ops.convert_to_numpy(Q)]
173247
if self.original_layer.bias is not None:
174248
new_weights.append(ops.convert_to_numpy(self.original_layer.bias))

keras/src/quantizers/gptqconfig.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import logging
1+
from absl import logging
22

33
from .gptqutils import quantize_model
44

0 commit comments

Comments
 (0)