diff --git a/keras/api/__init__.py b/keras/api/__init__.py index dee6cea5bb19..01ff61783001 100644 --- a/keras/api/__init__.py +++ b/keras/api/__init__.py @@ -60,6 +60,7 @@ from keras.src.ops.function import Function as Function from keras.src.ops.operation import Operation as Operation from keras.src.optimizers.optimizer import Optimizer as Optimizer +from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig from keras.src.quantizers.quantizers import Quantizer as Quantizer from keras.src.regularizers.regularizers import Regularizer as Regularizer from keras.src.version import __version__ as __version__ diff --git a/keras/api/_tf_keras/keras/__init__.py b/keras/api/_tf_keras/keras/__init__.py index 67d4738a0f3c..7a94d4cf1696 100644 --- a/keras/api/_tf_keras/keras/__init__.py +++ b/keras/api/_tf_keras/keras/__init__.py @@ -58,6 +58,7 @@ from keras.src.ops.function import Function as Function from keras.src.ops.operation import Operation as Operation from keras.src.optimizers.optimizer import Optimizer as Optimizer +from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig from keras.src.quantizers.quantizers import Quantizer as Quantizer from keras.src.regularizers.regularizers import Regularizer as Regularizer from keras.src.version import __version__ as __version__ diff --git a/keras/api/_tf_keras/keras/quantizers/__init__.py b/keras/api/_tf_keras/keras/quantizers/__init__.py index 48cf9b034560..299e467ac1bb 100644 --- a/keras/api/_tf_keras/keras/quantizers/__init__.py +++ b/keras/api/_tf_keras/keras/quantizers/__init__.py @@ -7,6 +7,7 @@ from keras.src.quantizers import deserialize as deserialize from keras.src.quantizers import get as get from keras.src.quantizers import serialize as serialize +from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer from keras.src.quantizers.quantizers import Quantizer as Quantizer from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize diff --git a/keras/api/quantizers/__init__.py b/keras/api/quantizers/__init__.py index 48cf9b034560..299e467ac1bb 100644 --- a/keras/api/quantizers/__init__.py +++ b/keras/api/quantizers/__init__.py @@ -7,6 +7,7 @@ from keras.src.quantizers import deserialize as deserialize from keras.src.quantizers import get as get from keras.src.quantizers import serialize as serialize +from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer from keras.src.quantizers.quantizers import Quantizer as Quantizer from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize diff --git a/keras/src/dtype_policies/dtype_policy.py b/keras/src/dtype_policies/dtype_policy.py index 8182a1e45aa4..975b6a2930df 100644 --- a/keras/src/dtype_policies/dtype_policy.py +++ b/keras/src/dtype_policies/dtype_policy.py @@ -3,7 +3,7 @@ from keras.src.api_export import keras_export from keras.src.backend.common import global_state -QUANTIZATION_MODES = ("int8", "float8", "int4") +QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq") @keras_export( diff --git a/keras/src/models/model.py b/keras/src/models/model.py index f75fc2efba9c..8fc889546616 100644 --- a/keras/src/models/model.py +++ b/keras/src/models/model.py @@ -8,6 +8,7 @@ from keras.src.api_export import keras_export from keras.src.layers.layer import Layer from keras.src.models.variable_mapping import map_saveable_variables +from keras.src.quantizers.gptq_config import GPTQConfig from keras.src.saving import saving_api from keras.src.trainers import trainer as base_trainer from keras.src.utils import summary_utils @@ -420,7 +421,7 @@ def load_weights(self, filepath, skip_mismatch=False, **kwargs): **kwargs, ) - def quantize(self, mode, **kwargs): + def quantize(self, mode, config=None, **kwargs): """Quantize the weights of the model. Note that the model must be built first before calling this method. @@ -433,6 +434,23 @@ def quantize(self, mode, **kwargs): """ from keras.src.dtype_policies import QUANTIZATION_MODES + if mode == "gptq": + if not isinstance(config, GPTQConfig): + raise TypeError( + "When using 'gptq' mode, you must pass a `config` " + "argument of type `keras.quantizers.GPTQConfig`." + ) + # The config object's own quantize method drives the process + config.quantize(self) + return + + # For all other modes, verify that a config object was not passed. + if config is not None: + raise ValueError( + f"The `config` argument is only supported for 'gptq' mode, " + f"but received mode='{mode}'." + ) + type_check = kwargs.pop("type_check", True) if kwargs: raise ValueError( diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index 6ed7d3c6543e..4b7be8626567 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -1,6 +1,7 @@ import os import pickle from collections import namedtuple +from collections.abc import Callable import numpy as np import pytest @@ -9,12 +10,14 @@ from keras.src import backend from keras.src import layers from keras.src import losses +from keras.src import models from keras.src import testing from keras.src import tree from keras.src.layers.core.input_layer import Input from keras.src.models.functional import Functional from keras.src.models.model import Model from keras.src.models.model import model_from_json +from keras.src.quantizers.gptq_config import GPTQConfig def _get_model(): @@ -1237,3 +1240,179 @@ def test_export_error(self): ), ): model.export(temp_filepath, format="tf_saved_model") + + +# Helper function to generate dummy data for quick testing. +def dummy_dataset_generator(nsamples, seqlen, vocab_size=1000): + """A generator that yields random numpy arrays for fast, + self-contained tests.""" + rng = np.random.default_rng(seed=42) + for _ in range(nsamples): + yield rng.integers(low=0, high=vocab_size, size=(1, seqlen)) + + +# Helper function to build a simple transformer model. +def get_model_with_dense_attention(): + """Builds a simple transformer model using Dense for attention.""" + vocab_size = 1000 + embed_dim = 32 + num_heads = 4 + ff_dim = 32 + + class SimpleTransformerBlock(layers.Layer): + def __init__(self, embed_dim, num_heads, ff_dim, **kwargs): + super().__init__(**kwargs) + self.att = layers.MultiHeadAttention( + num_heads=num_heads, key_dim=embed_dim + ) + self.ffn = models.Sequential( + [ + layers.Dense(ff_dim, activation="relu"), + layers.Dense(embed_dim), + ] + ) + self.layernorm1 = layers.LayerNormalization(epsilon=1e-6) + self.layernorm2 = layers.LayerNormalization(epsilon=1e-6) + + def call(self, inputs): + attention_output = self.att(inputs, inputs) + out1 = self.layernorm1(inputs + attention_output) + ffn_output = self.ffn(out1) + return self.layernorm2(out1 + ffn_output) + + inputs = layers.Input(shape=(None,), dtype="int32") + embedding_layer = layers.Embedding(vocab_size, embed_dim) + x = embedding_layer(inputs) + transformer_block = SimpleTransformerBlock(embed_dim, num_heads, ff_dim) + x = transformer_block(x) + outputs = layers.Dense(vocab_size)(x) + model = models.Model(inputs=inputs, outputs=outputs) + return model + + +# Define parameters for the tests +long_text = """gptq is an easy-to-use model quantization library...""" +DATASETS = { + "string_dataset": [long_text], + "generator_dataset": lambda: dummy_dataset_generator( + nsamples=16, seqlen=128 + ), +} +CONFIGS = { + "default": {}, + "per_channel": {"group_size": -1}, + "act_order": {"act_order": True}, + "symmetric": {"symmetric": True}, +} + + +def _get_simple_model(): + """Builds a simple sequential model for testing.""" + return models.Sequential([layers.Dense(10, input_shape=(5,))]) + + +quantize_test_cases = [ + # --- Error Scenarios --- + ( + "gptq", + {"wbits": 4}, # Invalid config (dict, not GPTQConfig) + TypeError, + "must pass a `config` argument of type", + "gptq_with_invalid_config", + ), + ( + "int8", + GPTQConfig(dataset=["test"], tokenizer=lambda x: x), + ValueError, + "is only supported for 'gptq' mode", + "non_gptq_with_unsupported_config", + ), + # --- Valid Scenario --- + ( + "int8", + None, # No config, which is correct + None, # No exception expected + None, + "non_gptq_runs_without_error", + ), +] + + +@pytest.mark.requires_trainable_backend +class TestModelQuantization: + def _run_gptq_test_on_dataset(self, dataset, **config_kwargs): + """Helper function to run a full GPTQ quantization test.""" + if isinstance(dataset, Callable): + dataset = dataset() + model = get_model_with_dense_attention() + rng = np.random.default_rng(seed=42) + + NUM_SAMPLES = 16 + SEQUENCE_LENGTH = 128 + VOCAB_SIZE = 1000 + W_BITS = 4 + + mock_tokenizer = lambda text: np.array( + [ord(c) % VOCAB_SIZE for c in text] + ) + mock_tokenizer.tokenize = mock_tokenizer + + base_config = { + "dataset": dataset, + "tokenizer": mock_tokenizer, + "wbits": W_BITS, + "nsamples": NUM_SAMPLES, + "seqlen": SEQUENCE_LENGTH, + "group_size": 32, + "symmetric": False, + "act_order": False, + } + + target_layer = model.layers[2].ffn.layers[0] + assert target_layer is not None + original_weights = np.copy(target_layer.kernel) + + final_config = {**base_config, **config_kwargs} + gptq_config = GPTQConfig(**final_config) + + model.quantize("gptq", config=gptq_config) + + quantized_weights = target_layer.kernel + + assert not np.allclose(original_weights, quantized_weights) + + dummy_sample = rng.integers( + low=0, high=VOCAB_SIZE, size=(1, SEQUENCE_LENGTH) + ) + _ = model.predict(dummy_sample) + + @pytest.mark.parametrize("dataset", DATASETS.values(), ids=DATASETS.keys()) + @pytest.mark.parametrize("config", CONFIGS.values(), ids=CONFIGS.keys()) + def test_quantize_gptq_combinations(self, dataset, config): + """Runs GPTQ tests across different datasets and config variations.""" + self._run_gptq_test_on_dataset(dataset, **config) + + @pytest.mark.parametrize( + "mode, config, expected_exception, match_message, test_id", + quantize_test_cases, + ids=[case[-1] for case in quantize_test_cases], + ) + def test_quantize_scenarios( + self, mode, config, expected_exception, match_message, test_id + ): + """ + Tests various scenarios for the model.quantize() method, including + error handling and valid calls. + """ + model = _get_simple_model() + + if expected_exception: + # Test for cases where an error is expected + with pytest.raises(expected_exception, match=match_message): + model.quantize(mode, config=config) + else: + # Test for valid cases where no error should occur + try: + model.quantize(mode, config=config) + except (ValueError, TypeError) as e: + pytest.fail(f"Test case '{test_id}' failed unexpectedly: {e}") diff --git a/keras/src/quantizers/gptq.py b/keras/src/quantizers/gptq.py new file mode 100644 index 000000000000..54a48f92674d --- /dev/null +++ b/keras/src/quantizers/gptq.py @@ -0,0 +1,268 @@ +from keras.src import ops +from keras.src.layers import Dense +from keras.src.layers import EinsumDense +from keras.src.quantizers.gptq_quant import dequantize + + +class GPTQ: + def __init__(self, layer): + self.original_layer = layer + self.nsamples = 0 + self.quantizer = None + + # Explicitly handle each supported layer type + if isinstance(layer, Dense) or ( + isinstance(layer, EinsumDense) and layer.kernel.ndim == 2 + ): + # For a standard Dense layer, the dimensions are straightforward. + self.kernel_shape = layer.kernel.shape + self.rows = self.kernel_shape[0] # Input features + self.columns = self.kernel_shape[1] # Output features + self.layer = layer # The layer itself can be used directly. + + # Handle 3D EinsumDense layers (typically from attention blocks). + elif isinstance(layer, EinsumDense) and layer.kernel.ndim == 3: + # For EinsumDense, we determine the effective 2D dimensions. + self.kernel_shape = layer.kernel.shape + shape = list(self.kernel_shape) + try: + d_model_dim_index = shape.index(max(shape)) + except ValueError: + raise TypeError( + f"Could not determine hidden dimension from shape {shape}" + ) + + if d_model_dim_index == 0: # QKV projection case + in_features, heads, head_dim = shape + self.rows, self.columns = in_features, heads * head_dim + elif d_model_dim_index in [1, 2]: # Attention Output case + heads, head_dim, out_features = shape + self.rows, self.columns = heads * head_dim, out_features + + # Create a temporary object that holds a reshaped + # 2D version of the kernel. + self.layer = type( + "temp", + (object,), + { + "kernel": ops.reshape( + layer.kernel, (self.rows, self.columns) + ), + "bias": layer.bias, + }, + )() + + else: + # Raise an error if the layer is not supported. + raise TypeError(f"Unsupported layer type for GPTQ: {type(layer)}") + self.H = ops.zeros((self.rows, self.rows), dtype="float32") + + def update_hessian_with_batch(self, inp): + """ + Updates the running average of the Hessian matrix with a new batch. + + This method computes the Hessian matrix for a given batch of input + activations and updates the accumulated Hessian (`self.H`) using a + numerically stable running average. This allows the Hessian to be + computed over a large dataset without loading all samples into memory + at once. + + The input tensor is first reshaped into a 2D matrix [num_samples, + num_features] before the Hessian is calculated. + + Args: + inp: A 2D or higher-dimensional tensor of input activations from a + calibration batch. + + Raises: + ValueError: If the feature dimension of the input tensor `inp` does + not match the dimensions of the pre-initialized Hessian matrix + `self.H`. + """ + if inp is None: + raise ValueError("Input tensor 'inp' cannot be None.") + + if len(inp.shape) < 2: + raise ValueError( + f"Input tensor 'inp' must have a rank of at least 2 " + f"(e.g., [batch, features]), but got rank {len(inp.shape)}." + ) + if ops.size(inp) == 0: + raise ValueError("Input tensor 'inp' cannot be empty.") + + if len(inp.shape) > 2: + inp = ops.reshape(inp, (-1, inp.shape[-1])) + inp = ops.cast(inp, "float32") + + if self.H.shape[0] != inp.shape[-1]: + raise ValueError( + f"Hessian dimensions ({self.H.shape[0]}) do not" + "match input features ({inp.shape[-1]})." + ) + + current_H = ops.multiply(2, ops.matmul(ops.transpose(inp), inp)) + + if self.nsamples == 0: + self.H = current_H + else: + total_samples = ops.add(self.nsamples, inp.shape[0]) + old_H_weight = ops.divide(self.nsamples, total_samples) + current_H_weight = ops.divide(inp.shape[0], total_samples) + + # Update the accumulated Hessian + term1 = ops.multiply(self.H, old_H_weight) + term2 = ops.multiply(current_H, current_H_weight) + self.H = ops.add(term1, term2) + + self.nsamples = ops.add(self.nsamples, inp.shape[0]) + + def quantize_and_correct_block( + self, blocksize=128, percdamp=0.01, group_size=-1, actorder=False + ): + """ + Performs GPTQ quantization and correction on the layer's weights. + + This method implements the core logic of the "Optimal Brain Quant" + (OBQ) method, as applied by GPTQ, to quantize the weights of a single + layer. It iteratively quantizes blocks of weights and corrects for the + quantization error by updating the remaining weights. + + The algorithm follows these main steps: + 1. **Initialization**: It optionally reorders the weight columns based + on activation magnitudes (`actorder=True`) to protect more salient + weights. + 2. **Hessian Modification**: The Hessian matrix `H`, pre-computed from + calibration data, is dampened to ensure its invertibility and + stability. + 3. **Iterative Quantization**: The function iterates through the + weight columns in blocks (`blocksize`). In each iteration, it: + a. Quantizes one column (`w`). + b. Calculates the quantization error (`err`). + c. Updates the remaining weights in the *current* block by + distributing the error, using the inverse Hessian (`Hinv`). + 4. **Block-wise Correction**: After a block is quantized, the total + error from that block is propagated to the *next* block of weights + to be processed. + 5. **Finalization**: The quantized weights (`Q`) are reordered back if + `actorder` was used, and the layer's weights are updated. + + This implementation is based on the official GPTQ paper and repository. + For more details, see: + - Paper: https://arxiv.org/abs/2210.17323 + - Original Code: https://github.com/IST-DASLab/gptq + + Args: + blocksize (int, optional): The size of the weight block to process + at a time. Defaults to 128. + percdamp (float, optional): The percentage of dampening to add the + Hessian's diagonal. A value of 0.01 is recommended. + Defaults to 0.01. + group_size (int, optional): The number of weights that share the + same quantization parameters (scale and zero-point). + A value of -1 indicates per-channel quantization. + actorder (bool, optional): If True, reorders weight columns based + on their activation's second-order information. + """ + + W = ops.transpose(ops.cast(self.layer.kernel, "float32")) + H = ops.cast(self.H, "float32") + + if actorder: + perm = ops.argsort(-ops.diagonal(H)) + W = ops.take(W, perm, axis=1) + H = ops.take(ops.take(H, perm, axis=0), perm, axis=1) + invperm = ops.argsort(perm) + + # Dampen the Hessian for Stability + diag_H = ops.diagonal(H) + dead = ops.equal(diag_H, 0.0) + diag_H = ops.where(dead, 1.0, diag_H) + H = ops.add(H, ops.diag(ops.where(dead, 1.0, ops.zeros_like(diag_H)))) + + # Add dampening factor to the Hessian diagonal + damp = ops.multiply(percdamp, ops.mean(diag_H)) + diag_H = ops.add(diag_H, damp) + H = ops.add( + ops.subtract(H, ops.diag(ops.diagonal(H))), ops.diag(diag_H) + ) + + # Compute the inverse Hessian, which is used for error correction + Hinv = ops.linalg.inv(H) + Q = ops.zeros_like(W) + + for i1 in range(0, self.rows, blocksize): + i2 = min(i1 + blocksize, self.rows) + count = i2 - i1 + # Extract the current block of weights and its corresponding + # Hessian + W1 = W[:, i1:i2] + Q1 = ops.zeros_like(W1) + Err1 = ops.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + # Process one column at a time within the block + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + if group_size != -1: + if (i1 + i) % group_size == 0: + self.quantizer.find_params( + W[:, (i1 + i) : (i1 + i + group_size)], weight=True + ) + else: + self.quantizer.find_params( + ops.expand_dims(w, 1), weight=True + ) + + # Quantize the current weight column + q = dequantize( + ops.expand_dims(w, 1), + self.quantizer.scale, + self.quantizer.zero, + self.quantizer.maxq, + )[:, 0] + + Q1 = ops.slice_update(Q1, (0, i), ops.expand_dims(q, axis=1)) + err = ops.divide(ops.subtract(w, q), d) + Err1 = ops.slice_update( + Err1, (0, i), ops.expand_dims(err, axis=1) + ) + + if i < count - 1: + update = ops.matmul( + ops.expand_dims(err, 1), + ops.expand_dims(Hinv1[i, i + 1 :], 0), + ) + + # Efficiently update the remaining part of the W1 tensor. + slice_to_update = W1[:, i + 1 :] + updated_slice = ops.subtract(slice_to_update, update) + W1 = ops.slice_update(W1, (0, i + 1), updated_slice) + + # Update the full quantized matrix Q with the processed block + Q = ops.concatenate([Q[:, :i1], Q1, Q[:, i2:]], axis=1) + + if i2 < self.rows: + update_total = ops.matmul(Err1, Hinv[i1:i2, i2:]) + W = ops.concatenate( + [W[:, :i2], ops.subtract(W[:, i2:], update_total)], axis=1 + ) + + if actorder: + Q = ops.take(Q, invperm, axis=1) + + Q = ops.transpose(Q) + + if isinstance(self.original_layer, EinsumDense): + Q = ops.reshape(Q, self.kernel_shape) + + # Set the new quantized weights in the original layer + new_weights = [ops.convert_to_numpy(Q)] + if self.original_layer.bias is not None: + new_weights.append(ops.convert_to_numpy(self.original_layer.bias)) + + self.original_layer.set_weights(new_weights) + + def free(self): + self.H = None diff --git a/keras/src/quantizers/gptq_config.py b/keras/src/quantizers/gptq_config.py new file mode 100644 index 000000000000..b6a256cb4bdb --- /dev/null +++ b/keras/src/quantizers/gptq_config.py @@ -0,0 +1,70 @@ +from absl import logging + +from keras.src.api_export import keras_export +from keras.src.quantizers.gptq_core import quantize_model + + +@keras_export(["keras.GPTQConfig", "keras.quantizers.GPTQConfig"]) +class GPTQConfig: + """Configuration class for the GPTQ algorithm. + + This class holds all the parameters needed to apply the GPTQ method + to a model. + + Args: + dataset: The calibration dataset. It can be an iterable that yields + strings or pre-tokenized numerical tensors (e.g., a list of + strings, a generator, or a NumPy array). This data is used to + analyze the model's activations. + tokenizer: A `keras_nlp.Tokenizer` instance (or a similar callable) + that is used to process the `dataset` if it contains strings. + wbits (int, optional): The number of bits to quantize weights to. + Defaults to 4. + nsamples (int, optional): The number of calibration data samples to + use from the dataset. Defaults to 128. + seqlen (int, optional): The sequence length to use for each calibration + sample. Defaults to 512. + percdamp (float, optional): The % of Hessian damping to use for + stabilization during inverse calculation. Defaults to 0.01. + group_size (int, optional): The size of weight groups to quantize + together. A `group_size` of -1 indicates per-channel quantization. + Defaults to 128. + symmetric (bool, optional): If `True`, uses symmetric quantization. + If `False`, uses asymmetric quantization. Defaults to `False`. + act_order (bool, optional): If `True`, reorders weight columns based on + activation magnitude, which can improve quantization accuracy. + Defaults to `False`. + """ + + def __init__( + self, + dataset, + tokenizer, + wbits: int = 4, + nsamples: int = 128, + seqlen: int = 512, + percdamp: float = 0.01, + group_size: int = 128, + symmetric: bool = False, + act_order: bool = False, + ): + self.dataset = dataset + self.tokenizer = tokenizer + self.nsamples = nsamples + self.seqlen = seqlen + self.percdamp = percdamp + self.wbits = wbits + self.group_size = group_size + self.symmetric = symmetric + self.act_order = act_order + self.quantization_method = "gptq" + + def quantize(self, model): + """ + Applies GPTQ quantization to the provided model using this + configuration. + """ + logging.info("Initiating quantization from GPTQConfig...") + # The core logic is now delegated to gptqutils, which will handle + # the dynamic imports and data loading. + quantize_model(model=model, config=self) diff --git a/keras/src/quantizers/gptq_core.py b/keras/src/quantizers/gptq_core.py new file mode 100644 index 000000000000..d2e36bb8d283 --- /dev/null +++ b/keras/src/quantizers/gptq_core.py @@ -0,0 +1,335 @@ +import random + +import numpy as np +from absl import logging + +from keras.src import ops +from keras.src import utils as keras_utils +from keras.src.layers import Dense +from keras.src.layers import EinsumDense +from keras.src.layers import Embedding +from keras.src.quantizers.gptq import GPTQ +from keras.src.quantizers.gptq_quant import GPTQQuant + + +def get_dataloader(tokenizer, seqlen, dataset, nsamples=128): + """ + Prepares and chunks the calibration dataloader, repeating short datasets. + """ + all_tokens = [] + + if isinstance(dataset, str): + raise TypeError( + "The `dataset` argument must be an iterable (e.g., a list of " + "strings or a generator). Providing a dataset name as a string " + "is not supported. Please pass the loaded dataset directly." + ) + + logging.info("Using pre-made dataset/generator...") + dataset_list = list(dataset) + + if not dataset_list: + raise ValueError("Provided dataset is empty.") + + if isinstance(dataset_list[0], str): + logging.info("(Dataset contains strings, tokenizing now...)") + full_text = "\n\n".join(dataset_list) + all_tokens = tokenizer.tokenize(full_text) + else: + logging.info("(Dataset is pre-tokenized, concatenating...)") + all_tokens = np.concatenate( + [ops.convert_to_numpy(s).reshape(-1) for s in dataset_list], axis=0 + ) + + all_tokens = np.array(all_tokens, dtype=np.int32) + + # Repeat data if it's too short + required_tokens = nsamples * seqlen + if len(all_tokens) < required_tokens: + logging.info( + f"Warning: Dataset is too short ({len(all_tokens)} tokens)." + " Repeating data to generate {nsamples} samples." + ) + repeats = -(-required_tokens // len(all_tokens)) # Ceiling division + all_tokens = np.tile(all_tokens, repeats) + + # Chunk the token list into samples + + calibration_samples = [] + for _ in range(nsamples): + # Generate a random starting index + start_index = random.randint(0, len(all_tokens) - seqlen - 1) + end_index = start_index + seqlen + sample = all_tokens[start_index:end_index] + calibration_samples.append(np.reshape(sample, (1, seqlen))) + + final_array = np.stack(calibration_samples, axis=0) + return final_array + + +def _find_layers_recursive(layer, prefix, found_layers): + """ + Recursively search for Dense and EinsumDense layers and record them. + """ + for sub_layer in layer._layers: + # Construct a unique name for the layer based on its hierarchy + layer_name = f"{prefix}.{sub_layer.name}" + if isinstance(sub_layer, (Dense, EinsumDense)): + found_layers[layer_name] = sub_layer + + # Recurse into nested layers that are not the target types + elif hasattr(sub_layer, "_layers") and sub_layer._layers: + _find_layers_recursive(sub_layer, layer_name, found_layers) + + +def find_layers_in_block(block): + """ + A pluggable, generic function to find all Dense and EinsumDense layers + within any transformer block by using a recursive search. + """ + found_layers = {} + # Start the recursive search from the block itself + _find_layers_recursive(block, "block", found_layers) + return found_layers + + +def apply_gptq_layerwise( + model, + dataloader, + nsamples, + percdamp, + group_size, + symmetric, + act_order, + wbits, +): + """Applies GPTQ quantization layer-by-layer to a Keras model. + + This function is designed to work with common transformer architectures, + like those provided by KerasNLP and KerasHub. It automatically discovers + the model's structure by first looking for the standard KerasNLP format: + a `model.backbone` attribute that contains a `transformer_layers` list. + + If a standard backbone is not found, it falls back to a heuristic for + custom models, where it assumes the first `keras.layers.Embedding` layer + is the input embedding and any subsequent container layers are the + transformer blocks to be quantized. + + The core logic operates as follows: + 1. It automatically detects the model's structure, identifying the main + embedding layer and a sequence of transformer blocks. + 2. It processes the model sequentially, one block at a time. For each + block, it uses temporary hooks to capture the input activations of + each target layer during a forward pass with the calibration data. + 3. These captured activations are used to compute the Hessian matrix for + each layer's weights. + 4. The GPTQ algorithm is then applied to each layer to find the optimal + quantized weights that minimize the error introduced. + 5. The output activations from the current block are then used as the + input for the next block, ensuring that quantization errors are + accounted for throughout the model. + + Args: + model: The Keras model instance to be quantized. The function will + attempt to automatically discover its structure. + dataloader: An iterable providing calibration data. Each item should + be a batch of token IDs suitable for the model's embedding layer. + nsamples (int): The number of samples from the dataloader to use for + calibration. + percdamp (float): The percentage of dampening to add to the Hessian + diagonal for stabilization during inverse calculation. A value of + 0.01 is common. + group_size (int): The size of the groups to use for quantization. A + value of 128 means that 128 weights will share the same scaling + factor. Use -1 for per-channel quantization. + symmetric (bool): If True, symmetric quantization is used. Otherwise, + asymmetric quantization is used. + act_order (bool): If True, reorders the weight columns based on + activation magnitude, which can improve quantization accuracy. + wbits (int): The number of bits to use for the quantized weights, + e.g., 4 for 4-bit quantization. + + Raises: + ValueError: If the function cannot automatically find an embedding + layer or any transformer-like blocks to quantize within the model. + """ + logging.info("Starting model quantization...") + embedding_layer = None + transformer_blocks = [] + if hasattr(model, "backbone"): + logging.info("Detected KerasNLP model structure.") + backbone = model.backbone + + # Add the check for the 'transformer_layers' attribute. + if hasattr(backbone, "transformer_layers"): + transformer_blocks = backbone.transformer_layers + else: + # Raise a specific error if the attribute is missing. + raise ValueError( + "The model's backbone does not have a 'transformer_layers' " + "attribute. Please ensure you are using a standard KerasNLP " + "transformer model." + ) + # Find the embedding layer by checking for common names or by type. + if hasattr(backbone, "token_embedding"): + embedding_layer = backbone.token_embedding + elif hasattr(backbone, "embedding"): + embedding_layer = backbone.embedding + else: + raise ValueError( + "Could not automatically find an embedding layer in the model." + ) + + else: + logging.info("Detected custom model structure.") + for layer in model.layers: + # The first Embedding layer found is assumed to be the main one. + if isinstance(layer, Embedding) and embedding_layer is None: + embedding_layer = layer + # A "block" is a container-like layer with its own sub-layers + # that we can quantize. This is a heuristic that works for the + # test. + elif hasattr(layer, "_layers") and layer._layers: + transformer_blocks.append(layer) + + if embedding_layer is None: + raise ValueError( + "Could not automatically find an embedding layer in the model." + ) + if not transformer_blocks: + raise ValueError( + "Could not automatically find any transformer-like blocks to " + "quantize." + ) + + # Initial inputs are the outputs of the token embedding layer + inputs = [ + embedding_layer(ops.convert_to_tensor(batch, dtype="int32")) + for batch in dataloader + ] + progbar = keras_utils.Progbar(target=len(transformer_blocks)) + + for i, block in enumerate(transformer_blocks): + logging.info(f"Quantizing Block {i}") + sub_layers_map = find_layers_in_block(block) + + if not sub_layers_map: + logging.info( + f" No Dense or EinsumDense layers found in block {i}. " + "Skipping." + ) + else: + logging.info(f"Found layers: {list(sub_layers_map.keys())}") + gptq_objects = { + name: GPTQ(layer) for name, layer in sub_layers_map.items() + } + + captured_inputs = {name: [] for name in sub_layers_map.keys()} + original_calls = {} + + def create_hook(name, original_call_func): + """A factory for creating a hook to capture layer inputs.""" + + def hook(*args, **kwargs): + if args: + inp = args[0] + else: + inp = kwargs["inputs"] + captured_inputs[name].append(inp) + return original_call_func(*args, **kwargs) + + return hook + + try: + for name, layer in sub_layers_map.items(): + original_call = layer.call + original_calls[name] = original_call + layer.call = create_hook(name, original_call) + + logging.info(f"Capturing activations for block {i}...") + for j in range(nsamples): + current_input = inputs[j] + if len(current_input.shape) == 2: + current_input = ops.expand_dims(current_input, axis=0) + _ = block(current_input) + + finally: + for name, layer in sub_layers_map.items(): + if name in original_calls: + layer.call = original_calls[name] + + logging.info(f"Building Hessians for block {i}...") + for name, gptq_object in gptq_objects.items(): + layer_inputs = ops.concatenate(captured_inputs[name], axis=0) + + # Explicitly reshape the input tensor to be 2D, with the second + # dimension matching the number of input features expected by + # the layer's kernel. + # This correctly handles inputs of any dimensionality + # (e.g., 3D or 4D). + num_features = gptq_object.rows + inp_reshaped = ops.reshape(layer_inputs, (-1, num_features)) + gptq_object.update_hessian_with_batch(inp_reshaped) + + quantizer = GPTQQuant( + wbits, + perchannel=True, + symmetric=symmetric, + group_size=group_size, + ) + for name, gptq_object in gptq_objects.items(): + logging.info(f"Quantizing {name}...") + gptq_object.quantizer = quantizer + gptq_object.quantize_and_correct_block( + percdamp=percdamp, group_size=group_size, actorder=act_order + ) + gptq_object.free() + + del gptq_objects, captured_inputs, original_calls + + if i < len(transformer_blocks) - 1: + logging.info(f"Generating inputs for block {i + 1}...") + next_block_inputs = [] + for j in range(nsamples): + current_input = inputs[j] + if len(current_input.shape) == 2: + current_input = ops.expand_dims(current_input, axis=0) + output = block(current_input)[0] + next_block_inputs.append(output) + inputs = next_block_inputs + progbar.update(current=i + 1) + + logging.info("Quantization process complete.") + + +def quantize_model(model, config): + """ + Top-level function to quantize a Keras model using GPTQ. + """ + logging.info("Starting GPTQ quantization process...") + + # Load ALL data needed from the generator/source in a single call. + total_samples_to_request = config.nsamples + full_dataloader = get_dataloader( + config.tokenizer, + config.seqlen, + config.dataset, + nsamples=total_samples_to_request, + ) + + # Split the materialized data. This works because full_dataloader + # is now a NumPy array, which can be sliced and reused. + calibration_dataloader = full_dataloader[: config.nsamples] + + apply_gptq_layerwise( + model, + calibration_dataloader, # Use the calibration slice + len(calibration_dataloader), # Use the actual number of samples + config.percdamp, + config.group_size, + config.symmetric, + config.act_order, + config.wbits, + ) + + return diff --git a/keras/src/quantizers/gptq_core_test.py b/keras/src/quantizers/gptq_core_test.py new file mode 100644 index 000000000000..424331c7f903 --- /dev/null +++ b/keras/src/quantizers/gptq_core_test.py @@ -0,0 +1,160 @@ +import pytest +from absl import logging + +from keras.src import layers +from keras.src import models +from keras.src.quantizers import gptq_core +from keras.src.quantizers.gptq_config import GPTQConfig + +VOCAB_SIZE = 100 + + +class MockTokenizer: + """A mock tokenizer that mimics the real API for testing.""" + + def tokenize(self, text): + return [ord(c) % VOCAB_SIZE for c in "".join(text)] + + def __call__(self, text): + return self.tokenize(text) + + +class MockEmptyBlock(layers.Layer): + """A mock block that contains no quantizable layers.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.ln = layers.LayerNormalization() + + def call(self, inputs): + return self.ln(inputs) + + +class MockTransformerBlock(layers.Layer): + """A mock transformer block with a quantizable Dense layer.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.dense = layers.Dense(128) + + def call(self, inputs): + return self.dense(inputs) + + +def _get_model_with_backbone( + has_transformer_layers=True, embedding_name="embedding" +): + """Creates a mock KerasNLP-style model with a backbone.""" + + class MockBackbone(layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + if has_transformer_layers: + self.transformer_layers = [MockTransformerBlock()] + setattr(self, embedding_name, layers.Embedding(VOCAB_SIZE, 128)) + + class MockModel(models.Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.backbone = MockBackbone() + + def call(self, inputs): + return self.backbone(inputs) + + model = MockModel() + model.build(input_shape=(None, 10)) + return model + + +@pytest.mark.requires_trainable_backend +class TestGPTQCore: + def test_get_dataloader_error_scenarios(self): + """Tests error cases for get_dataloader.""" + with pytest.raises(ValueError, match="Provided dataset is empty"): + gptq_core.get_dataloader( + tokenizer=MockTokenizer(), seqlen=10, dataset=[], nsamples=10 + ) + with pytest.raises( + TypeError, + match="Providing a dataset name as a " + "string is not supported. Please pass the " + "loaded dataset directly.", + ): + gptq_core.get_dataloader( + tokenizer=MockTokenizer(), + seqlen=10, + dataset="wikitext2", + nsamples=10, + ) + + def test_apply_gptq_on_multi_block_model(self): + """Tests quantization on a model with multiple blocks.""" + model = models.Sequential( + [ + layers.Embedding(VOCAB_SIZE, 128), + MockTransformerBlock(), + MockTransformerBlock(), + ] + ) + model.build(input_shape=(None, 10)) + config = GPTQConfig( + dataset=["test data"], tokenizer=MockTokenizer(), group_size=32 + ) + try: + model.quantize("gptq", config=config) + except Exception as e: + pytest.fail(f"Multi-block quantization failed unexpectedly: {e}") + + def test_apply_gptq_with_empty_block(self, caplog): + """Tests that a block with no quantizable layers is skipped + correctly.""" + caplog.set_level(logging.INFO) + model = models.Sequential( + [layers.Embedding(VOCAB_SIZE, 10), MockEmptyBlock()] + ) + model.build(input_shape=(None, 10)) + config = GPTQConfig(dataset=["test data"], tokenizer=MockTokenizer()) + model.quantize("gptq", config=config) + assert "No Dense or EinsumDense layers found" in caplog.text + + architecture_test_cases = [ + ( + models.Sequential([layers.Dense(10)]), + "Could not automatically find an embedding layer", + "no_embedding_layer", + ), + ( + models.Sequential( + [layers.Embedding(VOCAB_SIZE, 10), layers.Dense(10)] + ), + "Could not automatically find any transformer-like blocks", + "no_transformer_blocks", + ), + ( + _get_model_with_backbone(has_transformer_layers=False), + "backbone does not have a 'transformer_layers' attribute", + "backbone_no_layers", + ), + ( + _get_model_with_backbone(embedding_name="wrong_name"), + "Could not automatically find an embedding layer in the model", + "backbone_no_embedding", + ), + ] + + @pytest.mark.parametrize( + "model, match_message, test_id", + architecture_test_cases, + ids=[case[-1] for case in architecture_test_cases], + ) + def test_apply_gptq_with_unsupported_architectures( + self, model, match_message, test_id + ): + """Tests that quantize fails correctly for various unsupported + model architectures.""" + if not model.built: + model.build(input_shape=(None, 10)) + + config = GPTQConfig(dataset=["test"], tokenizer=MockTokenizer()) + with pytest.raises(ValueError, match=match_message): + model.quantize("gptq", config=config) diff --git a/keras/src/quantizers/gptq_quant.py b/keras/src/quantizers/gptq_quant.py new file mode 100644 index 000000000000..d3d76a4eb3b2 --- /dev/null +++ b/keras/src/quantizers/gptq_quant.py @@ -0,0 +1,126 @@ +from keras.src import ops + + +def dequantize(x, scale, zero, maxq): + """The core quantization function.""" + epsilon = ops.cast(1e-8, dtype=scale.dtype) + scale = ops.where(ops.equal(scale, 0), epsilon, scale) + + quantized_x = ops.divide(x, scale) + quantized_x = ops.round(quantized_x) + q = ops.add(quantized_x, zero) + q = ops.clip(q, 0, maxq) + + dequantized_x = ops.subtract(q, zero) + return ops.multiply(scale, dequantized_x) + + +class GPTQQuant: + """Initializes the GPTQQuant state. + + Args: + shape (int, optional): This argument is currently unused. + Defaults to 1. + + Attributes: + scale (tensor, optional): The quantization scaling factor(s). This + is computed during the calibration process. Defaults to `None`. + zero (tensor, optional): The quantization zero-point(s). This is + computed during the calibration process. Defaults to `None`. + maxq (tensor, optional): The maximum integer value for the + quantized weights (e.g., 15 for 4-bit quantization). + Defaults to `None`. + wbits (int, optional): The number of bits to quantize to (e.g., 4). + Defaults to `None`. + perchannel (bool): A flag indicating whether quantization is + applied per-channel (`True`) or per-tensor (`False`). + Defaults to `False`. + symmetric (bool): A flag indicating whether symmetric (`True`) or + asymmetric (`False`) quantization is used. Defaults to `False`. + group_size (int): The size of weight groups for quantization. A + value of -1 indicates that grouping is not used. + Defaults to -1. + """ + + def __init__(self, wbits, perchannel=True, symmetric=False, group_size=-1): + self.wbits = wbits + self.maxq = ops.cast((2**wbits) - 1, "float32") + self.perchannel = perchannel + self.symmetric = symmetric + self.group_size = group_size + + # These are now determined later by `find_params` + self.scale = None + self.zero = None + + def find_params(self, x, weight=False): + """Finds quantization parameters (scale and zero) for a given tensor.""" + + if x is None: + raise ValueError("Input tensor 'x' cannot be None.") + + # For weights, we typically expect at least a 2D tensor. + if weight and len(x.shape) < 2: + raise ValueError( + f"Input weight tensor 'x' must have a rank of at least 2, " + f"but got rank {len(x.shape)}." + ) + + if ops.size(x) == 0: + raise ValueError("Input tensor 'x' cannot be empty.") + + original_shape = x.shape + + if self.perchannel: + if weight: + if self.group_size != -1: + x_reshaped = ops.reshape(x, [-1, self.group_size]) + else: + x_reshaped = ops.reshape(x, [original_shape[0], -1]) + else: # per-tensor + x_reshaped = ops.reshape(x, [1, -1]) + + # Find min/max values + xmin = ops.min(x_reshaped, axis=1) + xmax = ops.max(x_reshaped, axis=1) + + # Apply symmetric quantization logic if enabled + if self.symmetric: + xmax = ops.maximum(ops.abs(xmin), xmax) + xmin = ops.where(ops.less(xmin, 0), -xmax, xmin) + + # Ensure range is not zero to avoid division errors + tmp = ops.equal(xmin, xmax) + xmin = ops.where(tmp, xmin - 1, xmin) + xmax = ops.where(tmp, xmax + 1, xmax) + + # Calculate scale and zero-point + self.scale = (xmax - xmin) / self.maxq + if self.symmetric: + self.zero = ops.full_like(self.scale, (self.maxq + 1) / 2) + else: + self.zero = ops.round(-xmin / self.scale) + + # Ensure scale is non-zero + self.scale = ops.where(ops.less_equal(self.scale, 0), 1e-8, self.scale) + + if weight: + # Per-channel, non-grouped case: simple reshape is correct. + if self.perchannel and self.group_size == -1: + self.scale = ops.reshape(self.scale, [-1, 1]) + self.zero = ops.reshape(self.zero, [-1, 1]) + elif not self.perchannel: + num_rows = original_shape[0] + self.scale = ops.tile( + ops.reshape(self.scale, (1, 1)), (num_rows, 1) + ) + self.zero = ops.tile( + ops.reshape(self.zero, (1, 1)), (num_rows, 1) + ) + if self.perchannel: + self.scale = ops.reshape(self.scale, [-1, 1]) + self.zero = ops.reshape(self.zero, [-1, 1]) + + def ready(self): + """Checks if the quantization parameters have been computed.""" + return self.scale is not None and self.zero is not None diff --git a/keras/src/quantizers/gptq_test.py b/keras/src/quantizers/gptq_test.py new file mode 100644 index 000000000000..568a4af9e6d4 --- /dev/null +++ b/keras/src/quantizers/gptq_test.py @@ -0,0 +1,101 @@ +import numpy as np +import pytest + +from keras.src import layers +from keras.src import ops +from keras.src import testing +from keras.src.quantizers.gptq import GPTQ +from keras.src.quantizers.gptq_quant import GPTQQuant + + +def _get_mock_layer(layer_type, kernel_shape, rng): + if layer_type == "Dense": + layer = layers.Dense(units=kernel_shape[1]) + layer.build(input_shape=(None, kernel_shape[0])) + elif layer_type == "EinsumDense": + output_shape = (kernel_shape[1], kernel_shape[2]) + layer = layers.EinsumDense( + equation="...h,hio->...io", output_shape=output_shape + ) + dummy_input = rng.standard_normal(size=(1, 1, kernel_shape[0])) + layer(dummy_input) + layer.kernel.assign( + rng.standard_normal(size=kernel_shape).astype("float32") + ) + else: + layer = layers.Layer() + return layer + + +@pytest.mark.requires_trainable_backend +class GPTQTest(testing.TestCase): + def test_initialization_with_dense_layer(self): + rng = np.random.default_rng(seed=42) + + mock_layer = _get_mock_layer("Dense", kernel_shape=(64, 128), rng=rng) + + gptq_instance = GPTQ(mock_layer) + self.assertEqual(gptq_instance.rows, 64) + self.assertEqual(gptq_instance.columns, 128) + self.assertEqual(gptq_instance.H.shape, (64, 64)) + + def test_initialization_with_einsumdense_3d(self): + rng = np.random.default_rng(seed=42) + mock_layer = _get_mock_layer( + "EinsumDense", kernel_shape=(64, 4, 32), rng=rng + ) + gptq_instance = GPTQ(mock_layer) + self.assertEqual(gptq_instance.rows, 64) + self.assertEqual(gptq_instance.columns, 4 * 32) + self.assertEqual(gptq_instance.H.shape, (64, 64)) + + def test_update_hessian(self): + rng = np.random.default_rng(seed=42) + mock_layer = _get_mock_layer("Dense", kernel_shape=(16, 32), rng=rng) + gptq_instance = GPTQ(mock_layer) + batch1 = rng.standard_normal(size=(8, 16)).astype("float32") + gptq_instance.update_hessian_with_batch(batch1) + self.assertEqual(gptq_instance.nsamples, 8) + H1 = np.copy(ops.convert_to_numpy(gptq_instance.H)) + batch2 = rng.standard_normal(size=(4, 16)).astype("float32") + gptq_instance.update_hessian_with_batch(batch2) + self.assertEqual(gptq_instance.nsamples, 12) + H2 = np.copy(ops.convert_to_numpy(gptq_instance.H)) + self.assertFalse(np.allclose(H1, H2)) + + def test_full_quantization_process(self): + rng = np.random.default_rng(seed=42) + mock_layer = _get_mock_layer("Dense", kernel_shape=(16, 32), rng=rng) + original_weights = np.copy(ops.convert_to_numpy(mock_layer.kernel)) + + gptq_instance = GPTQ(mock_layer) + gptq_instance.quantizer = GPTQQuant(wbits=4, symmetric=False) + calibration_data = rng.standard_normal(size=(128, 16)).astype("float32") + gptq_instance.update_hessian_with_batch(calibration_data) + gptq_instance.quantize_and_correct_block() + + quantized_weights = ops.convert_to_numpy(mock_layer.kernel) + self.assertFalse(np.allclose(original_weights, quantized_weights)) + + gptq_instance.free() + self.assertIsNone(gptq_instance.H) + + def test_unsupported_layer_error(self): + rng = np.random.default_rng(seed=42) + unsupported_layer = _get_mock_layer( + "Unsupported", kernel_shape=None, rng=rng + ) + with self.assertRaisesRegex(TypeError, "Unsupported layer type"): + GPTQ(unsupported_layer) + + def test_update_hessian_invalid_input(self): + rng = np.random.default_rng(seed=42) + mock_layer = _get_mock_layer("Dense", kernel_shape=(16, 32), rng=rng) + gptq_instance = GPTQ(mock_layer) + with self.assertRaisesRegex(ValueError, "cannot be None"): + gptq_instance.update_hessian_with_batch(None) + with self.assertRaisesRegex(ValueError, "cannot be empty"): + gptq_instance.update_hessian_with_batch(np.empty((0, 16))) + with self.assertRaisesRegex(ValueError, "match input features"): + bad_input = rng.standard_normal(size=(8, 99)) + gptq_instance.update_hessian_with_batch(bad_input)