Skip to content

Commit b12b2af

Browse files
Adds 4-bit packing support to GPTQ Quantization (#21686)
* updates int4 packing logic to handle both int8 and uint8 * adds gptq int4 packing support to einsum dense * adds gptq int4 packing support to einsum dense * Adds tests * cleanup * fix bugz * kernel property should return unpacked kernel * fixes dtype assertion for torch * addresses reviews
1 parent 49797f2 commit b12b2af

File tree

9 files changed

+405
-192
lines changed

9 files changed

+405
-192
lines changed

keras/src/layers/core/dense.py

Lines changed: 71 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,8 @@
99
from keras.src import quantizers
1010
from keras.src import regularizers
1111
from keras.src.api_export import keras_export
12-
from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy
13-
from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
1412
from keras.src.layers.input_spec import InputSpec
1513
from keras.src.layers.layer import Layer
16-
from keras.src.quantizers.gptq_config import GPTQConfig
1714
from keras.src.quantizers.quantizers import dequantize_with_sz_map
1815

1916

@@ -143,22 +140,47 @@ def build(self, input_shape):
143140

144141
@property
145142
def kernel(self):
143+
from keras.src.quantizers import gptq_core
144+
146145
if not self.built:
147146
raise AttributeError(
148147
"You must build the layer before accessing `kernel`."
149148
)
150-
if (
151-
getattr(self, "is_gptq_calibrated", False)
152-
and self.quantization_mode == "gptq"
153-
):
154-
return self.quantized_kernel
155-
kernel = self._kernel
156-
if self.quantization_mode == "int4":
157-
kernel = quantizers.unpack_int4(kernel, self._orig_input_dim)
149+
150+
mode = self.quantization_mode
151+
is_gptq = mode == "gptq"
152+
is_int4 = mode == "int4"
153+
calibrated = bool(getattr(self, "is_gptq_calibrated", False))
154+
gptq_bits = (
155+
gptq_core.get_weight_bits_for_layer(self, None) if is_gptq else None
156+
)
157+
158+
# Decide the source tensor first (packed vs already-quantized vs plain
159+
# kernel)
160+
if is_gptq and calibrated and gptq_bits != 4:
161+
# calibrated GPTQ, not 4-bit, no unpacking needed
162+
kernel = self.quantized_kernel
163+
else:
164+
# Start with the stored kernel
165+
kernel = getattr(self, "_kernel", None)
166+
167+
# Handle int4 unpacking cases in one place
168+
if is_int4:
169+
kernel = quantizers.unpack_int4(kernel, self._orig_input_dim)
170+
elif is_gptq and calibrated and gptq_bits == 4:
171+
kernel = quantizers.unpack_int4(
172+
self.quantized_kernel,
173+
orig_len=self.units,
174+
axis=0,
175+
dtype="uint8",
176+
)
177+
178+
# Apply LoRA once at the end.
158179
if self.lora_enabled:
159-
return kernel + (self.lora_alpha / self.lora_rank) * ops.matmul(
180+
kernel = kernel + (self.lora_alpha / self.lora_rank) * ops.matmul(
160181
self.lora_kernel_a, self.lora_kernel_b
161182
)
183+
162184
return kernel
163185

164186
def call(self, inputs, training=None):
@@ -414,23 +436,33 @@ def _int8_build(self, kernel_shape):
414436
)
415437

416438
def _gptq_build(self, kernel_shape, config):
439+
from keras.src.quantizers import gptq_core
440+
417441
# Ensures the forward pass uses the original high-precision kernel
418442
# until calibration has been performed.
419443
self.is_gptq_calibrated = False
420444
self.kernel_shape = kernel_shape
445+
446+
weight_bits = gptq_core.get_weight_bits_for_layer(self, config)
447+
# For 4-bit weights, we pack two values per byte.
448+
units = (
449+
(kernel_shape[1] + 1) // 2 if weight_bits == 4 else kernel_shape[1]
450+
)
451+
421452
self.quantized_kernel = self.add_weight(
422453
name="kernel",
423-
shape=(kernel_shape[1], kernel_shape[0]),
454+
shape=(units, kernel_shape[0]),
424455
initializer="zeros",
425456
dtype="uint8",
426457
trainable=False,
427458
)
428459

429-
group_size = self._get_gptq_group_size(config)
430-
if group_size == -1:
431-
n_groups = 1
432-
else:
433-
n_groups = math.ceil(self.kernel_shape[0] / group_size)
460+
group_size = gptq_core.get_group_size_for_layer(self, config)
461+
n_groups = (
462+
1
463+
if group_size == -1
464+
else math.ceil(self.kernel_shape[0] / group_size)
465+
)
434466
self.kernel_scale = self.add_weight(
435467
name="kernel_scale",
436468
shape=(self.units, n_groups),
@@ -453,18 +485,31 @@ def _gptq_build(self, kernel_shape, config):
453485
)
454486

455487
def _gptq_call(self, inputs, training=False):
488+
from keras.src.quantizers import gptq_core
489+
456490
if not self.is_gptq_calibrated:
457491
W = self._kernel
458492
else:
493+
should_unpack = (
494+
gptq_core.get_weight_bits_for_layer(self, config=None) == 4
495+
)
459496
W = (
460-
ops.transpose(
461-
dequantize_with_sz_map(
462-
self.quantized_kernel,
463-
self.kernel_scale,
464-
self.kernel_zero,
465-
self.g_idx,
466-
)
467-
),
497+
quantizers.unpack_int4(
498+
self.quantized_kernel,
499+
orig_len=self.units,
500+
axis=0,
501+
dtype="uint8",
502+
)
503+
if should_unpack
504+
else self.quantized_kernel
505+
)
506+
W = ops.transpose(
507+
dequantize_with_sz_map(
508+
W,
509+
self.kernel_scale,
510+
self.kernel_zero,
511+
self.g_idx,
512+
)
468513
)
469514

470515
y = ops.matmul(inputs, W)
@@ -875,43 +920,3 @@ def _get_kernel_with_merged_lora(self):
875920
else:
876921
kernel_value = requantized_kernel
877922
return kernel_value, kernel_scale
878-
879-
def _get_gptq_group_size(self, config):
880-
"""Determine the group size for GPTQ quantization.
881-
882-
The group size can be specified either through the `config` argument
883-
or through the `dtype_policy` if it is of type `GPTQDTypePolicy`.
884-
885-
The config argument is usually available when quantizing the layer
886-
via the `quantize` method. If the layer was deserialized from a
887-
saved model, the group size should be specified in the `dtype_policy`.
888-
889-
Args:
890-
config: An optional configuration object that may contain the
891-
`group_size` attribute.
892-
Returns:
893-
int. The determined group size for GPTQ quantization.
894-
Raises:
895-
ValueError: If the group size is not specified in either the
896-
`config` or the `dtype_policy`.
897-
"""
898-
if config and isinstance(config, GPTQConfig):
899-
return config.group_size
900-
elif isinstance(self.dtype_policy, GPTQDTypePolicy):
901-
return self.dtype_policy.group_size
902-
elif isinstance(self.dtype_policy, DTypePolicyMap):
903-
policy = self.dtype_policy[self.path]
904-
if not isinstance(policy, GPTQDTypePolicy):
905-
# This should never happen based on how we set the
906-
# quantization mode, but we check just in case.
907-
raise ValueError(
908-
"Expected a `dtype_policy` of type `GPTQDTypePolicy`."
909-
f"Got: {type(policy)}"
910-
)
911-
return policy.group_size
912-
else:
913-
raise ValueError(
914-
"For GPTQ quantization, the group_size must be specified"
915-
"either through a `dtype_policy` of type "
916-
"`GPTQDTypePolicy` or the `config` argument."
917-
)

keras/src/layers/core/dense_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,3 +898,37 @@ def test_legacy_load_own_variables(self):
898898
self.assertAllClose(layer.kernel_amax_history, float8_store["5"])
899899
self.assertAllClose(layer.outputs_grad_scale, float8_store["6"])
900900
self.assertAllClose(layer.outputs_grad_amax_history, float8_store["7"])
901+
902+
def test_int4_gptq_kernel_returns_unpacked_form(self):
903+
"""Test that the `kernel` property returns the unpacked int4 GPTQ
904+
kernel."""
905+
layer = layers.Dense(units=2)
906+
layer.build((None, 2))
907+
layer.quantize(
908+
"gptq",
909+
config=GPTQConfig(
910+
dataset=None, tokenizer=None, weight_bits=4, group_size=8
911+
),
912+
)
913+
layer.is_gptq_calibrated = True # Bypass calibration check
914+
packed_kernel = layer.quantized_kernel
915+
self.assertAllClose(
916+
layer.kernel, quantizers.unpack_int4(packed_kernel, 2)
917+
)
918+
919+
def test_gptq_kernel_packing(self):
920+
"""Validates that 4-bit GPTQ packing reduces the kernel size."""
921+
layer = layers.Dense(units=16, use_bias=False)
922+
layer.build((None, 8))
923+
924+
original_kernel_params = ops.prod(layer._kernel.shape)
925+
926+
layer.quantize(
927+
"gptq",
928+
config=GPTQConfig(
929+
dataset=None, tokenizer=None, weight_bits=4, group_size=8
930+
),
931+
)
932+
933+
quantized_kernel_params = ops.prod(layer.quantized_kernel.shape)
934+
self.assertEqual(quantized_kernel_params, original_kernel_params // 2)

0 commit comments

Comments
 (0)