Skip to content

Commit 05a95b7

Browse files
authored
Merge branch 'main' into docs-reorg
2 parents e4067bf + df45a16 commit 05a95b7

File tree

6 files changed

+200
-131
lines changed

6 files changed

+200
-131
lines changed

.github/workflows/run_tests_against_package.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ on:
3131
pytest_marker:
3232
required: true
3333
type: string
34+
pytest_addopts:
35+
required: false
36+
type: string
37+
default: ''
3438
is_scheduled_run:
3539
required: true
3640
type: string
@@ -107,4 +111,4 @@ jobs:
107111
fi
108112
# TODO: Fix the skipped tests and remove the deselect flags
109113
[ "${{ inputs.total_workers }}" -gt 1 ] && .venv/bin/python3 -m pip install --quiet pytest-split && SPLIT_ARGS="--splits ${{ inputs.total_workers }} --group ${{ inputs.worker_group }}" || SPLIT_ARGS=""
110-
.venv/bin/python3 -m pytest -v -m "${FINAL_PYTEST_MARKER}" --durations=0 --deselect "tests/aot_hlo_identical_test.py::AotHloIdenticalTest::test_default_hlo_match" --deselect "tests/tokenizer_test.py::TokenizerTest::test_detokenize" $SPLIT_ARGS
114+
.venv/bin/python3 -m pytest ${{ inputs.pytest_addopts }} -v -m "${FINAL_PYTEST_MARKER}" --durations=0 --deselect "tests/aot_hlo_identical_test.py::AotHloIdenticalTest::test_default_hlo_match" --deselect "tests/tokenizer_test.py::TokenizerTest::test_detokenize" $SPLIT_ARGS

docs/explanations/quantization.md

Lines changed: 78 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
<!--
2-
Copyright 2024 Google LLC
2+
Copyright 2024-2025 Google LLC
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -16,53 +16,78 @@
1616

1717
# Quantization
1818

19-
MaxText supports quantization via both the [AQT](https://github.com/google/aqt) and [Qwix](https://github.com/google/qwix) libraries. Qwix is the recommended approach, providing a non-intrusive way to apply various quantization techniques, including Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ).
19+
Quantization in deep learning is the process of reducing the precision of numbers used to represent a model's weights and/or activations. Instead of using higher-precision floating-point formats like 32-bit floats (`float32`) or 16-bit brain floats (`bfloat16`), quantization maps these values to lower-precision numerical formats, most commonly 8-bit integers (`int8`) or floats (`fp8`).
2020

21-
## Why quantize?
21+
MaxText supports quantization via both the [AQT](https://github.com/google/aqt) and [Qwix](https://github.com/google/qwix) libraries. Qwix is the recommended approach, providing a non-intrusive way to apply Quantized Training (QT).
2222

23-
* **Reduced model size**: Lower precision numbers require less storage, making models easier to store and deploy.
24-
* **Faster inference**: Operations on lower-precision data are computationally less expensive, which can lead to faster inference times.
25-
* **Lower memory usage**: Reduced precision for weights and activations decreases the memory footprint, allowing for the deployment of larger models on hardware with limited memory.
23+
## Why use quantization?
2624

27-
## Quantizing using AQT
25+
The drive to use lower-precision formats like `int8` or `fp8` stems from significant performance advantages:
2826

29-
Jax supports AQT. You can read more about AQT on this [Google Cloud blog](https://cloud.google.com/blog/products/compute/accurate-quantized-training-aqt-for-tpu-v5e).
30-
You can turn on the quantization by adding the following flag `--quantization` and passing one of the following values:
27+
**Faster computation**: Hardware accelerators like TPUs and GPUs often have specialized instructions for integer arithmetic. Operations on lower-precision data like `int8` or `fp8` can be significantly faster than on BF16 or FP32. For example, matrix multiplications with these formats can often be 2x or more faster on hardware supporting native lower-precision tensor cores.
3128

32-
- 'int8' for dynamic range quantization using 8-bits
33-
- 'int8w' for weights only quantization using 8-bits
34-
- 'int4w' for weights only quantization using 4-bits
35-
- 'intmp' for mixed precision weight only quantization based on config file
36-
- 'fp8' for 8-bit floating-point GeMMs on NVIDIA GPUs.
29+
**Reduced memory footprint**: Storing weights and activations in `int8` or `fp8` requires 2x less memory compared to `bfloat16`. This reduces:
30+
- **HBM usage**: Less memory is needed on the accelerator itself.
31+
- **Communication costs**: Less data needs to be transferred between memory and compute units, or across devices in distributed training, which makes these transfers faster and consumes less bandwidth.
32+
- **Reduced power consumption**: Lower precision operations and reduced memory access lead to less energy usage, which is crucial for deploying models on edge devices and for sustainable AI.
33+
34+
The primary trade-off with quantization is between the model accuracy and computational performance:
35+
36+
* Reduced Dynamic Range & Precision: Lower-precision formats like `int8` or `fp8` can represent a much smaller range of values and with less precision than BF16. This can be problematic for models with wide distributions of weights or activations, potentially clipping large values or losing fine-grained details.
37+
* Impact on Gradients: Gradients during backpropagation can have very different, often wider, distributions than weights or activations, making them more sensitive to quantization errors.
38+
* Convergence Issues: The approximations introduced by quantization can sometimes hinder the model's ability to converge during training.
39+
40+
To overcome the challenges of quantization, libraries like Google's Accurate Quantized Training (AQT) and its successor Qwix (used in MaxText) employ a suite of advanced techniques. These methods ensure that models can be trained with low-precision arithmetic without significant loss in accuracy and with stable convergence.
41+
42+
## How Quantized Training (QT) works with Qwix
43+
44+
Quantized Training (QT) incorporates the effects of quantization into the training loop. This allows the model to learn and adapt to the reduced precision of quantized weights and activations.
45+
46+
Here’s how it works:
47+
48+
1. **Forward Pass**: During the forward pass, high-precision weights and activations are converted to a lower-precision format. This step simulates the information loss that occurs during quantization. The model then performs its computations using these lower-precision representations before they are converted back to a higher precision for the rest of the network. This process forces the model to become robust to the noise and reduced range of quantized values.
49+
50+
2. **Backward Pass**: Standard backpropagation cannot flow through the non-differentiable quantization operations (like rounding). To solve this, QT uses the **Straight-Through Estimator (STE)**. The STE essentially "ignores" the non-differentiable quantization step during the backward pass, passing the gradients through as if the operation was an identity function. This allows the high-precision weights to be updated based on the loss, enabling the model to learn effectively.
51+
52+
By integrating the quantization simulation directly into the training, the model learns to minimize the impact of precision loss, resulting in a more accurate quantized model.
53+
54+
## Using Quantization in MaxText
55+
56+
You can enable quantization in MaxText by setting flags in your configuration file (e.g., `base.yml`) or via the command line. MaxText supports two quantization libraries: Qwix (recommended) and AQT.
57+
58+
### Configuration Flags
3759

38-
## How QAT works with Qwix
60+
The primary flags to control quantization are:
3961

40-
The core idea behind QAT is to insert "fake quantization" operations into the model's computation graph. During the training forward pass, these operations simulate the effect of quantizing weights and activations to a lower precision. For the backward pass, Qwix uses the Straight-Through Estimator (STE) to approximate the gradients, allowing the model to learn effectively despite the non-differentiable nature of quantization.
62+
* `use_qwix_quantization`: A boolean flag.
63+
* Set to `True` to enable quantization using the Qwix library.
64+
* Set to `False` (or omit) to use the AQT library if `quantization` is set.
65+
* `quantization`: A string that specifies the type of quantization to apply. The accepted values depend on whether you are using Qwix or AQT.
66+
* `quantization_calibration_method`: The calibration method for weights and activations (e.g., `"absmax"`). This is mainly for Qwix.
4167

42-
## Using Qwix in MaxText
68+
### Qwix Quantization (Recommended)
4369

44-
You can enable quantization in MaxText by setting flags in your configuration file (e.g., `base.yml`) or via the command line.
70+
To use Qwix, you must set `use_qwix_quantization=True`. Qwix is a powerful and non-intrusive library for Quantized Training.
4571

46-
### Configuration flags
72+
#### `quantization` values for Qwix
4773

48-
* `use_qwix_quantization`: Must be set to `True` to enable quantization using the Qwix library.
49-
* `quantization`: Specifies the type of quantization to apply. Common options include:
50-
* `"int8"`: 8-bit integer quantization.
51-
* `"fp8"`: 8-bit floating-point quantization.
52-
* `"fp8_full"`: FP8 quantization with static scaling.
53-
* `"fp8_gpu"`: FP8 for NVIDIA GPUs.
54-
* `"fp8_nanoo"`: FP8 for AMD MI300/MI325 GPUs.
55-
* `quantization_calibration_method`: The calibration method for weights and activations (e.g., `"absmax"`).
74+
Common options for the `quantization` flag when using Qwix include:
5675

57-
### Example command
76+
* `"int8"`: 8-bit integer quantization.
77+
* `"fp8"`: 8-bit floating-point quantization.
78+
* `"fp8_full"`: FP8 quantization with static scaling.
79+
* `"fp8_gpu"`: FP8 for NVIDIA GPUs.
80+
* `"fp8_nanoo"`: FP8 for AMD MI300/MI325 GPUs.
81+
82+
#### Example command for Qwix
5883

5984
Here is an example of how to run a training job with int8 quantization enabled via Qwix:
6085

6186
```bash
62-
python3 -m MaxText.train src/MaxText/configs/base.yml ... use_qwix_quantization=True quantization='int8'
87+
python3 -m MaxText.train src/MaxText/configs/base.yml run_name=$YOUR_JOB_NAME base_output_directory=gs://<my-bucket> dataset_type=synthetic use_qwix_quantization=true quantization='int8'
6388
```
6489

65-
## The Qwix interception API
90+
#### The Qwix Interception API
6691

6792
MaxText integrates Qwix using its powerful and non-intrusive Interception API. This approach allows you to enable QAT for your models without modifying the original model source code. You don't need to manually replace `nn.Dense` with `QuantizedDense` or other quantized layer types.
6893

@@ -94,3 +119,26 @@ This rule is then used within a `QtProvider` to quantize the model automatically
94119
```python
95120
model = qwix.quantize_model(model, qwix.QtProvider(rule))
96121
```
122+
123+
### AQT Quantization
124+
125+
If `use_qwix_quantization` is `False` or not set, you can still apply quantization using the AQT library by setting the `quantization` flag. You can read more about AQT on this [Google Cloud blog](https://cloud.google.com/blog/products/compute/accurate-quantized-training-aqt-for-tpu-v5e).
126+
127+
#### `quantization` values for AQT
128+
129+
When using AQT, you can pass one of the following values to the `quantization` flag:
130+
131+
- 'int8' for dynamic range quantization using 8-bits
132+
- 'int8w' for weights only quantization using 8-bits
133+
- 'int4w' for weights only quantization using 4-bits
134+
- 'intmp' for mixed precision weight only quantization based on config file
135+
- 'fp8' for 8-bit floating-point GeMMs on NVIDIA GPUs.
136+
137+
#### Example command for AQT
138+
139+
```bash
140+
python3 -m MaxText.train src/MaxText/configs/base.yml run_name=$YOUR_JOB_NAME base_output_directory=gs://<my-bucket> dataset_type=synthetic use_qwix_quantization=false quantization='int8'
141+
```
142+
Note that `use_qwix_quantization` is not set to `True`.
143+
144+
For further reading, please refer to the [Qwix Read the Docs website](https://qwix.readthedocs.io/en/latest/get_started.html#).

0 commit comments

Comments
 (0)