Skip to content

Commit 6b9cc18

Browse files
improve formatting and add missing note
1 parent 791043f commit 6b9cc18

File tree

3 files changed

+147
-76
lines changed

3 files changed

+147
-76
lines changed

guides/ipynb/quantization_overview.ipynb

Lines changed: 96 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,24 @@
22
"cells": [
33
{
44
"cell_type": "markdown",
5-
"id": "35a7da8b",
6-
"metadata": {},
5+
"metadata": {
6+
"colab_type": "text"
7+
},
78
"source": [
89
"# Quantization in Keras\n",
9-
"Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh)\n",
10-
"\n",
11-
"Date created: 2025/10/09\n",
12-
"\n",
13-
"Last modified: 2025/10/09\n",
14-
"\n",
15-
"Description: Overview of quantization in Keras (int8, float8, int4, GPTQ).\n",
16-
"\n",
17-
"Accelerator: GPU\n",
1810
"\n",
11+
"**Author:** [Jyotinder Singh](https://x.com/Jyotinder_Singh)<br>\n",
12+
"**Date created:** 2025/10/09<br>\n",
13+
"**Last modified:** 2025/10/09<br>\n",
14+
"**Description:** Overview of quantization in Keras (int8, float8, int4, GPTQ)."
15+
]
16+
},
17+
{
18+
"cell_type": "markdown",
19+
"metadata": {
20+
"colab_type": "text"
21+
},
22+
"source": [
1923
"## Introduction\n",
2024
"\n",
2125
"Modern large models are often **memory- and bandwidth-bound**: most inference time is spent moving tensors between memory and compute units rather than doing math. Quantization reduces the number of bits used to represent the model's weights and (optionally) activations, which:\n",
@@ -31,12 +35,19 @@
3135
"* Joint weight + activation PTQ in `int4`, `int8`, and `float8`.\n",
3236
"* Weight-only PTQ via **GPTQ** (2/3/4/8-bit) to maximize compression with minimal accuracy impact, especially for large language models (LLMs).\n",
3337
"\n",
34-
"**Terminology**\n",
38+
"### Terminology\n",
39+
"\n",
3540
"* *Scale / zero-point:* Quantization maps real values `x` to integers `q` using a scale (and optionally a zero-point). Symmetric schemes use only a scale.\n",
3641
"* *Per-channel vs per-tensor:* A separate scale per output channel (e.g., per hidden unit) usually preserves accuracy better than a single scale for the whole tensor.\n",
37-
"* *Calibration:* A short pass over sample data to estimate activation ranges (e.g., max absolute value).\n",
38-
"\n",
39-
"\n",
42+
"* *Calibration:* A short pass over sample data to estimate activation ranges (e.g., max absolute value)."
43+
]
44+
},
45+
{
46+
"cell_type": "markdown",
47+
"metadata": {
48+
"colab_type": "text"
49+
},
50+
"source": [
4051
"## Quantization Modes\n",
4152
"\n",
4253
"Keras currently focuses on the following numeric formats. Each mode can be applied selectively to layers or to the whole model via the same API.\n",
@@ -67,10 +78,18 @@
6778
"\n",
6879
"### Implementation notes\n",
6980
"\n",
70-
"* For `int4`, Keras packs signed 4-bit values (range = [-8, 7]) and stores per-channel scales such as `kernel_scale`. Dequantization happens on the fly, and matmuls use 8-bit (unpacked) kernels.\n",
71-
"* Activation scaling for `int4` / `int8` / `float8` uses **AbsMax calibration** by default (range set by the maximum absolute value observed). Alternative calibration methods (e.g., percentile) may be added in future releases.\n",
72-
"* Per-channel scaling is the default for weights where supported, because it materially improves accuracy at negligible overhead.\n",
73-
"\n",
81+
"* **Dynamic activation quantization**: In the `int4`, `int8` PTQ path, activation scales are computed on-the-fly at runtime (per tensor and per batch) using an AbsMax estimator. This avoids maintaining a separate, fixed set of activation scales from a calibration pass and adapts to varying input ranges.\n",
82+
"* **4-bit packing**: For `int4`, Keras packs signed 4-bit values (range = [-8, 7]) and stores per-channel scales such as `kernel_scale`. Dequantization happens on the fly, and matmuls use 8-bit (unpacked) kernels.\n",
83+
"* **Calibration Strategy**: Activation scaling for `int4` / `int8` / `float8` uses **AbsMax calibration** by default (range set by the maximum absolute value observed). Alternative calibration methods (e.g., percentile) may be added in future releases.\n",
84+
"* Per-channel scaling is the default for weights where supported, because it materially improves accuracy at negligible overhead."
85+
]
86+
},
87+
{
88+
"cell_type": "markdown",
89+
"metadata": {
90+
"colab_type": "text"
91+
},
92+
"source": [
7493
"## Quantizing Keras Models\n",
7594
"\n",
7695
"Quantization is applied explicitly after layers or models are built. The API is designed to be predictable: you call quantize, the graph is rewritten, the weights are replaced, and you can immediately run inference or save the model.\n",
@@ -87,9 +106,10 @@
87106
},
88107
{
89108
"cell_type": "code",
90-
"execution_count": null,
91-
"id": "d9944077",
92-
"metadata": {},
109+
"execution_count": 0,
110+
"metadata": {
111+
"colab_type": "code"
112+
},
93113
"outputs": [],
94114
"source": [
95115
"import keras\n",
@@ -100,10 +120,13 @@
100120
"y_train = keras.ops.array(np.random.rand(100, 1))\n",
101121
"\n",
102122
"# Build the model.\n",
103-
"model = keras.Sequential([\n",
104-
" keras.layers.Dense(32, activation=\"relu\", input_shape=(10,)),\n",
105-
" keras.layers.Dense(1)\n",
106-
"])\n",
123+
"model = keras.Sequential(\n",
124+
" [\n",
125+
" keras.Input(shape=(10,)),\n",
126+
" keras.layers.Dense(32, activation=\"relu\"),\n",
127+
" keras.layers.Dense(1),\n",
128+
" ]\n",
129+
")\n",
107130
"\n",
108131
"# Compile and fit the model.\n",
109132
"model.compile(optimizer=\"adam\", loss=\"mean_squared_error\")\n",
@@ -115,8 +138,9 @@
115138
},
116139
{
117140
"cell_type": "markdown",
118-
"id": "a9b1d974",
119-
"metadata": {},
141+
"metadata": {
142+
"colab_type": "text"
143+
},
120144
"source": [
121145
"**What this does:** Quantizes the weights of the supported layers, and re-wires their forward paths to be compatible with the quantized kernels and quantization scales.\n",
122146
"\n",
@@ -129,31 +153,40 @@
129153
},
130154
{
131155
"cell_type": "code",
132-
"execution_count": null,
133-
"id": "0df2aa1a",
134-
"metadata": {},
156+
"execution_count": 0,
157+
"metadata": {
158+
"colab_type": "code"
159+
},
135160
"outputs": [],
136161
"source": [
137162
"from keras import layers\n",
138163
"\n",
139164
"input_shape = (10,)\n",
140-
"layer = layers.Dense(32, activation=\"relu\", input_shape=input_shape)\n",
165+
"layer = layers.Dense(32, activation=\"relu\")\n",
141166
"layer.build(input_shape)\n",
142167
"\n",
143168
"layer.quantize(\"int4\") # Or \"int8\", \"float8\", etc."
144169
]
145170
},
146171
{
147172
"cell_type": "markdown",
148-
"id": "249deef4",
149-
"metadata": {},
173+
"metadata": {
174+
"colab_type": "text"
175+
},
150176
"source": [
151-
"**When to use layer-wise quantization**\n",
177+
"### When to use layer-wise quantization\n",
152178
"\n",
153179
"* To keep numerically sensitive blocks (e.g., small residual paths, logits) at higher precision while quantizing large projection layers.\n",
154180
"* To mix modes (e.g., attention projections in int4, feed-forward in int8) and measure trade-offs incrementally.\n",
155-
"* Always validate on a small eval set after each step; mixing precisions across residual connections can shift distributions.\n",
156-
"\n",
181+
"* Always validate on a small eval set after each step; mixing precisions across residual connections can shift distributions."
182+
]
183+
},
184+
{
185+
"cell_type": "markdown",
186+
"metadata": {
187+
"colab_type": "text"
188+
},
189+
"source": [
157190
"## Layer & model coverage\n",
158191
"\n",
159192
"Keras supports the following core layers in its quantization framework:\n",
@@ -165,26 +198,42 @@
165198
"\n",
166199
"Any composite layers that are built from the above (for example, `MultiHeadAttention`, `GroupedQueryAttention`, feed-forward blocks in Transformers) inherit quantization support by construction. This covers the majority of modern encoder-only and decoder-only Transformer architectures.\n",
167200
"\n",
168-
"Since all KerasHub models subclass `keras.Model`, they automatically support the `model.quantize(...)` API. In practice, this means you can take a popular LLM preset, call a single function to obtain an int8/int4/GPTQ-quantized variant, and then save or serve it—without changing your training code.\n",
201+
"Since all KerasHub models subclass `keras.Model`, they automatically support the `model.quantize(...)` API. In practice, this means you can take a popular LLM preset, call a single function to obtain an int8/int4/GPTQ-quantized variant, and then save or serve it\u2014without changing your training code.\n",
169202
"\n",
170203
"## Practical guidance\n",
171204
"\n",
172205
"* For GPTQ, use a calibration set that matches your inference domain (a few hundred to a few thousand tokens is often enough to see strong retention).\n",
173206
"* Measure both **VRAM** and **throughput/latency**: memory savings are immediate; speedups depend on the availability of fused low-precision kernels on your device."
174207
]
175-
},
176-
{
177-
"cell_type": "markdown",
178-
"id": "cce23bb3",
179-
"metadata": {},
180-
"source": []
181208
}
182209
],
183210
"metadata": {
211+
"accelerator": "GPU",
212+
"colab": {
213+
"collapsed_sections": [],
214+
"name": "quantization_overview",
215+
"private_outputs": false,
216+
"provenance": [],
217+
"toc_visible": true
218+
},
219+
"kernelspec": {
220+
"display_name": "Python 3",
221+
"language": "python",
222+
"name": "python3"
223+
},
184224
"language_info": {
185-
"name": "python"
225+
"codemirror_mode": {
226+
"name": "ipython",
227+
"version": 3
228+
},
229+
"file_extension": ".py",
230+
"mimetype": "text/x-python",
231+
"name": "python",
232+
"nbconvert_exporter": "python",
233+
"pygments_lexer": "ipython3",
234+
"version": "3.7.0"
186235
}
187236
},
188237
"nbformat": 4,
189-
"nbformat_minor": 5
190-
}
238+
"nbformat_minor": 0
239+
}

guides/md/quantization_overview.md

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
**Last modified:** 2025/10/09<br>
66
**Description:** Overview of quantization in Keras (int8, float8, int4, GPTQ).
77

8+
89
<img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/quantization_overview.ipynb) <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/guides/quantization_overview.py)
910

10-
---
1111

12+
13+
---
1214
## Introduction
1315

1416
Modern large models are often **memory- and bandwidth-bound**: most inference time is spent moving tensors between memory and compute units rather than doing math. Quantization reduces the number of bits used to represent the model's weights and (optionally) activations, which:
@@ -24,14 +26,13 @@ At a high level, Keras supports:
2426
* Joint weight + activation PTQ in `int4`, `int8`, and `float8`.
2527
* Weight-only PTQ via **GPTQ** (2/3/4/8-bit) to maximize compression with minimal accuracy impact, especially for large language models (LLMs).
2628

27-
> **Terminology**
28-
>
29-
> * *Scale / zero-point:* Quantization maps real values `x` to integers `q` using a scale (and optionally a zero-point). Symmetric schemes use only a scale.
30-
> * *Per-channel vs per-tensor:* A separate scale per output channel (e.g., per hidden unit) usually preserves accuracy better than a single scale for the whole tensor.
31-
> * *Calibration:* A short pass over sample data to estimate activation ranges (e.g., max absolute value).
29+
### Terminology
3230

33-
---
31+
* *Scale / zero-point:* Quantization maps real values `x` to integers `q` using a scale (and optionally a zero-point). Symmetric schemes use only a scale.
32+
* *Per-channel vs per-tensor:* A separate scale per output channel (e.g., per hidden unit) usually preserves accuracy better than a single scale for the whole tensor.
33+
* *Calibration:* A short pass over sample data to estimate activation ranges (e.g., max absolute value).
3434

35+
---
3536
## Quantization Modes
3637

3738
Keras currently focuses on the following numeric formats. Each mode can be applied selectively to layers or to the whole model via the same API.
@@ -62,12 +63,12 @@ Keras currently focuses on the following numeric formats. Each mode can be appli
6263

6364
### Implementation notes
6465

65-
* For `int4`, Keras packs signed 4-bit values (range = [−8, 7]) and stores per-channel scales such as `kernel_scale`. Dequantization happens on the fly, and matmuls use 8-bit (unpacked) kernels.
66-
* Activation scaling for `int4` / `int8` / `float8` uses **AbsMax calibration** by default (range set by the maximum absolute value observed). Alternative calibration methods (e.g., percentile) may be added in future releases.
66+
* **Dynamic activation quantization**: In the `int4`, `int8` PTQ path, activation scales are computed on-the-fly at runtime (per tensor and per batch) using an AbsMax estimator. This avoids maintaining a separate, fixed set of activation scales from a calibration pass and adapts to varying input ranges.
67+
* **4-bit packing**: For `int4`, Keras packs signed 4-bit values (range = [-8, 7]) and stores per-channel scales such as `kernel_scale`. Dequantization happens on the fly, and matmuls use 8-bit (unpacked) kernels.
68+
* **Calibration Strategy**: Activation scaling for `int4` / `int8` / `float8` uses **AbsMax calibration** by default (range set by the maximum absolute value observed). Alternative calibration methods (e.g., percentile) may be added in future releases.
6769
* Per-channel scaling is the default for weights where supported, because it materially improves accuracy at negligible overhead.
6870

6971
---
70-
7172
## Quantizing Keras Models
7273

7374
Quantization is applied explicitly after layers or models are built. The API is designed to be predictable: you call quantize, the graph is rewritten, the weights are replaced, and you can immediately run inference or save the model.
@@ -81,6 +82,7 @@ Typical workflow:
8182

8283
### Model Quantization
8384

85+
8486
```python
8587
import keras
8688
import numpy as np
@@ -90,10 +92,13 @@ x_train = keras.ops.array(np.random.rand(100, 10))
9092
y_train = keras.ops.array(np.random.rand(100, 1))
9193

9294
# Build the model.
93-
model = keras.Sequential([
94-
keras.layers.Dense(32, activation="relu", input_shape=(10,)),
95-
keras.layers.Dense(1)
96-
])
95+
model = keras.Sequential(
96+
[
97+
keras.Input(shape=(10,)),
98+
keras.layers.Dense(32, activation="relu"),
99+
keras.layers.Dense(1),
100+
]
101+
)
97102

98103
# Compile and fit the model.
99104
model.compile(optimizer="adam", loss="mean_squared_error")
@@ -103,6 +108,13 @@ model.fit(x_train, y_train, epochs=1, verbose=0)
103108
model.quantize("int8")
104109
```
105110

111+
<div class="k-default-codeblock">
112+
```
113+
/Users/jyotindersingh/miniconda3/envs/keras-io-env-3.10/lib/python3.10/site-packages/keras/src/models/model.py:455: UserWarning: Layer InputLayer does not have a `quantize` method implemented.
114+
warnings.warn(str(e))
115+
```
116+
</div>
117+
106118
**What this does:** Quantizes the weights of the supported layers, and re-wires their forward paths to be compatible with the quantized kernels and quantization scales.
107119

108120
**Note**: Throughput gains depend on backend/hardware kernels; in cases where kernels fall back to dequantized matmul, you still get memory savings but smaller speedups.
@@ -111,11 +123,12 @@ model.quantize("int8")
111123

112124
The Keras quantization framework allows you to quantize each layer separately, without having to quantize the entire model using the same unified API.
113125

126+
114127
```python
115128
from keras import layers
116129

117130
input_shape = (10,)
118-
layer = layers.Dense(32, activation="relu", input_shape=input_shape)
131+
layer = layers.Dense(32, activation="relu")
119132
layer.build(input_shape)
120133

121134
layer.quantize("int4") # Or "int8", "float8", etc.
@@ -128,7 +141,6 @@ layer.quantize("int4") # Or "int8", "float8", etc.
128141
* Always validate on a small eval set after each step; mixing precisions across residual connections can shift distributions.
129142

130143
---
131-
132144
## Layer & model coverage
133145

134146
Keras supports the following core layers in its quantization framework:
@@ -143,7 +155,6 @@ Any composite layers that are built from the above (for example, `MultiHeadAtten
143155
Since all KerasHub models subclass `keras.Model`, they automatically support the `model.quantize(...)` API. In practice, this means you can take a popular LLM preset, call a single function to obtain an int8/int4/GPTQ-quantized variant, and then save or serve it—without changing your training code.
144156

145157
---
146-
147158
## Practical guidance
148159

149160
* For GPTQ, use a calibration set that matches your inference domain (a few hundred to a few thousand tokens is often enough to see strong retention).

0 commit comments

Comments
 (0)