Skip to content

Commit 3afe488

Browse files
feat(quantization): Add GPTQ n-bit quantization support
This commit integrates the GPTQ (Generative Pre-trained Transformer Quantization) algorithm into Keras. Key features include: - A new `GPTQConfig` for configuring quantization parameters. - Integration with base Keras models via a `model.quantize()` method. - Support for multiple datasets (WIKITEXT2,PTB, C4,custom dataset) and tested models (GPT-2, OPT, Bloom,gemma3 etc). - Includes unit tests to verify perplexity and model functionality post-quantization.
1 parent 0c6c363 commit 3afe488

File tree

7 files changed

+1007
-1
lines changed

7 files changed

+1007
-1
lines changed

keras/src/dtype_policies/dtype_policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from keras.src.api_export import keras_export
44
from keras.src.backend.common import global_state
55

6-
QUANTIZATION_MODES = ("int8", "float8", "int4")
6+
QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq")
77

88

99
@keras_export(

keras/src/models/model.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,28 @@ def quantize(self, mode, **kwargs):
433433
"""
434434
from keras.src.dtype_policies import QUANTIZATION_MODES
435435

436+
if mode == "gptq":
437+
try:
438+
from keras.src.quantizers.gptqconfig import GPTQConfig
439+
except ImportError:
440+
raise ImportError(
441+
"To use 'gptq' mode, please ensure the necessary "
442+
"quantization modules are correctly placed in"
443+
"keras/src/quantizers."
444+
)
445+
446+
config = kwargs.get("quant_config")
447+
print("Inside the model.py before the instance check")
448+
if not isinstance(config, GPTQConfig):
449+
raise TypeError(
450+
"When using 'gptq' mode, you must pass a `gptq_config` "
451+
"keyword argument of type `keras.quantizers.GPTQConfig`."
452+
)
453+
454+
# The config object's own quantize method drives the process.
455+
config.quantize(self)
456+
return
457+
436458
type_check = kwargs.pop("type_check", True)
437459
if kwargs:
438460
raise ValueError(

keras/src/models/model_test.py

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,117 @@
1+
import io
2+
import logging
13
import os
24
import pickle
5+
import tarfile
36
from collections import namedtuple
47

58
import numpy as np
69
import pytest
10+
import requests
711
from absl.testing import parameterized
12+
from datasets import load_dataset
813

914
from keras.src import backend
1015
from keras.src import layers
1116
from keras.src import losses
17+
from keras.src import models
1218
from keras.src import testing
1319
from keras.src import tree
1420
from keras.src.layers.core.input_layer import Input
1521
from keras.src.models.functional import Functional
1622
from keras.src.models.model import Model
1723
from keras.src.models.model import model_from_json
24+
from keras.src.quantizers.gptqconfig import GPTQConfig
25+
26+
# Configure logging
27+
logging.basicConfig(level=logging.INFO)
28+
29+
30+
def get_dataset_text(dataset_identifier: str, nsamples=1000) -> str:
31+
"""
32+
Loads a specified dataset and extracts its text content into a
33+
single string.
34+
"""
35+
DATASET_CONFIGS = {
36+
"wikitext2": {
37+
"name": "wikitext",
38+
"config": "wikitext-2-raw-v1",
39+
"split": "test",
40+
"text_column": "text",
41+
},
42+
"ptb": {
43+
"name": "ptb_text_only",
44+
"config": "penn_treebank",
45+
"split": "validation",
46+
"text_column": "sentence",
47+
},
48+
"c4": {
49+
"name": "allenai/c4",
50+
"config": "en",
51+
"split": "validation", # Use validation for C4's test split
52+
"text_column": "text",
53+
},
54+
}
55+
56+
if dataset_identifier not in DATASET_CONFIGS:
57+
raise ValueError(
58+
f"Unknown dataset identifier '{dataset_identifier}'. "
59+
f"Available options are: {list(DATASET_CONFIGS.keys())}"
60+
)
61+
62+
config = DATASET_CONFIGS[dataset_identifier]
63+
64+
if dataset_identifier == "ptb":
65+
url = "http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz"
66+
try:
67+
# 1. Download the archive into memory
68+
response = requests.get(url)
69+
response.raise_for_status()
70+
71+
# 2. Extract only the test file from the in-memory archive
72+
with tarfile.open(
73+
fileobj=io.BytesIO(response.content), mode="r:gz"
74+
) as tar:
75+
test_path = "./simple-examples/data/ptb.test.txt"
76+
test_bytes = tar.extractfile(test_path).read()
77+
78+
# 3. Decode the bytes and join into a single string
79+
test_lines = test_bytes.decode("utf-8").strip().split("\n")
80+
all_text = "\n\n".join(test_lines)
81+
82+
print("✅ Successfully processed PTB test data.")
83+
return all_text
84+
85+
except Exception as e:
86+
print(f"Failed to download or process PTB data: {e!r}")
87+
raise e
88+
89+
load_kwargs = {"name": config["config"]}
90+
91+
if dataset_identifier == "c4":
92+
load_kwargs["streaming"] = True
93+
# For PTB, force a redownload to bypass potential cache errors.
94+
if dataset_identifier == "ptb":
95+
load_kwargs["download_mode"] = "force_redownload"
96+
97+
print(f"Loading dataset '{config['name']}'...")
98+
99+
test_data = load_dataset(
100+
config["name"], split=config["split"], **load_kwargs
101+
)
102+
103+
if dataset_identifier == "c4":
104+
print(f" -> Limiting C4 to the first {nsamples} documents forspeed.")
105+
test_data = test_data.take(nsamples)
106+
107+
all_text = "\n\n".join(
108+
row[config["text_column"]]
109+
for row in test_data
110+
if row.get(config["text_column"])
111+
)
112+
113+
print(f"Successfully loaded and processed {dataset_identifier}.")
114+
return all_text
18115

19116

20117
def _get_model():
@@ -1237,3 +1334,156 @@ def test_export_error(self):
12371334
),
12381335
):
12391336
model.export(temp_filepath, format="tf_saved_model")
1337+
1338+
1339+
# Helper function to generate dummy data for quick testing.
1340+
def dummy_dataset_generator(nsamples, seqlen, vocab_size=1000):
1341+
"""A generator that yields random numpy arrays for fast,
1342+
self-contained tests."""
1343+
for _ in range(nsamples):
1344+
yield np.random.randint(0, vocab_size, size=(1, seqlen))
1345+
1346+
1347+
# Helper function to build a simple transformer model that uses standard
1348+
# Keras `Dense` layers for its attention projections.
1349+
def _get_model_with_dense_attention():
1350+
"""Builds a simple transformer model using Dense for attention."""
1351+
vocab_size = 1000
1352+
embed_dim = 32
1353+
num_heads = 4
1354+
ff_dim = 32
1355+
1356+
class SimpleTransformerBlock(layers.Layer):
1357+
def __init__(self, embed_dim, num_heads, ff_dim, **kwargs):
1358+
super().__init__(**kwargs)
1359+
# The standard MultiHeadAttention layer uses Dense layers
1360+
# for its projections.
1361+
self.att = layers.MultiHeadAttention(
1362+
num_heads=num_heads, key_dim=embed_dim
1363+
)
1364+
self.ffn = models.Sequential(
1365+
[
1366+
layers.Dense(ff_dim, activation="relu"),
1367+
layers.Dense(embed_dim),
1368+
]
1369+
)
1370+
self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
1371+
self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
1372+
1373+
def call(self, inputs):
1374+
attention_output = self.att(inputs, inputs)
1375+
out1 = self.layernorm1(inputs + attention_output)
1376+
ffn_output = self.ffn(out1)
1377+
return self.layernorm2(out1 + ffn_output)
1378+
1379+
inputs = layers.Input(shape=(None,), dtype="int32")
1380+
embedding_layer = layers.Embedding(vocab_size, embed_dim)
1381+
x = embedding_layer(inputs)
1382+
transformer_block = SimpleTransformerBlock(embed_dim, num_heads, ff_dim)
1383+
x = transformer_block(x)
1384+
outputs = layers.Dense(vocab_size)(x)
1385+
model = models.Model(inputs=inputs, outputs=outputs)
1386+
return model
1387+
1388+
1389+
def _run_gptq_test_on_dataset(test_case, dataset):
1390+
"""Helper function to run a full GPTQ quantization
1391+
test on a given dataset."""
1392+
model = _get_model_with_dense_attention()
1393+
1394+
# --- 1. Common Setup ---
1395+
NUM_SAMPLES = 16
1396+
SEQUENCE_LENGTH = 128
1397+
VOCAB_SIZE = 1000
1398+
W_BITS = 4
1399+
GROUP_SIZE = 32
1400+
1401+
mock_tokenizer = lambda text: np.array([ord(c) % VOCAB_SIZE for c in text])
1402+
mock_tokenizer.tokenize = mock_tokenizer
1403+
1404+
# --- 2. Find Target Layer and Get Original Weights ---
1405+
target_layer = None
1406+
for layer in model.layers:
1407+
if hasattr(layer, "ffn") and hasattr(layer.ffn, "layers"):
1408+
dense_layer_in_ffn = next(
1409+
(
1410+
ffn_layer
1411+
for ffn_layer in layer.ffn.layers
1412+
if isinstance(ffn_layer, layers.Dense)
1413+
),
1414+
None,
1415+
)
1416+
if dense_layer_in_ffn:
1417+
target_layer = dense_layer_in_ffn
1418+
break
1419+
1420+
test_case.assertIsNotNone(
1421+
target_layer,
1422+
"Test setup failed: No Dense layer was found inside an 'ffn' block.",
1423+
)
1424+
original_weights = np.copy(target_layer.kernel.numpy())
1425+
1426+
# --- 3. Configure and Run Quantization ---
1427+
gptq_config = GPTQConfig(
1428+
dataset=dataset,
1429+
tokenizer=mock_tokenizer,
1430+
wbits=W_BITS,
1431+
nsamples=NUM_SAMPLES,
1432+
seqlen=SEQUENCE_LENGTH,
1433+
groupsize=GROUP_SIZE,
1434+
)
1435+
model.quantize("gptq", quant_config=gptq_config)
1436+
1437+
# --- 4. Assertions and Verification ---
1438+
quantized_weights = target_layer.kernel.numpy()
1439+
1440+
# Assert that the weights have been changed
1441+
test_case.assertFalse(
1442+
np.allclose(original_weights, quantized_weights),
1443+
f"Weights were not changed by the GPTQ process for dataset: {dataset}",
1444+
)
1445+
1446+
# Verify the quantized model can still make a prediction
1447+
try:
1448+
dummy_input = np.random.randint(
1449+
0, VOCAB_SIZE, size=(1, SEQUENCE_LENGTH)
1450+
)
1451+
_ = model.predict(dummy_input)
1452+
except Exception as e:
1453+
test_case.fail(
1454+
"Prediction failed for the quantized model with dataset: "
1455+
f"{dataset}. Error: {e}"
1456+
)
1457+
1458+
1459+
@pytest.mark.requires_trainable_backend
1460+
class ModelQuantizationTest(testing.TestCase):
1461+
def test_quantize_gptq_with_dense_attention(self):
1462+
"""Tests GPTQ with an in-memory list of strings as the dataset."""
1463+
1464+
long_text = """auto-gptq is an easy-to-use model quantization library
1465+
with user-friendly apis, based on GPTQ algorithm. The goal is to
1466+
quantize pre-trained models to 4-bit or even 3-bit precision with
1467+
minimal performance degradation.
1468+
This allows for running larger models on less powerful hardware,
1469+
reducing memory footprint and increasing inference speed.
1470+
The process involves calibrating the model on a small dataset
1471+
to determine the quantization parameters.
1472+
This technique is particularly useful for deploying large language
1473+
models in resource-constrained environments where every bit of memory
1474+
and every millisecond of latency counts."""
1475+
1476+
string_dataset = [long_text]
1477+
_run_gptq_test_on_dataset(self, string_dataset)
1478+
1479+
def test_quantize_gptq_with_data_gen(self):
1480+
"""Tests GPTQ with a Python generator as the dataset."""
1481+
generator_dataset = dummy_dataset_generator(
1482+
nsamples=16, seqlen=128, vocab_size=1000
1483+
)
1484+
_run_gptq_test_on_dataset(self, generator_dataset)
1485+
1486+
@pytest.mark.slow
1487+
def test_quantize_gptq_with_wikitext2(self):
1488+
"""Tests GPTQ with the 'wikitext2' dataset identifier."""
1489+
_run_gptq_test_on_dataset(self, "wikitext2")

0 commit comments

Comments
 (0)