Skip to content

Commit 89d953e

Browse files
Adds int4 Quantization Support (#21435)
* int4 quantization support * refactor packing utils into quantizers * generalize int4 packing * restored pytest skip conditions * fixes 'tuple' object has no attribute 'rank' error * fix dtype check to work across backends * fixed torch compatibility * fixed jax compatibility * removes redundant self._orig_input_dim initialization * improves readability * W4A8 * added _int4_call stub * Fix bug in unpack that promoted tensor to fp32 * add missing dtype assertion to quantizer test * docstring fixes * docstring fixes * introduces fastpath for dense unpack * handle negative axis for pack/unpack * standardize docs formatting * fix docstring format * Reduce duplication in _get_kernel_with_merged_lora * remove unnecessary cast ops * removes unused var
1 parent e233825 commit 89d953e

File tree

9 files changed

+714
-21
lines changed

9 files changed

+714
-21
lines changed

keras/api/_tf_keras/keras/quantizers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from keras.src.quantizers.quantizers import (
2020
fake_quant_with_min_max_vars as fake_quant_with_min_max_vars,
2121
)
22+
from keras.src.quantizers.quantizers import pack_int4 as pack_int4
2223
from keras.src.quantizers.quantizers import (
2324
quantize_and_dequantize as quantize_and_dequantize,
2425
)
26+
from keras.src.quantizers.quantizers import unpack_int4 as unpack_int4

keras/api/quantizers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from keras.src.quantizers.quantizers import (
2020
fake_quant_with_min_max_vars as fake_quant_with_min_max_vars,
2121
)
22+
from keras.src.quantizers.quantizers import pack_int4 as pack_int4
2223
from keras.src.quantizers.quantizers import (
2324
quantize_and_dequantize as quantize_and_dequantize,
2425
)
26+
from keras.src.quantizers.quantizers import unpack_int4 as unpack_int4

keras/src/dtype_policies/dtype_policy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from keras.src.api_export import keras_export
44
from keras.src.backend.common import global_state
55

6-
QUANTIZATION_MODES = ("int8", "float8")
6+
QUANTIZATION_MODES = ("int8", "float8", "int4")
77

88

99
@keras_export(
@@ -350,7 +350,7 @@ def _get_quantized_dtype_policy_by_str(policy):
350350
f"Received: policy={policy}"
351351
)
352352
mode, source_name = split_name
353-
if policy.startswith("int8"):
353+
if policy.startswith("int8") or policy.startswith("int4"):
354354
return QuantizedDTypePolicy(mode, source_name)
355355
elif policy.startswith("float8"):
356356
return QuantizedFloat8DTypePolicy(mode, source_name)

keras/src/layers/core/dense.py

Lines changed: 215 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from keras.src import activations
44
from keras.src import constraints
5-
from keras.src import dtype_policies
65
from keras.src import initializers
76
from keras.src import ops
87
from keras.src import quantizers
@@ -110,9 +109,10 @@ def build(self, input_shape):
110109
kernel_shape = (input_shape[-1], self.units)
111110
if self.quantization_mode:
112111
self.quantized_build(kernel_shape, mode=self.quantization_mode)
113-
if self.quantization_mode != "int8":
114-
# If the layer is quantized to int8, `self._kernel` will be added
115-
# in `self._int8_build`. Therefore, we skip it here.
112+
if self.quantization_mode not in ("int8", "int4"):
113+
# If the layer is quantized to int8 or int4, `self._kernel` will be
114+
# added in `self._int8_build` or `_int4_build`. Therefore, we skip
115+
# it here.
116116
self._kernel = self.add_weight(
117117
name="kernel",
118118
shape=kernel_shape,
@@ -182,9 +182,22 @@ def enable_lora(
182182
"lora is already enabled. This can only be done once per layer."
183183
)
184184
self._tracker.unlock()
185+
# Determine the correct input dimension for the LoRA A matrix. When
186+
# the layer has been int4-quantized, `self._kernel` stores a *packed*
187+
# representation whose first dimension is `ceil(input_dim/2)`. We
188+
# saved the true, *unpacked* input dimension in `self._orig_input_dim`
189+
# during quantization. Use it if available; otherwise fall back to the
190+
# first dimension of `self.kernel`.
191+
if self.quantization_mode == "int4" and hasattr(
192+
self, "_orig_input_dim"
193+
):
194+
input_dim_for_lora = self._orig_input_dim
195+
else:
196+
input_dim_for_lora = self.kernel.shape[0]
197+
185198
self.lora_kernel_a = self.add_weight(
186199
name="lora_kernel_a",
187-
shape=(self.kernel.shape[0], rank),
200+
shape=(input_dim_for_lora, rank),
188201
initializer=initializers.get(a_initializer),
189202
regularizer=self.kernel_regularizer,
190203
)
@@ -211,7 +224,7 @@ def save_own_variables(self, store):
211224
if self.use_bias:
212225
target_variables.append(self.bias)
213226
if self.quantization_mode is not None:
214-
if self.quantization_mode == "int8":
227+
if self.quantization_mode in ("int8", "int4"):
215228
target_variables.append(kernel_scale)
216229
elif self.quantization_mode == "float8":
217230
target_variables.append(self.inputs_scale)
@@ -237,7 +250,7 @@ def load_own_variables(self, store):
237250
if self.use_bias:
238251
target_variables.append(self.bias)
239252
if self.quantization_mode is not None:
240-
if self.quantization_mode == "int8":
253+
if self.quantization_mode in ("int8", "int4"):
241254
target_variables.append(self.kernel_scale)
242255
elif self.quantization_mode == "float8":
243256
target_variables.append(self.inputs_scale)
@@ -315,6 +328,8 @@ def _check_load_own_variables(self, store):
315328
def quantized_build(self, kernel_shape, mode):
316329
if mode == "int8":
317330
self._int8_build(kernel_shape)
331+
elif mode == "int4":
332+
self._int4_build(kernel_shape)
318333
elif mode == "float8":
319334
self._float8_build()
320335
else:
@@ -337,6 +352,39 @@ def _int8_build(self, kernel_shape):
337352
trainable=False,
338353
)
339354

355+
def _int4_build(self, kernel_shape):
356+
"""Build variables for int4 quantization.
357+
358+
`kernel_shape` is the *original* float32 kernel shape
359+
`(input_dim, units)`. We allocate the stored kernel with rows
360+
`ceil(input_dim/2)` because two int4 values are packed into a single
361+
int8 byte.
362+
"""
363+
# Per-channel int8 quantizer for the last axis (features).
364+
self.inputs_quantizer = quantizers.AbsMaxQuantizer(
365+
axis=-1,
366+
)
367+
input_dim, output_dim = kernel_shape
368+
packed_rows = (input_dim + 1) // 2 # ceil for odd dims
369+
370+
# Kernel is stored *packed*: each int8 byte contains two int4 values.
371+
self._kernel = self.add_weight(
372+
name="kernel",
373+
shape=(packed_rows, output_dim),
374+
initializer="zeros",
375+
dtype="int8",
376+
trainable=False,
377+
)
378+
# One scale per output unit (per-channel).
379+
self.kernel_scale = self.add_weight(
380+
name="kernel_scale",
381+
shape=(self.units,),
382+
initializer="ones",
383+
trainable=False,
384+
)
385+
# Record original input_dim for unpacking at runtime.
386+
self._orig_input_dim = input_dim
387+
340388
def _float8_build(self):
341389
from keras.src.dtype_policies import QuantizedFloat8DTypePolicy
342390

@@ -383,6 +431,16 @@ def _float8_build(self):
383431
def _int8_call(self, inputs, training=None):
384432
@ops.custom_gradient
385433
def matmul_with_inputs_gradient(inputs, kernel, kernel_scale):
434+
"""Custom gradient function to handle the int8 quantized weights.
435+
436+
Automatic differentiation will not know how to handle the int8
437+
quantized weights. So a custom gradient function is needed to
438+
handle the int8 quantized weights.
439+
440+
The custom gradient function will use the dequantized kernel to
441+
compute the gradient.
442+
"""
443+
386444
def grad_fn(*args, upstream=None):
387445
if upstream is None:
388446
(upstream,) = args
@@ -415,6 +473,59 @@ def grad_fn(*args, upstream=None):
415473
x = self.activation(x)
416474
return x
417475

476+
def _int4_call(self, inputs, training=None):
477+
"""Forward pass for int4 quantized Dense layer."""
478+
479+
@ops.custom_gradient
480+
def matmul_with_inputs_gradient(inputs, kernel, kernel_scale):
481+
"""Custom gradient function for int4 quantized weights.
482+
483+
Automatic differentiation will not know how to handle the
484+
int4 quantized weights. So a custom gradient function is needed
485+
to handle the int4 quantized weights.
486+
487+
The custom gradient function will use the dequantized kernel to
488+
compute the gradient.
489+
"""
490+
491+
unpacked_kernel = quantizers.unpack_int4(
492+
kernel, self._orig_input_dim
493+
)
494+
495+
def grad_fn(*args, upstream=None):
496+
if upstream is None:
497+
(upstream,) = args
498+
float_kernel = ops.divide(
499+
ops.cast(unpacked_kernel, dtype=self.compute_dtype),
500+
kernel_scale,
501+
)
502+
inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel))
503+
return (inputs_grad, None, None)
504+
505+
inputs, inputs_scale = self.inputs_quantizer(inputs)
506+
x = ops.matmul(inputs, unpacked_kernel)
507+
x = ops.cast(x, self.compute_dtype)
508+
x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
509+
return x, grad_fn
510+
511+
x = matmul_with_inputs_gradient(
512+
inputs,
513+
ops.convert_to_tensor(self._kernel),
514+
ops.convert_to_tensor(self.kernel_scale),
515+
)
516+
517+
if self.lora_enabled:
518+
lora_x = ops.matmul(inputs, self.lora_kernel_a)
519+
lora_x = ops.matmul(lora_x, self.lora_kernel_b)
520+
x = ops.add(x, (self.lora_alpha / self.lora_rank) * lora_x)
521+
522+
# Add bias and activation
523+
if self.bias is not None:
524+
x = ops.add(x, self.bias)
525+
if self.activation is not None:
526+
x = self.activation(x)
527+
return x
528+
418529
def _float8_call(self, inputs, training=None):
419530
if self.lora_enabled:
420531
raise NotImplementedError(
@@ -518,32 +629,117 @@ def quantize(self, mode, type_check=True):
518629
)
519630
kernel_scale = ops.squeeze(kernel_scale, axis=0)
520631
del self._kernel
521-
self.quantized_build(kernel_shape, mode)
522-
if mode == "int8":
632+
# Build variables for int8 mode
633+
self.quantized_build(kernel_shape, mode)
523634
self._kernel.assign(kernel_value)
524635
self.kernel_scale.assign(kernel_scale)
636+
elif mode == "int4":
637+
# 1. Quantize to int4 values (still int8 dtype, range [-8,7])
638+
kernel_value_int4, kernel_scale = quantizers.abs_max_quantize(
639+
self._kernel,
640+
axis=0,
641+
value_range=(-8, 7),
642+
dtype="int8",
643+
to_numpy=True,
644+
)
645+
kernel_scale = ops.squeeze(kernel_scale, axis=0)
646+
# 2. Pack two int4 values into a single int8 byte.
647+
packed_kernel_value, _, _ = quantizers.pack_int4(kernel_value_int4)
648+
del self._kernel
649+
# Build variables using the original kernel shape; _int4_build will
650+
# compute the packed shape internally.
651+
self.quantized_build(kernel_shape, mode)
652+
# Assign packed values.
653+
self._kernel.assign(packed_kernel_value)
654+
self.kernel_scale.assign(kernel_scale)
655+
elif mode == "float8":
656+
self.quantized_build(kernel_shape, mode)
657+
else:
658+
raise self._quantization_mode_error(mode)
525659

526-
# Set new dtype policy
660+
# Set new dtype policy only for modes that already have a policy.
527661
if self.dtype_policy.quantization_mode is None:
662+
from keras.src import dtype_policies # local import to avoid cycle
663+
528664
policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
529665
self.dtype_policy = policy
530666

531667
def _get_kernel_with_merged_lora(self):
668+
"""Returns the kernel with LoRA matrices merged, for serialization.
669+
670+
This method is called by `save_own_variables` to produce a single
671+
kernel tensor that includes the adaptations from LoRA. This is useful
672+
for deploying the model or for continuing training after permanently
673+
applying the LoRA update.
674+
675+
If the layer is quantized (`int8` or `int4`), the process is:
676+
1. Dequantize the base kernel to float.
677+
2. Compute the LoRA delta (`lora_kernel_a @ lora_kernel_b`) and add
678+
it to the dequantized kernel.
679+
3. Re-quantize the merged result back to the original quantized
680+
type (`int8` or packed `int4`), calculating a new scale factor.
681+
682+
If the layer is not quantized, this method returns the result of the
683+
`kernel` property (which computes the merge in floating-point) and a
684+
scale of `None`.
685+
686+
If LoRA is not enabled, it returns the original kernel and scale
687+
without modification.
688+
689+
Returns:
690+
A tuple `(kernel_value, kernel_scale)`:
691+
`kernel_value`: The merged kernel. A quantized tensor if
692+
quantization is active, otherwise a high precision tensor.
693+
`kernel_scale`: The quantization scale for the merged kernel.
694+
This is `None` if the layer is not quantized.
695+
"""
532696
if self.dtype_policy.quantization_mode is not None:
533697
kernel_value = self._kernel
534698
kernel_scale = self.kernel_scale
535699
if self.lora_enabled:
536-
# Dequantize & quantize to merge lora weights into int8 kernel
537-
# Note that this is a lossy compression
538-
kernel_value = ops.divide(kernel_value, kernel_scale)
539-
kernel_value = ops.add(
540-
kernel_value,
541-
(self.lora_alpha / self.lora_rank)
542-
* ops.matmul(self.lora_kernel_a, self.lora_kernel_b),
700+
# Dequantize kernel to float
701+
if self.quantization_mode == "int4":
702+
unpacked_kernel = quantizers.unpack_int4(
703+
kernel_value, self._orig_input_dim
704+
)
705+
float_kernel = ops.divide(
706+
ops.cast(unpacked_kernel, self.compute_dtype),
707+
kernel_scale,
708+
)
709+
quant_range = (-8, 7)
710+
elif self.quantization_mode == "int8":
711+
float_kernel = ops.divide(
712+
ops.cast(kernel_value, self.compute_dtype), kernel_scale
713+
)
714+
quant_range = (-127, 127)
715+
else:
716+
raise ValueError(
717+
"Unsupported quantization mode: "
718+
f"{self.quantization_mode}"
719+
)
720+
721+
# Merge LoRA weights in float domain
722+
lora_delta = (self.lora_alpha / self.lora_rank) * ops.matmul(
723+
self.lora_kernel_a, self.lora_kernel_b
543724
)
544-
kernel_value, kernel_scale = quantizers.abs_max_quantize(
545-
kernel_value, axis=0, to_numpy=True
725+
merged_float_kernel = ops.add(float_kernel, lora_delta)
726+
727+
# Requantize
728+
requantized_kernel, kernel_scale = quantizers.abs_max_quantize(
729+
merged_float_kernel,
730+
axis=0,
731+
value_range=quant_range,
732+
dtype="int8",
733+
to_numpy=True,
546734
)
547735
kernel_scale = ops.squeeze(kernel_scale, axis=0)
736+
737+
# Pack if int4
738+
if self.quantization_mode == "int4":
739+
kernel_value, _, _ = quantizers.pack_int4(
740+
requantized_kernel
741+
)
742+
else:
743+
kernel_value = requantized_kernel
548744
return kernel_value, kernel_scale
549745
return self.kernel, None

0 commit comments

Comments
 (0)