Skip to content

Commit baa0682

Browse files
reworked the comments
1 parent c9ff4d1 commit baa0682

File tree

5 files changed

+82
-116
lines changed

5 files changed

+82
-116
lines changed

keras/src/models/model_test.py

Lines changed: 36 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import pickle
33
from collections import namedtuple
4+
from collections.abc import Callable
45

56
import numpy as np
67
import pytest
@@ -1250,9 +1251,8 @@ def dummy_dataset_generator(nsamples, seqlen, vocab_size=1000):
12501251
yield rng.integers(low=0, high=vocab_size, size=(1, seqlen))
12511252

12521253

1253-
# Helper function to build a simple transformer model that uses standard
1254-
# Keras `Dense` layers for its attention projections.
1255-
def _get_model_with_dense_attention():
1254+
# Helper function to build a simple transformer model.
1255+
def get_model_with_dense_attention():
12561256
"""Builds a simple transformer model using Dense for attention."""
12571257
vocab_size = 1000
12581258
embed_dim = 32
@@ -1262,8 +1262,6 @@ def _get_model_with_dense_attention():
12621262
class SimpleTransformerBlock(layers.Layer):
12631263
def __init__(self, embed_dim, num_heads, ff_dim, **kwargs):
12641264
super().__init__(**kwargs)
1265-
# The standard MultiHeadAttention layer uses Dense layers
1266-
# for its projections.
12671265
self.att = layers.MultiHeadAttention(
12681266
num_heads=num_heads, key_dim=embed_dim
12691267
)
@@ -1292,22 +1290,44 @@ def call(self, inputs):
12921290
return model
12931291

12941292

1293+
# Define parameters for the tests
1294+
long_text = """gptq is an easy-to-use model quantization library..."""
1295+
DATASETS = {
1296+
"string_dataset": [long_text],
1297+
"generator_dataset": lambda: dummy_dataset_generator(
1298+
nsamples=16, seqlen=128
1299+
),
1300+
}
1301+
CONFIGS = {
1302+
"default": {},
1303+
"per_channel": {"group_size": -1},
1304+
"act_order": {"act_order": True},
1305+
"symmetric": {"symmetric": True},
1306+
}
1307+
1308+
12951309
@pytest.mark.requires_trainable_backend
1296-
class ModelQuantizationTest(testing.TestCase):
1310+
class TestModelQuantization:
12971311
def _run_gptq_test_on_dataset(self, dataset, **config_kwargs):
12981312
"""Helper function to run a full GPTQ quantization test."""
1299-
model = _get_model_with_dense_attention()
1313+
if isinstance(dataset, Callable):
1314+
dataset = dataset()
1315+
model = get_model_with_dense_attention()
13001316
rng = np.random.default_rng(seed=42)
13011317

1302-
# 1. Common setup
13031318
NUM_SAMPLES = 16
13041319
SEQUENCE_LENGTH = 128
13051320
VOCAB_SIZE = 1000
13061321
W_BITS = 4
13071322

1308-
# Default config that can be overridden by config_kwargs
1323+
mock_tokenizer = lambda text: np.array(
1324+
[ord(c) % VOCAB_SIZE for c in text]
1325+
)
1326+
mock_tokenizer.tokenize = mock_tokenizer
1327+
13091328
base_config = {
13101329
"dataset": dataset,
1330+
"tokenizer": mock_tokenizer,
13111331
"wbits": W_BITS,
13121332
"nsamples": NUM_SAMPLES,
13131333
"seqlen": SEQUENCE_LENGTH,
@@ -1316,82 +1336,26 @@ def _run_gptq_test_on_dataset(self, dataset, **config_kwargs):
13161336
"act_order": False,
13171337
}
13181338

1319-
mock_tokenizer = lambda text: np.array(
1320-
[ord(c) % VOCAB_SIZE for c in text]
1321-
)
1322-
mock_tokenizer.tokenize = mock_tokenizer
1323-
base_config["tokenizer"] = mock_tokenizer
1324-
1325-
# Find target layer and get original weights
13261339
target_layer = model.layers[2].ffn.layers[0]
1327-
self.assertIsNotNone(
1328-
target_layer,
1329-
"Test setup failed: No Dense layer found in 'ffn' block.",
1330-
)
1340+
assert target_layer is not None
13311341
original_weights = np.copy(target_layer.kernel)
13321342

1333-
# Configure and run quantization
13341343
final_config = {**base_config, **config_kwargs}
13351344
gptq_config = GPTQConfig(**final_config)
13361345

13371346
model.quantize("gptq", config=gptq_config)
13381347

1339-
# Assertions and verification
13401348
quantized_weights = target_layer.kernel
13411349

1342-
self.assertNotAllClose(
1343-
original_weights,
1344-
quantized_weights,
1345-
msg=f"Weights not changed by GPTQ for config: {config_kwargs}",
1346-
)
1350+
assert not np.allclose(original_weights, quantized_weights)
13471351

13481352
dummy_sample = rng.integers(
13491353
low=0, high=VOCAB_SIZE, size=(1, SEQUENCE_LENGTH)
13501354
)
13511355
_ = model.predict(dummy_sample)
13521356

1353-
def test_quantize_gptq_on_different_datasets(self):
1354-
"""Tests GPTQ with various dataset types (string list, generator)."""
1355-
1356-
# Define the datasets to be tested
1357-
long_text = """gptq is an easy-to-use model quantization library
1358-
with user-friendly apis, based on GPTQ algorithm. The goal is to
1359-
quantize pre-trained models to 4-bit or even 3-bit precision with
1360-
minimal performance degradation.
1361-
This allows for running larger models on less powerful hardware,
1362-
reducing memory footprint and increasing inference speed.
1363-
The process involves calibrating the model on a small dataset
1364-
to determine the quantization parameters.
1365-
This technique is particularly useful for deploying large language
1366-
models in resource-constrained environments where every bit of memory
1367-
and every millisecond of latency counts."""
1368-
1369-
datasets_to_test = {
1370-
"string_dataset": [long_text],
1371-
"generator_dataset": dummy_dataset_generator(
1372-
nsamples=16, seqlen=128, vocab_size=1000
1373-
),
1374-
}
1375-
1376-
# Loop through the datasets and run each as a sub-test
1377-
for dataset_name, dataset in datasets_to_test.items():
1378-
with self.subTest(dataset_type=dataset_name):
1379-
self._run_gptq_test_on_dataset(dataset)
1380-
1381-
def test_quantize_gptq_with_config_variations(self):
1382-
"""Tests GPTQ with specific config variations."""
1383-
config_variations = {
1384-
"per_channel": {"group_size": -1},
1385-
"act_order": {"act_order": True},
1386-
"symmetric": {"symmetric": True},
1387-
"all_options_enabled": {
1388-
"group_size": -1,
1389-
"act_order": True,
1390-
"symmetric": True,
1391-
},
1392-
}
1393-
1394-
dataset = ["This is the calibration data for the test."]
1395-
for config_name, config_overrides in config_variations.items():
1396-
with self.subTest(config_type=config_name):
1397-
self._run_gptq_test_on_dataset(dataset, **config_overrides)
1357+
@pytest.mark.parametrize("dataset", DATASETS.values(), ids=DATASETS.keys())
1358+
@pytest.mark.parametrize("config", CONFIGS.values(), ids=CONFIGS.keys())
1359+
def test_quantize_gptq_combinations(self, dataset, config):
1360+
"""Runs GPTQ tests across different datasets and config variations."""
1361+
self._run_gptq_test_on_dataset(dataset, **config)

keras/src/quantizers/gptq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from keras.src import ops
22
from keras.src.layers import Dense
33
from keras.src.layers import EinsumDense
4-
from keras.src.quantizers.gptqquant import dequantize
4+
from keras.src.quantizers.gptq_quant import dequantize
55

66

77
class GPTQ:

keras/src/quantizers/gptq_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from absl import logging
22

33
from keras.src.api_export import keras_export
4-
from keras.src.quantizers.gptqutils import quantize_model
4+
from keras.src.quantizers.gptq_core import quantize_model
55

66

77
@keras_export(["keras.GPTQConfig", "keras.quantizers.GPTQConfig"])

keras/src/quantizers/gptqutils.py renamed to keras/src/quantizers/gptq_core.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from keras.src.layers import EinsumDense
1010
from keras.src.layers import Embedding
1111
from keras.src.quantizers.gptq import GPTQ
12-
from keras.src.quantizers.gptqquant import GPTQQuant
12+
from keras.src.quantizers.gptq_quant import GPTQQuant
1313

1414

1515
def get_dataloader(tokenizer, seqlen, dataset, nsamples=128):
@@ -20,9 +20,9 @@ def get_dataloader(tokenizer, seqlen, dataset, nsamples=128):
2020

2121
if isinstance(dataset, str):
2222
raise TypeError(
23-
"The `dataset` argument must be an iterable (e.g., a list or "
24-
"generator) of strings or pre-tokenized tensors. Loading "
25-
"datasets by name is no longer supported."
23+
"The `dataset` argument must be an iterable (e.g., a list of "
24+
"strings or a generator). Providing a dataset name as a string "
25+
"is not supported. Please pass the loaded dataset directly."
2626
)
2727

2828
logging.info("Using pre-made dataset/generator...")
@@ -37,10 +37,9 @@ def get_dataloader(tokenizer, seqlen, dataset, nsamples=128):
3737
all_tokens = tokenizer.tokenize(full_text)
3838
else:
3939
logging.info("(Dataset is pre-tokenized, concatenating...)")
40-
concatenated_tokens = ops.concatenate(
41-
[ops.reshape(s, [-1]) for s in dataset_list], axis=0
40+
all_tokens = np.concatenate(
41+
[ops.convert_to_numpy(s).reshape(-1) for s in dataset_list], axis=0
4242
)
43-
all_tokens = ops.convert_to_numpy(concatenated_tokens)
4443

4544
all_tokens = np.array(all_tokens, dtype=np.int32)
4645

@@ -62,10 +61,10 @@ def get_dataloader(tokenizer, seqlen, dataset, nsamples=128):
6261
start_index = random.randint(0, len(all_tokens) - seqlen - 1)
6362
end_index = start_index + seqlen
6463
sample = all_tokens[start_index:end_index]
65-
calibration_samples.append(ops.reshape(sample, (1, seqlen)))
64+
calibration_samples.append(np.reshape(sample, (1, seqlen)))
6665

67-
final_array = ops.stack(calibration_samples, axis=0)
68-
return ops.convert_to_numpy(final_array)
66+
final_array = np.stack(calibration_samples, axis=0)
67+
return final_array
6968

7069

7170
def _find_layers_recursive(layer, prefix, found_layers):
@@ -106,9 +105,15 @@ def apply_gptq_layerwise(
106105
):
107106
"""Applies GPTQ quantization layer-by-layer to a Keras model.
108107
109-
This function performs a sequential, model-agnostic quantization process. It
110-
dynamically identifies quantizable layers (e.g., Dense, EinsumDense)
111-
within larger "transformer blocks" of a model.
108+
This function is designed to work with common transformer architectures,
109+
like those provided by KerasNLP and KerasHub. It automatically discovers
110+
the model's structure by first looking for the standard KerasNLP format:
111+
a `model.backbone` attribute that contains a `transformer_layers` list.
112+
113+
If a standard backbone is not found, it falls back to a heuristic for
114+
custom models, where it assumes the first `keras.layers.Embedding` layer
115+
is the input embedding and any subsequent container layers are the
116+
transformer blocks to be quantized.
112117
113118
The core logic operates as follows:
114119
1. It automatically detects the model's structure, identifying the main
@@ -154,7 +159,17 @@ def apply_gptq_layerwise(
154159
if hasattr(model, "backbone"):
155160
logging.info("Detected KerasNLP model structure.")
156161
backbone = model.backbone
157-
transformer_blocks = backbone.transformer_layers
162+
163+
# Add the check for the 'transformer_layers' attribute.
164+
if hasattr(backbone, "transformer_layers"):
165+
transformer_blocks = backbone.transformer_layers
166+
else:
167+
# Raise a specific error if the attribute is missing.
168+
raise ValueError(
169+
"The model's backbone does not have a 'transformer_layers' "
170+
"attribute. Please ensure you are using a standard KerasNLP "
171+
"transformer model."
172+
)
158173
# Find the embedding layer by checking for common names or by type.
159174
if hasattr(backbone, "token_embedding"):
160175
embedding_layer = backbone.token_embedding
@@ -256,13 +271,12 @@ def hook(*args, **kwargs):
256271
inp_reshaped = ops.reshape(layer_inputs, (-1, num_features))
257272
gptq_object.update_hessian_with_batch(inp_reshaped)
258273

259-
quantizer = GPTQQuant()
260-
quantizer.configure(
261-
wbits,
262-
perchannel=True,
263-
symmetric=symmetric,
264-
group_size=group_size,
265-
)
274+
quantizer = GPTQQuant(
275+
wbits,
276+
perchannel=True,
277+
symmetric=symmetric,
278+
group_size=group_size,
279+
)
266280
for name, gptq_object in gptq_objects.items():
267281
logging.info(f"Quantizing {name}...")
268282
gptq_object.quantizer = quantizer

keras/src/quantizers/gptqquant.py renamed to keras/src/quantizers/gptq_quant.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,15 @@
22

33

44
def dequantize(x, scale, zero, maxq):
5-
"""The core quantization function with correct broadcasting."""
6-
# Ensure scale is broadcastable with the input tensor x
7-
if scale.shape != x.shape:
8-
scale = ops.broadcast_to(scale, x.shape)
9-
10-
# Ensure zero-point is also broadcastable
11-
if zero.shape != x.shape:
12-
zero = ops.broadcast_to(zero, x.shape)
13-
14-
epsilon = 1e-8
5+
"""The core quantization function."""
6+
epsilon = ops.cast(1e-8, dtype=scale.dtype)
157
scale = ops.where(ops.equal(scale, 0), epsilon, scale)
8+
169
quantized_x = ops.divide(x, scale)
1710
quantized_x = ops.round(quantized_x)
1811
q = ops.add(quantized_x, zero)
1912
q = ops.clip(q, 0, maxq)
13+
2014
dequantized_x = ops.subtract(q, zero)
2115
return ops.multiply(scale, dequantized_x)
2216

@@ -48,23 +42,17 @@ class GPTQQuant:
4842
Defaults to -1.
4943
"""
5044

51-
def __init__(self):
52-
self.scale = None
53-
self.zero = None
54-
self.maxq = None
55-
self.wbits = None
56-
self.perchannel = False
57-
self.symmetric = False
58-
self.group_size = -1
59-
60-
def configure(self, wbits, perchannel=True, symmetric=False, group_size=-1):
61-
"""Configures the quantizer settings."""
45+
def __init__(self, wbits, perchannel=True, symmetric=False, group_size=-1):
6246
self.wbits = wbits
6347
self.maxq = ops.cast((2**wbits) - 1, "float32")
6448
self.perchannel = perchannel
6549
self.symmetric = symmetric
6650
self.group_size = group_size
6751

52+
# These are now determined later by `find_params`
53+
self.scale = None
54+
self.zero = None
55+
6856
def find_params(self, x, weight=False):
6957
"""Finds quantization parameters (scale and zero) for a given tensor."""
7058

0 commit comments

Comments
 (0)