Skip to content

feat(quantization): Add GPTQ n-bit quantization support #21551

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
9ccec06
feat(quantization): Add GPTQ n-bit quantization support
amitsrivastava78 Jul 24, 2025
558e5c0
added dataset to be installed
amitsrivastava78 Aug 6, 2025
4b387ad
Fix the AI comments except one
amitsrivastava78 Aug 6, 2025
bd4919b
Fixed gptq algo for inline weights update
amitsrivastava78 Aug 6, 2025
0113829
updated the review comments
amitsrivastava78 Aug 6, 2025
872c75f
Renamed the quant to gptqquant class
amitsrivastava78 Aug 7, 2025
826f923
Renamed the quant file to gptqqnat
amitsrivastava78 Aug 7, 2025
dfb327c
Reworked some superficial comments
amitsrivastava78 Aug 7, 2025
39b8326
Reworked on review comments
amitsrivastava78 Aug 7, 2025
2c7173f
Removed the huggingfce dependency
amitsrivastava78 Aug 8, 2025
02b1df3
changed the file name to gptq_config.py
amitsrivastava78 Aug 8, 2025
7313b5f
fix comments and added additional test file
amitsrivastava78 Aug 8, 2025
463b252
added test to improve converage
amitsrivastava78 Aug 8, 2025
bf3de61
removed numerics like +,-,* etc and used keras.ops
amitsrivastava78 Aug 8, 2025
0c370c7
reworked on the review comments
amitsrivastava78 Aug 8, 2025
044e4ec
updated the interface as per comments
amitsrivastava78 Aug 11, 2025
52b6346
reworked the comments
amitsrivastava78 Aug 12, 2025
7cf0c5d
fixed failing test case
amitsrivastava78 Aug 12, 2025
80e5cf5
Added test case to improve the coverage
amitsrivastava78 Aug 12, 2025
176eb63
Added test case to improve the coverage
amitsrivastava78 Aug 12, 2025
37370e0
Added test case to improve the coverage
amitsrivastava78 Aug 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from keras.src.ops.function import Function as Function
from keras.src.ops.operation import Operation as Operation
from keras.src.optimizers.optimizer import Optimizer as Optimizer
from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
from keras.src.quantizers.quantizers import Quantizer as Quantizer
from keras.src.regularizers.regularizers import Regularizer as Regularizer
from keras.src.version import __version__ as __version__
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from keras.src.ops.function import Function as Function
from keras.src.ops.operation import Operation as Operation
from keras.src.optimizers.optimizer import Optimizer as Optimizer
from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
from keras.src.quantizers.quantizers import Quantizer as Quantizer
from keras.src.regularizers.regularizers import Regularizer as Regularizer
from keras.src.version import __version__ as __version__
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from keras.src.quantizers import deserialize as deserialize
from keras.src.quantizers import get as get
from keras.src.quantizers import serialize as serialize
from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer
from keras.src.quantizers.quantizers import Quantizer as Quantizer
from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize
Expand Down
1 change: 1 addition & 0 deletions keras/api/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from keras.src.quantizers import deserialize as deserialize
from keras.src.quantizers import get as get
from keras.src.quantizers import serialize as serialize
from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer
from keras.src.quantizers.quantizers import Quantizer as Quantizer
from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize
Expand Down
2 changes: 1 addition & 1 deletion keras/src/dtype_policies/dtype_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from keras.src.api_export import keras_export
from keras.src.backend.common import global_state

QUANTIZATION_MODES = ("int8", "float8", "int4")
QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq")


@keras_export(
Expand Down
20 changes: 19 additions & 1 deletion keras/src/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras.src.api_export import keras_export
from keras.src.layers.layer import Layer
from keras.src.models.variable_mapping import map_saveable_variables
from keras.src.quantizers.gptq_config import GPTQConfig
from keras.src.saving import saving_api
from keras.src.trainers import trainer as base_trainer
from keras.src.utils import summary_utils
Expand Down Expand Up @@ -420,7 +421,7 @@ def load_weights(self, filepath, skip_mismatch=False, **kwargs):
**kwargs,
)

def quantize(self, mode, **kwargs):
def quantize(self, mode, config=None, **kwargs):
"""Quantize the weights of the model.
Note that the model must be built first before calling this method.
Expand All @@ -433,6 +434,23 @@ def quantize(self, mode, **kwargs):
"""
from keras.src.dtype_policies import QUANTIZATION_MODES

if mode == "gptq":
if not isinstance(config, GPTQConfig):
raise TypeError(
"When using 'gptq' mode, you must pass a `config` "
"argument of type `keras.quantizers.GPTQConfig`."
)
# The config object's own quantize method drives the process
config.quantize(self)
return

# For all other modes, verify that a config object was not passed.
if config is not None:
raise ValueError(
f"The `config` argument is only supported for 'gptq' mode, "
f"but received mode='{mode}'."
)

type_check = kwargs.pop("type_check", True)
if kwargs:
raise ValueError(
Expand Down
179 changes: 179 additions & 0 deletions keras/src/models/model_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import pickle
from collections import namedtuple
from collections.abc import Callable

import numpy as np
import pytest
Expand All @@ -9,12 +10,14 @@
from keras.src import backend
from keras.src import layers
from keras.src import losses
from keras.src import models
from keras.src import testing
from keras.src import tree
from keras.src.layers.core.input_layer import Input
from keras.src.models.functional import Functional
from keras.src.models.model import Model
from keras.src.models.model import model_from_json
from keras.src.quantizers.gptq_config import GPTQConfig


def _get_model():
Expand Down Expand Up @@ -1237,3 +1240,179 @@ def test_export_error(self):
),
):
model.export(temp_filepath, format="tf_saved_model")


# Helper function to generate dummy data for quick testing.
def dummy_dataset_generator(nsamples, seqlen, vocab_size=1000):
"""A generator that yields random numpy arrays for fast,
self-contained tests."""
rng = np.random.default_rng(seed=42)
for _ in range(nsamples):
yield rng.integers(low=0, high=vocab_size, size=(1, seqlen))


# Helper function to build a simple transformer model.
def get_model_with_dense_attention():
"""Builds a simple transformer model using Dense for attention."""
vocab_size = 1000
embed_dim = 32
num_heads = 4
ff_dim = 32

class SimpleTransformerBlock(layers.Layer):
def __init__(self, embed_dim, num_heads, ff_dim, **kwargs):
super().__init__(**kwargs)
self.att = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=embed_dim
)
self.ffn = models.Sequential(
[
layers.Dense(ff_dim, activation="relu"),
layers.Dense(embed_dim),
]
)
self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)

def call(self, inputs):
attention_output = self.att(inputs, inputs)
out1 = self.layernorm1(inputs + attention_output)
ffn_output = self.ffn(out1)
return self.layernorm2(out1 + ffn_output)

inputs = layers.Input(shape=(None,), dtype="int32")
embedding_layer = layers.Embedding(vocab_size, embed_dim)
x = embedding_layer(inputs)
transformer_block = SimpleTransformerBlock(embed_dim, num_heads, ff_dim)
x = transformer_block(x)
outputs = layers.Dense(vocab_size)(x)
model = models.Model(inputs=inputs, outputs=outputs)
return model


# Define parameters for the tests
long_text = """gptq is an easy-to-use model quantization library..."""
DATASETS = {
"string_dataset": [long_text],
"generator_dataset": lambda: dummy_dataset_generator(
nsamples=16, seqlen=128
),
}
CONFIGS = {
"default": {},
"per_channel": {"group_size": -1},
"act_order": {"act_order": True},
"symmetric": {"symmetric": True},
}


def _get_simple_model():
"""Builds a simple sequential model for testing."""
return models.Sequential([layers.Dense(10, input_shape=(5,))])


quantize_test_cases = [
# --- Error Scenarios ---
(
"gptq",
{"wbits": 4}, # Invalid config (dict, not GPTQConfig)
TypeError,
"must pass a `config` argument of type",
"gptq_with_invalid_config",
),
(
"int8",
GPTQConfig(dataset=["test"], tokenizer=lambda x: x),
ValueError,
"is only supported for 'gptq' mode",
"non_gptq_with_unsupported_config",
),
# --- Valid Scenario ---
(
"int8",
None, # No config, which is correct
None, # No exception expected
None,
"non_gptq_runs_without_error",
),
]


@pytest.mark.requires_trainable_backend
class TestModelQuantization:
def _run_gptq_test_on_dataset(self, dataset, **config_kwargs):
"""Helper function to run a full GPTQ quantization test."""
if isinstance(dataset, Callable):
dataset = dataset()
model = get_model_with_dense_attention()
rng = np.random.default_rng(seed=42)

NUM_SAMPLES = 16
SEQUENCE_LENGTH = 128
VOCAB_SIZE = 1000
W_BITS = 4

mock_tokenizer = lambda text: np.array(
[ord(c) % VOCAB_SIZE for c in text]
)
mock_tokenizer.tokenize = mock_tokenizer

base_config = {
"dataset": dataset,
"tokenizer": mock_tokenizer,
"wbits": W_BITS,
"nsamples": NUM_SAMPLES,
"seqlen": SEQUENCE_LENGTH,
"group_size": 32,
"symmetric": False,
"act_order": False,
}

target_layer = model.layers[2].ffn.layers[0]
assert target_layer is not None
original_weights = np.copy(target_layer.kernel)

final_config = {**base_config, **config_kwargs}
gptq_config = GPTQConfig(**final_config)

model.quantize("gptq", config=gptq_config)

quantized_weights = target_layer.kernel

assert not np.allclose(original_weights, quantized_weights)

dummy_sample = rng.integers(
low=0, high=VOCAB_SIZE, size=(1, SEQUENCE_LENGTH)
)
_ = model.predict(dummy_sample)

@pytest.mark.parametrize("dataset", DATASETS.values(), ids=DATASETS.keys())
@pytest.mark.parametrize("config", CONFIGS.values(), ids=CONFIGS.keys())
def test_quantize_gptq_combinations(self, dataset, config):
"""Runs GPTQ tests across different datasets and config variations."""
self._run_gptq_test_on_dataset(dataset, **config)

@pytest.mark.parametrize(
"mode, config, expected_exception, match_message, test_id",
quantize_test_cases,
ids=[case[-1] for case in quantize_test_cases],
)
def test_quantize_scenarios(
self, mode, config, expected_exception, match_message, test_id
):
"""
Tests various scenarios for the model.quantize() method, including
error handling and valid calls.
"""
model = _get_simple_model()

if expected_exception:
# Test for cases where an error is expected
with pytest.raises(expected_exception, match=match_message):
model.quantize(mode, config=config)
else:
# Test for valid cases where no error should occur
try:
model.quantize(mode, config=config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per my PR level comment, can you add model.save, then reload and verify it's still quantized.

except (ValueError, TypeError) as e:
pytest.fail(f"Test case '{test_id}' failed unexpectedly: {e}")
Loading