Skip to content

Commit 826af96

Browse files
Adds quantization documentation
1 parent bff731e commit 826af96

File tree

4 files changed

+321
-0
lines changed

4 files changed

+321
-0
lines changed

guides/md/quantization/overview.md

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Quantization in Keras
2+
3+
**Author:** [Jyotinder Singh](https://x.com/Jyotinder_Singh)<br>
4+
**Date created:** 2025/10/09<br>
5+
**Last modified:** 2025/10/09<br>
6+
**Description:** Overview of quantization in Keras (int8, float8, int4, GPTQ).
7+
8+
9+
<img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/quantization/overview.ipynb) <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/guides/quantization/overview.py)
10+
11+
---
12+
13+
## Introduction
14+
15+
Modern large models are often **memory- and bandwidth-bound**: most inference time is spent moving tensors between memory and compute units rather than doing math. Quantization reduces the number of bits used to represent the model's weights and (optionally) activations, which:
16+
17+
* Shrinks model size and VRAM/RAM footprint.
18+
* Increases effective memory bandwidth (fewer bytes per value).
19+
* Can improve throughput and sometimes latency on supporting hardware with low-precision kernels.
20+
21+
Keras provides first-class **post-training quantization (PTQ)** workflows which support pretrained models and expose a uniform API at both the model and layer level.
22+
23+
At a high level, Keras supports:
24+
25+
* Joint weight + activation PTQ in `int4`, `int8`, and `float8`.
26+
* Weight-only PTQ via **GPTQ** (2/3/4/8-bit) to maximize compression with minimal accuracy impact, especially for large language models (LLMs).
27+
28+
> **Terminology**
29+
>
30+
> * *Scale / zero-point:* Quantization maps real values `x` to integers `q` using a scale (and optionally a zero-point). Symmetric schemes use only a scale.
31+
> * *Per-channel vs per-tensor:* A separate scale per output channel (e.g., per hidden unit) usually preserves accuracy better than a single scale for the whole tensor.
32+
> * *Calibration:* A short pass over sample data to estimate activation ranges (e.g., max absolute value).
33+
34+
---
35+
36+
## Quantization Modes
37+
38+
Keras currently focuses on the following numeric formats. Each mode can be applied selectively to layers or to the whole model via the same API.
39+
40+
* **`int8` (8-bit integer)**: **joint weight + activation** PTQ.
41+
42+
* **How it works:** Values are linearly mapped to 8-bit integers with per-channel scales. Activations are calibrated using dynamic quantization (see note below).
43+
* **Why use it:** Good accuracy for many architectures; broad hardware support.
44+
* **What to expect:** ~4x smaller than FP32 parameters (~2x vs FP16) and lower activation bandwidth, with small accuracy loss on many tasks. Throughput gains depend on kernel availability and memory bandwidth.
45+
46+
* **`float8` (FP8: E4M3 / E5M2 variants)**: Low-precision floating-point useful for training and inference on FP8-capable hardware.
47+
48+
* **How it works:** Values are quantized to FP8 with a dynamic scale. Fused FP8 kernels on supported devices yield speedups.
49+
* **Why use it:** Mixed-precision training/inference with hardware acceleration while keeping floating-point semantics (since underflow/overflow characteristics differ from int).
50+
* **What to expect:** Competitive speed and memory reductions where FP8 kernels are available; accuracy varies by model, but is usually acceptable for most tasks.
51+
52+
* **`int4`**: Ultra-low-bit **weights** for aggressive compression; activations remain in higher precision (int8).
53+
54+
* **How it works:** Two signed 4-bit "nibbles" are packed per int8 byte. Keras uses symmetric per-output-channel scales to dequantize efficiently inside matmul.
55+
* **Why use it:** Significant VRAM/storage savings for LLMs with acceptable accuracy when combined with robust per-channel scaling.
56+
* **What to expect:** ~8× smaller than FP32 (~4× vs FP16) for weights; throughput gains depend on kernel availability and memory bandwidth. Competitive accuracy deltas for encoder-only architectures, may show larger regressions on decoder-only models.
57+
58+
* **`GPTQ` (weight-only 2/3/4/8 bits)**: *Second-order, post-training* method minimizing layer output error.
59+
60+
* **How it works (brief):** For each weight block (group), GPTQ solves a local least-squares problem using a Hessian approximation built from a small calibration set, then quantizes to low bit-width. The result is a packed weight tensor plus per-group parameters (e.g., scales).
61+
* **Why use it:** Strong accuracy retention at very low bit-widths without retraining; ideal for rapid LLM compression.
62+
* **What to expect:** Large storage/VRAM savings with small perplexity/accuracy deltas on many decoder-only models when calibrated on task-relevant samples.
63+
64+
> **Implementation notes**
65+
>
66+
> * For `int4`, Keras packs signed 4-bit values (range ≈ [−8, 7]) and stores per-channel scales such as `kernel_scale`. Dequantization happens on the fly, and matmuls use 8-bit (unpacked) kernels.
67+
> * Activation scaling for `int4` / `int8` / `float8` uses **AbsMax calibration** by default (range set by the maximum absolute value observed). Alternative calibration methods (e.g., percentile) may be added in future releases.
68+
> * Per-channel scaling is the default for weights where supported, because it materially improves accuracy at negligible overhead.
69+
70+
---
71+
72+
## Quantizing Keras Models
73+
74+
Quantization is applied explicitly after layers or models are built. The API is designed to be predictable: you call quantize, the graph is rewritten,the weights are replaced, and you can immediately run inference or save the model.
75+
76+
Typical workflow:
77+
78+
1. **Build / load your FP model.** Train if needed. Ensure `build()` or a forward pass has materialized weights.
79+
2. **(GPTQ only)** Keras may run a short calibration pass to collect activation ranges (you can pass a small, representative dataset).
80+
3. **Invoke quantization.** Call `model.quantize("<mode>")` or `layer.quantize("<mode>")` with `"int8"`, `"int4"`, `"float8"`, or `"gptq"` (weight-only).
81+
4. **Use or save.** Run inference, or `model.save(...)`. Quantization state (packed weights, scales, metadata) is preserved on save/load.
82+
83+
### Model Quantization
84+
85+
```python
86+
import keras
87+
import numpy as np
88+
89+
# Sample training data
90+
x_train = keras.ops.array(np.random.rand(100, 10))
91+
y_train = keras.ops.array(np.random.rand(100, 1))
92+
93+
# Build the model
94+
model = keras.Sequential([
95+
keras.layers.Dense(32, activation="relu", input_shape=(10,)),
96+
keras.layers.Dense(1)
97+
])
98+
99+
# Compile and fit the model
100+
model.compile(optimizer="adam", loss="mean_squared_error")
101+
model.fit(x_train, y_train, epochs=1, verbose=0)
102+
103+
# Quantize the model
104+
model.quantize("int8")
105+
```
106+
107+
**What this does:** Quantizes the weights of the supported layers, and re-wires their forward paths to be compatible with the quantized kernels and quantization scales.
108+
109+
**Note**: Throughput gains depend on backend/hardware kernels; in cases where kernels fall back to dequantized matmul, you still get memory savings but smaller speedups.
110+
111+
### Layer-wise Quantization
112+
113+
The Keras quantization framework allows you to quantize each layer separately, without having to quantize the entire model using the same unified API.
114+
115+
```python
116+
from keras import layers
117+
118+
input_shape = (10,)
119+
layer = layers.Dense(32, activation="relu", input_shape=input_shape)
120+
layer.build(input_shape)
121+
122+
layer.quantize("int4") # or "int8", "float8", etc.
123+
```
124+
125+
**When to use layer-wise quantization**
126+
127+
* To keep numerically sensitive blocks (e.g., small residual paths, logits) at higher precision while quantizing large projection layers.
128+
* To mix modes (e.g., attention projections in int4, feed-forward in int8) and measure trade-offs incrementally.
129+
* Always validate on a small eval set after each step; mixing precisions across residual connections can shift distributions.
130+
131+
---
132+
133+
## Layer & model coverage
134+
135+
Keras supports the following core layers in its quantization framework:
136+
137+
* `Dense`
138+
* `EinsumDense`
139+
* `Embedding` (available in KerasHub)
140+
* `ReversibleEmbedding` (available in KerasHub)
141+
142+
Any composite layers that are built from the above (for example, `MultiHeadAttention`, `GroupedQueryAttention`, feed-forward blocks in Transformers) inherit quantization support by construction. This covers the majority of modern encoder-only and decoder-only Transformer architectures.
143+
144+
Since all KerasHub models subclass `keras.Model`, they automatically support the `model.quantize(...)` API. In practice, this means you can take a popular LLM preset, call a single function to obtain an int8/int4/GPTQ-quantized variant, and then save or serve it—without changing your training code.
145+
146+
> **Practical guidance**
147+
>
148+
> * For GPTQ, use a calibration set that matches your inference domain (a few hundred to a few thousand tokens is often enough to see strong retention).
149+
> * Measure both **VRAM** and **throughput/latency**: memory savings are immediate; speedups depend on the availability of fused low-precision kernels on your device.

guides/quantization/overview.py

Whitespace-only changes.

guides/quantization_overview.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
"""
2+
Title: Quantization in Keras
3+
Author: Keras Team
4+
Date created: 2025/10/09
5+
Last modified: 2025/10/09
6+
Description: Overview of weight & activation quantization workflows in Keras (int8, float8, int4, GPTQ) and integration patterns.
7+
Accelerator: None
8+
"""
9+
10+
"""
11+
## Introduction
12+
13+
Modern large models are memory- and bandwidth-bound. Quantization reduces model footprint
14+
and improves inference throughput by representing weights and (optionally) activations
15+
with fewer bits than the original floating-point formats.
16+
17+
At a high level, Keras provides first-class support for multiple quantization modes:
18+
19+
* Weight-only post-training quantization (e.g., GPTQ 4-bit) for rapid compression.
20+
* Joint weight & activation post-training quantization (PTQ) in `int8` or `float8`.
21+
* Quantization-aware training (QAT) paths (where supported) to recover accuracy when
22+
naive PTQ would cause excessive degradation.
23+
24+
Two dimensions to keep in mind:
25+
26+
* What is being quantized? (weights only vs weights + activations)
27+
* When is quantization applied? (post-training vs during training via QAT)
28+
"""
29+
30+
"""
31+
## Supported quantization modes
32+
33+
Keras currently focuses on the following numeric formats:
34+
35+
* int8 (8-bit integer) — Common baseline for activation + weight PTQ/QAT.
36+
* float8 (e.g., E4M3 / E5M2 variants) — Mixed-precision friendly low-bit floating formats
37+
enabling smoother accuracy retention vs int8 for some transformer blocks.
38+
* int4 (packed) — Ultra-low-bit weight representation for large language models.
39+
Activations generally remain in a higher precision (e.g., float16 / bfloat16 / int8).
40+
* GPTQ (weight-only 4-bit) — Post-training, block-wise second-order approximated
41+
quantization minimizing output error. Ideal for rapid compression of pretrained LLMs.
42+
43+
Behind the scenes, the `int4` path packs two 4-bit nibbles into each signed int8 byte.
44+
A per-output-channel (e.g., per hidden unit) `kernel_scale` tensor rescales the packed
45+
values back to a floating domain during matmul. Activation scaling (for int8/float8 modes)
46+
uses an AbsMax (max absolute value) calibration by default, but alternative calibration
47+
strategies may be pluggable in future updates.
48+
"""
49+
50+
"""
51+
## Integration paths: policy-first vs imperative
52+
53+
Keras offers two complementary integration styles so you can adopt quantization with
54+
minimal friction depending on how you construct models.
55+
56+
### 1. Policy-first (declarative) approach
57+
58+
Specify a quantization policy via `dtype` (or `compute_dtype`) when instantiating layers
59+
or updating them later. Examples:
60+
61+
```python
62+
from keras import layers
63+
64+
# Declare an int4 quantization transformation from float32 pretrained weights
65+
dense_int4 = layers.Dense(4096, activation="gelu", dtype="int4_from_float32")
66+
67+
# Mixed path: start from mixed bfloat16 weights, produce int8 quantized variant
68+
dense_int8 = layers.Dense(4096, dtype="int8_from_mixed_bfloat16")
69+
```
70+
71+
Semantics of a policy string:
72+
73+
```
74+
<target_quant_dtype>_from_<source_dtype>
75+
```
76+
77+
The layer will (lazily) quantize its weights the first time it's built or when weights
78+
are assigned.
79+
80+
You can also retroactively apply a policy:
81+
82+
```python
83+
layer.dtype = "int4_from_float32" # triggers (re)quantization on next build/assign
84+
```
85+
86+
### 2. Imperative (post-construction) approach
87+
88+
Build your model/layer normally, then invoke an explicit quantization method:
89+
90+
```python
91+
model = build_large_lm() # returns a compiled or uncompiled keras.Model
92+
model.build(input_shape) # ensure weights are created
93+
model.quantize("int4") # in-place weight transformation (and bookkeeping)
94+
```
95+
96+
For an individual layer:
97+
98+
```python
99+
layer.build(input_shape)
100+
layer.quantize("gptq_int4") # or "int8", "float8", etc.
101+
```
102+
103+
This style is convenient for applying quantization to existing checkpoints without
104+
rewriting code.
105+
"""
106+
107+
"""
108+
## Layer & model coverage
109+
110+
Core dense projection layers are supported:
111+
112+
* `Dense`
113+
* `EinsumDense` (including LoRA-compatible adaptations)
114+
115+
At the model level, KerasHub provides helpers/presets (for popular LLM architectures)
116+
that expose a single `model.quantize(<mode>)` convenience method targeting all eligible
117+
submodules.
118+
119+
Additional layer types (e.g., attention, convolution) may be incrementally covered as
120+
kernel implementations mature.
121+
"""
122+
123+
"""
124+
## Under the hood: packed int4
125+
126+
For `int4` / GPTQ weight-only quantization:
127+
128+
1. Original float weights (e.g., float32 or bfloat16) are partitioned per output channel.
129+
2. A scale (and optionally zero-point / offset for asymmetric schemes) is computed per channel.
130+
3. Each weight is divided by its channel scale, clamped to the 4-bit signed range, and
131+
rounded to nearest integer.
132+
4. Two 4-bit values are packed into one int8 storage byte (low nibble + high nibble).
133+
5. During a forward matmul, packed bytes are unpacked on-the-fly (SIMD / fused kernel),
134+
multiplied by their channel scale, and accumulated in higher precision (e.g., float16).
135+
6. Activation quantization (if enabled) applies AbsMax scaling to produce an 8-bit tensor
136+
fed into integer or mixed-precision GEMM kernels.
137+
138+
Key advantages:
139+
140+
* 4x compression vs float16; often ~8x vs float32.
141+
* Reduced memory bandwidth → higher tokens/sec for LLM inference.
142+
* Channel-wise scaling preserves accuracy better than per-tensor scaling.
143+
144+
Trade-offs:
145+
146+
* Extra decode overhead; mitigated by fused kernels.
147+
* Some accuracy regression vs int8/float8, especially for very small models.
148+
"""
149+
150+
"""
151+
## Where to go next
152+
153+
* Quickstart with runnable LLM examples (coming soon)
154+
* int4 internals & advanced configuration (link forthcoming)
155+
* Exporting / deployment guide (ONNX / TFLite / Serving)
156+
* GPTQ overview & best practices
157+
* Quantization recipes & troubleshooting
158+
159+
(Links will be activated as the companion guides are published.)
160+
"""
161+
162+
"""
163+
## Summary
164+
165+
You can adopt Keras quantization either declaratively (policy-first `dtype` strings) or
166+
imperatively (`.quantize(mode)`). Start with `int8` / `float8` for balanced accuracy &
167+
performance, and explore `int4` / GPTQ for maximum compression of large language models.
168+
"""

scripts/guides_master.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@
123123
"path": "orbax_checkpoint",
124124
"title": "Orbax Checkpointing in Keras",
125125
},
126+
{
127+
"path": "quantization/overview",
128+
"title": "Quantization in Keras",
129+
},
126130
# {
127131
# "path": "preprocessing_layers",
128132
# "title": "Working with preprocessing layers",

0 commit comments

Comments
 (0)