Skip to content

Commit df45a16

Browse files
Merge pull request #2419 from AI-Hypercomputer:mohit/quant_doc
PiperOrigin-RevId: 825272158
2 parents d352bc9 + 7189d42 commit df45a16

File tree

1 file changed

+76
-30
lines changed

1 file changed

+76
-30
lines changed

docs/explanations/quantization.md

Lines changed: 76 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
<!--
2-
Copyright 2024 Google LLC
2+
Copyright 2024-2025 Google LLC
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -16,55 +16,78 @@
1616

1717
# Quantization
1818

19-
MaxText supports quantization via both the [AQT](https://github.com/google/aqt) and [Qwix](https://github.com/google/qwix) libraries. Qwix is the recommended approach, providing a non-intrusive way to apply various quantization techniques, including Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ).
19+
Quantization in deep learning is the process of reducing the precision of numbers used to represent a model's weights and/or activations. Instead of using higher-precision floating-point formats like 32-bit floats (`float32`) or 16-bit brain floats (`bfloat16`), quantization maps these values to lower-precision numerical formats, most commonly 8-bit integers (`int8`) or floats (`fp8`).
2020

21-
## Why quantize?
21+
MaxText supports quantization via both the [AQT](https://github.com/google/aqt) and [Qwix](https://github.com/google/qwix) libraries. Qwix is the recommended approach, providing a non-intrusive way to apply Quantized Training (QT).
2222

23-
* **Reduced model size**: Lower precision numbers require less storage, making models easier to store and deploy.
24-
* **Faster inference**: Operations on lower-precision data are computationally less expensive, which can lead to faster inference times.
25-
* **Lower memory usage**: Reduced precision for weights and activations decreases the memory footprint, allowing for the deployment of larger models on hardware with limited memory.
23+
## Why use quantization?
2624

27-
## Quantizing using AQT
25+
The drive to use lower-precision formats like `int8` or `fp8` stems from significant performance advantages:
2826

29-
Jax supports AQT. You can read more about AQT on this [Google Cloud blog](https://cloud.google.com/blog/products/compute/accurate-quantized-training-aqt-for-tpu-v5e).
30-
You can turn on the quantization by adding the following flag `--quantization` and passing one of the following values:
27+
**Faster computation**: Hardware accelerators like TPUs and GPUs often have specialized instructions for integer arithmetic. Operations on lower-precision data like `int8` or `fp8` can be significantly faster than on BF16 or FP32. For example, matrix multiplications with these formats can often be 2x or more faster on hardware supporting native lower-precision tensor cores.
3128

32-
- 'int8' for dynamic range quantization using 8-bits
33-
- 'int8w' for weights only quantization using 8-bits
34-
- 'int4w' for weights only quantization using 4-bits
35-
- 'intmp' for mixed precision weight only quantization based on config file
36-
- 'fp8' for 8-bit floating-point GeMMs on NVIDIA GPUs.
29+
**Reduced memory footprint**: Storing weights and activations in `int8` or `fp8` requires 2x less memory compared to `bfloat16`. This reduces:
30+
- **HBM usage**: Less memory is needed on the accelerator itself.
31+
- **Communication costs**: Less data needs to be transferred between memory and compute units, or across devices in distributed training, which makes these transfers faster and consumes less bandwidth.
32+
- **Reduced power consumption**: Lower precision operations and reduced memory access lead to less energy usage, which is crucial for deploying models on edge devices and for sustainable AI.
33+
34+
The primary trade-off with quantization is between the model accuracy and computational performance:
35+
36+
* Reduced Dynamic Range & Precision: Lower-precision formats like `int8` or `fp8` can represent a much smaller range of values and with less precision than BF16. This can be problematic for models with wide distributions of weights or activations, potentially clipping large values or losing fine-grained details.
37+
* Impact on Gradients: Gradients during backpropagation can have very different, often wider, distributions than weights or activations, making them more sensitive to quantization errors.
38+
* Convergence Issues: The approximations introduced by quantization can sometimes hinder the model's ability to converge during training.
39+
40+
To overcome the challenges of quantization, libraries like Google's Accurate Quantized Training (AQT) and its successor Qwix (used in MaxText) employ a suite of advanced techniques. These methods ensure that models can be trained with low-precision arithmetic without significant loss in accuracy and with stable convergence.
41+
42+
## How Quantized Training (QT) works with Qwix
43+
44+
Quantized Training (QT) incorporates the effects of quantization into the training loop. This allows the model to learn and adapt to the reduced precision of quantized weights and activations.
45+
46+
Here’s how it works:
47+
48+
1. **Forward Pass**: During the forward pass, high-precision weights and activations are converted to a lower-precision format. This step simulates the information loss that occurs during quantization. The model then performs its computations using these lower-precision representations before they are converted back to a higher precision for the rest of the network. This process forces the model to become robust to the noise and reduced range of quantized values.
49+
50+
2. **Backward Pass**: Standard backpropagation cannot flow through the non-differentiable quantization operations (like rounding). To solve this, QT uses the **Straight-Through Estimator (STE)**. The STE essentially "ignores" the non-differentiable quantization step during the backward pass, passing the gradients through as if the operation was an identity function. This allows the high-precision weights to be updated based on the loss, enabling the model to learn effectively.
51+
52+
By integrating the quantization simulation directly into the training, the model learns to minimize the impact of precision loss, resulting in a more accurate quantized model.
53+
54+
## Using Quantization in MaxText
55+
56+
You can enable quantization in MaxText by setting flags in your configuration file (e.g., `base.yml`) or via the command line. MaxText supports two quantization libraries: Qwix (recommended) and AQT.
3757

58+
### Configuration Flags
3859

60+
The primary flags to control quantization are:
3961

40-
## How QAT works with Qwix
62+
* `use_qwix_quantization`: A boolean flag.
63+
* Set to `True` to enable quantization using the Qwix library.
64+
* Set to `False` (or omit) to use the AQT library if `quantization` is set.
65+
* `quantization`: A string that specifies the type of quantization to apply. The accepted values depend on whether you are using Qwix or AQT.
66+
* `quantization_calibration_method`: The calibration method for weights and activations (e.g., `"absmax"`). This is mainly for Qwix.
4167

42-
The core idea behind QAT is to insert "fake quantization" operations into the model's computation graph. During the training forward pass, these operations simulate the effect of quantizing weights and activations to a lower precision. For the backward pass, Qwix uses the Straight-Through Estimator (STE) to approximate the gradients, allowing the model to learn effectively despite the non-differentiable nature of quantization.
68+
### Qwix Quantization (Recommended)
4369

44-
## Using Qwix in MaxText
70+
To use Qwix, you must set `use_qwix_quantization=True`. Qwix is a powerful and non-intrusive library for Quantized Training.
4571

46-
You can enable quantization in MaxText by setting flags in your configuration file (e.g., `base.yml`) or via the command line.
72+
#### `quantization` values for Qwix
4773

48-
### Configuration flags
74+
Common options for the `quantization` flag when using Qwix include:
4975

50-
* `use_qwix_quantization`: Must be set to `True` to enable quantization using the Qwix library.
51-
* `quantization`: Specifies the type of quantization to apply. Common options include:
52-
* `"int8"`: 8-bit integer quantization.
53-
* `"fp8"`: 8-bit floating-point quantization.
54-
* `"fp8_full"`: FP8 quantization with static scaling.
55-
* `"fp8_gpu"`: FP8 for NVIDIA GPUs.
56-
* `"fp8_nanoo"`: FP8 for AMD MI300/MI325 GPUs.
57-
* `quantization_calibration_method`: The calibration method for weights and activations (e.g., `"absmax"`).
76+
* `"int8"`: 8-bit integer quantization.
77+
* `"fp8"`: 8-bit floating-point quantization.
78+
* `"fp8_full"`: FP8 quantization with static scaling.
79+
* `"fp8_gpu"`: FP8 for NVIDIA GPUs.
80+
* `"fp8_nanoo"`: FP8 for AMD MI300/MI325 GPUs.
5881

59-
### Example command
82+
#### Example command for Qwix
6083

6184
Here is an example of how to run a training job with int8 quantization enabled via Qwix:
6285

6386
```bash
64-
python3 -m MaxText.train src/MaxText/configs/base.yml ... use_qwix_quantization=True quantization='int8'
87+
python3 -m MaxText.train src/MaxText/configs/base.yml run_name=$YOUR_JOB_NAME base_output_directory=gs://<my-bucket> dataset_type=synthetic use_qwix_quantization=true quantization='int8'
6588
```
6689

67-
## The Qwix interception API
90+
#### The Qwix Interception API
6891

6992
MaxText integrates Qwix using its powerful and non-intrusive Interception API. This approach allows you to enable QAT for your models without modifying the original model source code. You don't need to manually replace `nn.Dense` with `QuantizedDense` or other quantized layer types.
7093

@@ -96,3 +119,26 @@ This rule is then used within a `QtProvider` to quantize the model automatically
96119
```python
97120
model = qwix.quantize_model(model, qwix.QtProvider(rule))
98121
```
122+
123+
### AQT Quantization
124+
125+
If `use_qwix_quantization` is `False` or not set, you can still apply quantization using the AQT library by setting the `quantization` flag. You can read more about AQT on this [Google Cloud blog](https://cloud.google.com/blog/products/compute/accurate-quantized-training-aqt-for-tpu-v5e).
126+
127+
#### `quantization` values for AQT
128+
129+
When using AQT, you can pass one of the following values to the `quantization` flag:
130+
131+
- 'int8' for dynamic range quantization using 8-bits
132+
- 'int8w' for weights only quantization using 8-bits
133+
- 'int4w' for weights only quantization using 4-bits
134+
- 'intmp' for mixed precision weight only quantization based on config file
135+
- 'fp8' for 8-bit floating-point GeMMs on NVIDIA GPUs.
136+
137+
#### Example command for AQT
138+
139+
```bash
140+
python3 -m MaxText.train src/MaxText/configs/base.yml run_name=$YOUR_JOB_NAME base_output_directory=gs://<my-bucket> dataset_type=synthetic use_qwix_quantization=false quantization='int8'
141+
```
142+
Note that `use_qwix_quantization` is not set to `True`.
143+
144+
For further reading, please refer to the [Qwix Read the Docs website](https://qwix.readthedocs.io/en/latest/get_started.html#).

0 commit comments

Comments
 (0)