Skip to content

Commit 0cbd234

Browse files
feat(quantization): Add GPTQ n-bit quantization support (#21551)
* feat(quantization): Add GPTQ n-bit quantization support This commit integrates the GPTQ (Generative Pre-trained Transformer Quantization) algorithm into Keras. Key features include: - A new `GPTQConfig` for configuring quantization parameters. - Integration with base Keras models via a `model.quantize()` method. - Support for multiple datasets (WIKITEXT2,PTB, C4,custom dataset) and tested models (GPT-2, OPT, Bloom,gemma3 etc). - Includes unit tests to verify perplexity and model functionality post-quantization. * added dataset to be installed * Fix the AI comments except one * Fixed gptq algo for inline weights update * updated the review comments * Renamed the quant to gptqquant class * Renamed the quant file to gptqqnat * Reworked some superficial comments * Reworked on review comments * Removed the huggingfce dependency * changed the file name to gptq_config.py * fix comments and added additional test file * added test to improve converage * removed numerics like +,-,* etc and used keras.ops * reworked on the review comments * updated the interface as per comments * reworked the comments * fixed failing test case * Added test case to improve the coverage * Added test case to improve the coverage * Added test case to improve the coverage * reworke on review comments * reworked on final review comments * fix issue while fixing review comments * fix minor review comments * fixed failing test case * fixed some typos
1 parent a1bcf94 commit 0cbd234

File tree

11 files changed

+1453
-2
lines changed

11 files changed

+1453
-2
lines changed

keras/api/_tf_keras/keras/quantizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from keras.src.quantizers import deserialize as deserialize
88
from keras.src.quantizers import get as get
99
from keras.src.quantizers import serialize as serialize
10+
from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
1011
from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer
1112
from keras.src.quantizers.quantizers import Quantizer as Quantizer
1213
from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize

keras/api/quantizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from keras.src.quantizers import deserialize as deserialize
88
from keras.src.quantizers import get as get
99
from keras.src.quantizers import serialize as serialize
10+
from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
1011
from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer
1112
from keras.src.quantizers.quantizers import Quantizer as Quantizer
1213
from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize

keras/src/dtype_policies/dtype_policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from keras.src.api_export import keras_export
44
from keras.src.backend.common import global_state
55

6-
QUANTIZATION_MODES = ("int8", "float8", "int4")
6+
QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq")
77

88

99
@keras_export(

keras/src/models/model.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from keras.src.api_export import keras_export
99
from keras.src.layers.layer import Layer
1010
from keras.src.models.variable_mapping import map_saveable_variables
11+
from keras.src.quantizers.gptq_config import GPTQConfig
1112
from keras.src.saving import saving_api
1213
from keras.src.trainers import trainer as base_trainer
1314
from keras.src.utils import summary_utils
@@ -420,7 +421,7 @@ def load_weights(self, filepath, skip_mismatch=False, **kwargs):
420421
**kwargs,
421422
)
422423

423-
def quantize(self, mode, **kwargs):
424+
def quantize(self, mode, config=None, **kwargs):
424425
"""Quantize the weights of the model.
425426
426427
Note that the model must be built first before calling this method.
@@ -433,6 +434,23 @@ def quantize(self, mode, **kwargs):
433434
"""
434435
from keras.src.dtype_policies import QUANTIZATION_MODES
435436

437+
if mode == "gptq":
438+
if not isinstance(config, GPTQConfig):
439+
raise ValueError(
440+
"The `config` argument must be of type "
441+
"`keras.quantizers.GPTQConfig`."
442+
)
443+
# The config object's own quantize method drives the process
444+
config.quantize(self)
445+
return
446+
447+
# For all other modes, verify that a config object was not passed.
448+
if config is not None:
449+
raise ValueError(
450+
f"The `config` argument is only supported for 'gptq' mode, "
451+
f"but received mode='{mode}'."
452+
)
453+
436454
type_check = kwargs.pop("type_check", True)
437455
if kwargs:
438456
raise ValueError(

keras/src/models/model_test.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import pickle
33
from collections import namedtuple
4+
from collections.abc import Callable
45

56
import numpy as np
67
import pytest
@@ -9,12 +10,14 @@
910
from keras.src import backend
1011
from keras.src import layers
1112
from keras.src import losses
13+
from keras.src import models
1214
from keras.src import testing
1315
from keras.src import tree
1416
from keras.src.layers.core.input_layer import Input
1517
from keras.src.models.functional import Functional
1618
from keras.src.models.model import Model
1719
from keras.src.models.model import model_from_json
20+
from keras.src.quantizers.gptq_config import GPTQConfig
1821

1922

2023
def _get_model():
@@ -1237,3 +1240,177 @@ def test_export_error(self):
12371240
),
12381241
):
12391242
model.export(temp_filepath, format="tf_saved_model")
1243+
1244+
1245+
def dummy_dataset_generator(num_samples, sequence_length, vocab_size=1000):
1246+
"""A generator that yields random numpy arrays for fast,
1247+
self-contained tests."""
1248+
rng = np.random.default_rng(seed=42)
1249+
for _ in range(num_samples):
1250+
yield rng.integers(low=0, high=vocab_size, size=(1, sequence_length))
1251+
1252+
1253+
def get_model_with_dense_attention():
1254+
"""Builds a simple transformer model using Dense for attention."""
1255+
vocab_size = 1000
1256+
embed_dim = 32
1257+
num_heads = 4
1258+
ff_dim = 32
1259+
1260+
class SimpleTransformerBlock(layers.Layer):
1261+
def __init__(self, embed_dim, num_heads, ff_dim, **kwargs):
1262+
super().__init__(**kwargs)
1263+
self.att = layers.MultiHeadAttention(
1264+
num_heads=num_heads, key_dim=embed_dim
1265+
)
1266+
self.ffn = models.Sequential(
1267+
[
1268+
layers.Dense(ff_dim, activation="relu"),
1269+
layers.Dense(embed_dim),
1270+
]
1271+
)
1272+
self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
1273+
self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
1274+
1275+
def call(self, inputs):
1276+
attention_output = self.att(inputs, inputs)
1277+
out1 = self.layernorm1(inputs + attention_output)
1278+
ffn_output = self.ffn(out1)
1279+
return self.layernorm2(out1 + ffn_output)
1280+
1281+
inputs = layers.Input(shape=(None,), dtype="int32")
1282+
embedding_layer = layers.Embedding(vocab_size, embed_dim)
1283+
x = embedding_layer(inputs)
1284+
transformer_block = SimpleTransformerBlock(embed_dim, num_heads, ff_dim)
1285+
x = transformer_block(x)
1286+
outputs = layers.Dense(vocab_size)(x)
1287+
model = models.Model(inputs=inputs, outputs=outputs)
1288+
return model
1289+
1290+
1291+
# Define parameters for the tests
1292+
long_text = """gptq is an easy-to-use model quantization library..."""
1293+
DATASETS = {
1294+
"string_dataset": [long_text],
1295+
"generator_dataset": lambda: dummy_dataset_generator(
1296+
num_samples=16, sequence_length=128
1297+
),
1298+
}
1299+
CONFIGS = {
1300+
"default": {},
1301+
"per_channel": {"group_size": -1},
1302+
"act_order": {"activation_order": True},
1303+
"symmetric": {"symmetric": True},
1304+
}
1305+
1306+
1307+
def _get_simple_model():
1308+
"""Builds a simple sequential model for testing."""
1309+
return models.Sequential([layers.Dense(10, input_shape=(5,))])
1310+
1311+
1312+
quantize_test_cases = [
1313+
# --- Error Scenarios ---
1314+
(
1315+
"gptq",
1316+
{"weight_bits": 4}, # Invalid config (dict, not GPTQConfig)
1317+
ValueError,
1318+
"The `config` argument must be of type",
1319+
"gptq_with_invalid_config",
1320+
),
1321+
(
1322+
"int8",
1323+
GPTQConfig(dataset=["test"], tokenizer=lambda x: x),
1324+
ValueError,
1325+
"is only supported for 'gptq' mode",
1326+
"non_gptq_with_unsupported_config",
1327+
),
1328+
# --- Valid Scenario ---
1329+
(
1330+
"int8",
1331+
None, # No config, which is correct
1332+
None, # No exception expected
1333+
None,
1334+
"non_gptq_runs_without_error",
1335+
),
1336+
]
1337+
1338+
1339+
@pytest.mark.requires_trainable_backend
1340+
class TestModelQuantization:
1341+
def _run_gptq_test_on_dataset(self, dataset, **config_kwargs):
1342+
"""Helper function to run a full GPTQ quantization test."""
1343+
if isinstance(dataset, Callable):
1344+
dataset = dataset()
1345+
model = get_model_with_dense_attention()
1346+
rng = np.random.default_rng(seed=42)
1347+
1348+
NUM_SAMPLES = 16
1349+
SEQUENCE_LENGTH = 128
1350+
VOCAB_SIZE = 1000
1351+
W_BITS = 4
1352+
1353+
mock_tokenizer = lambda text: np.array(
1354+
[ord(c) % VOCAB_SIZE for c in text]
1355+
)
1356+
mock_tokenizer.tokenize = mock_tokenizer
1357+
1358+
base_config = {
1359+
"dataset": dataset,
1360+
"tokenizer": mock_tokenizer,
1361+
"weight_bits": W_BITS,
1362+
"num_samples": NUM_SAMPLES,
1363+
"sequence_length": SEQUENCE_LENGTH,
1364+
"group_size": 32,
1365+
"symmetric": False,
1366+
"activation_order": False,
1367+
}
1368+
1369+
target_layer = model.layers[2].ffn.layers[0]
1370+
assert target_layer is not None
1371+
original_weights = np.copy(target_layer.kernel)
1372+
1373+
final_config = {**base_config, **config_kwargs}
1374+
gptq_config = GPTQConfig(**final_config)
1375+
1376+
model.quantize("gptq", config=gptq_config)
1377+
1378+
quantized_weights = target_layer.kernel
1379+
1380+
assert not np.allclose(original_weights, quantized_weights)
1381+
1382+
dummy_sample = rng.integers(
1383+
low=0, high=VOCAB_SIZE, size=(1, SEQUENCE_LENGTH)
1384+
)
1385+
_ = model.predict(dummy_sample)
1386+
1387+
@pytest.mark.parametrize("dataset", DATASETS.values(), ids=DATASETS.keys())
1388+
@pytest.mark.parametrize("config", CONFIGS.values(), ids=CONFIGS.keys())
1389+
def test_quantize_gptq_combinations(self, dataset, config):
1390+
"""Runs GPTQ tests across different datasets and config variations."""
1391+
self._run_gptq_test_on_dataset(dataset, **config)
1392+
1393+
@pytest.mark.parametrize(
1394+
"mode, config, expected_exception, match_message, test_id",
1395+
quantize_test_cases,
1396+
ids=[case[-1] for case in quantize_test_cases],
1397+
)
1398+
def test_quantize_scenarios(
1399+
self, mode, config, expected_exception, match_message, test_id
1400+
):
1401+
"""
1402+
Tests various scenarios for the model.quantize() method, including
1403+
error handling and valid calls.
1404+
"""
1405+
model = _get_simple_model()
1406+
1407+
if expected_exception:
1408+
# Test for cases where an error is expected
1409+
with pytest.raises(expected_exception, match=match_message):
1410+
model.quantize(mode, config=config)
1411+
else:
1412+
# Test for valid cases where no error should occur
1413+
try:
1414+
model.quantize(mode, config=config)
1415+
except (ValueError, TypeError) as e:
1416+
pytest.fail(f"Test case '{test_id}' failed unexpectedly: {e}")

0 commit comments

Comments
 (0)