|
1 | 1 | <!-- |
2 | | - Copyright 2024 Google LLC |
| 2 | + Copyright 2024-2025 Google LLC |
3 | 3 |
|
4 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); |
5 | 5 | you may not use this file except in compliance with the License. |
|
16 | 16 |
|
17 | 17 | # Quantization |
18 | 18 |
|
19 | | -MaxText supports quantization via both the [AQT](https://github.com/google/aqt) and [Qwix](https://github.com/google/qwix) libraries. Qwix is the recommended approach, providing a non-intrusive way to apply various quantization techniques, including Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ). |
| 19 | +Quantization in deep learning is the process of reducing the precision of numbers used to represent a model's weights and/or activations. Instead of using higher-precision floating-point formats like 32-bit floats (`float32`) or 16-bit brain floats (`bfloat16`), quantization maps these values to lower-precision numerical formats, most commonly 8-bit integers (`int8`) or floats (`fp8`). |
20 | 20 |
|
21 | | -## Why quantize? |
| 21 | +MaxText supports quantization via both the [AQT](https://github.com/google/aqt) and [Qwix](https://github.com/google/qwix) libraries. Qwix is the recommended approach, providing a non-intrusive way to apply Quantized Training (QT). |
22 | 22 |
|
23 | | -* **Reduced model size**: Lower precision numbers require less storage, making models easier to store and deploy. |
24 | | -* **Faster inference**: Operations on lower-precision data are computationally less expensive, which can lead to faster inference times. |
25 | | -* **Lower memory usage**: Reduced precision for weights and activations decreases the memory footprint, allowing for the deployment of larger models on hardware with limited memory. |
| 23 | +## Why use quantization? |
26 | 24 |
|
27 | | -## Quantizing using AQT |
| 25 | +The drive to use lower-precision formats like `int8` or `fp8` stems from significant performance advantages: |
28 | 26 |
|
29 | | -Jax supports AQT. You can read more about AQT on this [Google Cloud blog](https://cloud.google.com/blog/products/compute/accurate-quantized-training-aqt-for-tpu-v5e). |
30 | | -You can turn on the quantization by adding the following flag `--quantization` and passing one of the following values: |
| 27 | +**Faster computation**: Hardware accelerators like TPUs and GPUs often have specialized instructions for integer arithmetic. Operations on lower-precision data like `int8` or `fp8` can be significantly faster than on BF16 or FP32. For example, matrix multiplications with these formats can often be 2x or more faster on hardware supporting native lower-precision tensor cores. |
31 | 28 |
|
32 | | -- 'int8' for dynamic range quantization using 8-bits |
33 | | -- 'int8w' for weights only quantization using 8-bits |
34 | | -- 'int4w' for weights only quantization using 4-bits |
35 | | -- 'intmp' for mixed precision weight only quantization based on config file |
36 | | -- 'fp8' for 8-bit floating-point GeMMs on NVIDIA GPUs. |
| 29 | +**Reduced memory footprint**: Storing weights and activations in `int8` or `fp8` requires 2x less memory compared to `bfloat16`. This reduces: |
| 30 | +- **HBM usage**: Less memory is needed on the accelerator itself. |
| 31 | +- **Communication costs**: Less data needs to be transferred between memory and compute units, or across devices in distributed training, which makes these transfers faster and consumes less bandwidth. |
| 32 | +- **Reduced power consumption**: Lower precision operations and reduced memory access lead to less energy usage, which is crucial for deploying models on edge devices and for sustainable AI. |
| 33 | + |
| 34 | +The primary trade-off with quantization is between the model accuracy and computational performance: |
| 35 | + |
| 36 | +* Reduced Dynamic Range & Precision: Lower-precision formats like `int8` or `fp8` can represent a much smaller range of values and with less precision than BF16. This can be problematic for models with wide distributions of weights or activations, potentially clipping large values or losing fine-grained details. |
| 37 | +* Impact on Gradients: Gradients during backpropagation can have very different, often wider, distributions than weights or activations, making them more sensitive to quantization errors. |
| 38 | +* Convergence Issues: The approximations introduced by quantization can sometimes hinder the model's ability to converge during training. |
| 39 | + |
| 40 | +To overcome the challenges of quantization, libraries like Google's Accurate Quantized Training (AQT) and its successor Qwix (used in MaxText) employ a suite of advanced techniques. These methods ensure that models can be trained with low-precision arithmetic without significant loss in accuracy and with stable convergence. |
| 41 | + |
| 42 | +## How Quantized Training (QT) works with Qwix |
| 43 | + |
| 44 | +Quantized Training (QT) incorporates the effects of quantization into the training loop. This allows the model to learn and adapt to the reduced precision of quantized weights and activations. |
| 45 | + |
| 46 | +Here’s how it works: |
| 47 | + |
| 48 | +1. **Forward Pass**: During the forward pass, high-precision weights and activations are converted to a lower-precision format. This step simulates the information loss that occurs during quantization. The model then performs its computations using these lower-precision representations before they are converted back to a higher precision for the rest of the network. This process forces the model to become robust to the noise and reduced range of quantized values. |
| 49 | + |
| 50 | +2. **Backward Pass**: Standard backpropagation cannot flow through the non-differentiable quantization operations (like rounding). To solve this, QT uses the **Straight-Through Estimator (STE)**. The STE essentially "ignores" the non-differentiable quantization step during the backward pass, passing the gradients through as if the operation was an identity function. This allows the high-precision weights to be updated based on the loss, enabling the model to learn effectively. |
| 51 | + |
| 52 | +By integrating the quantization simulation directly into the training, the model learns to minimize the impact of precision loss, resulting in a more accurate quantized model. |
| 53 | + |
| 54 | +## Using Quantization in MaxText |
| 55 | + |
| 56 | +You can enable quantization in MaxText by setting flags in your configuration file (e.g., `base.yml`) or via the command line. MaxText supports two quantization libraries: Qwix (recommended) and AQT. |
| 57 | + |
| 58 | +### Configuration Flags |
37 | 59 |
|
38 | | -## How QAT works with Qwix |
| 60 | +The primary flags to control quantization are: |
39 | 61 |
|
40 | | -The core idea behind QAT is to insert "fake quantization" operations into the model's computation graph. During the training forward pass, these operations simulate the effect of quantizing weights and activations to a lower precision. For the backward pass, Qwix uses the Straight-Through Estimator (STE) to approximate the gradients, allowing the model to learn effectively despite the non-differentiable nature of quantization. |
| 62 | +* `use_qwix_quantization`: A boolean flag. |
| 63 | + * Set to `True` to enable quantization using the Qwix library. |
| 64 | + * Set to `False` (or omit) to use the AQT library if `quantization` is set. |
| 65 | +* `quantization`: A string that specifies the type of quantization to apply. The accepted values depend on whether you are using Qwix or AQT. |
| 66 | +* `quantization_calibration_method`: The calibration method for weights and activations (e.g., `"absmax"`). This is mainly for Qwix. |
41 | 67 |
|
42 | | -## Using Qwix in MaxText |
| 68 | +### Qwix Quantization (Recommended) |
43 | 69 |
|
44 | | -You can enable quantization in MaxText by setting flags in your configuration file (e.g., `base.yml`) or via the command line. |
| 70 | +To use Qwix, you must set `use_qwix_quantization=True`. Qwix is a powerful and non-intrusive library for Quantized Training. |
45 | 71 |
|
46 | | -### Configuration flags |
| 72 | +#### `quantization` values for Qwix |
47 | 73 |
|
48 | | -* `use_qwix_quantization`: Must be set to `True` to enable quantization using the Qwix library. |
49 | | -* `quantization`: Specifies the type of quantization to apply. Common options include: |
50 | | - * `"int8"`: 8-bit integer quantization. |
51 | | - * `"fp8"`: 8-bit floating-point quantization. |
52 | | - * `"fp8_full"`: FP8 quantization with static scaling. |
53 | | - * `"fp8_gpu"`: FP8 for NVIDIA GPUs. |
54 | | - * `"fp8_nanoo"`: FP8 for AMD MI300/MI325 GPUs. |
55 | | -* `quantization_calibration_method`: The calibration method for weights and activations (e.g., `"absmax"`). |
| 74 | +Common options for the `quantization` flag when using Qwix include: |
56 | 75 |
|
57 | | -### Example command |
| 76 | +* `"int8"`: 8-bit integer quantization. |
| 77 | +* `"fp8"`: 8-bit floating-point quantization. |
| 78 | +* `"fp8_full"`: FP8 quantization with static scaling. |
| 79 | +* `"fp8_gpu"`: FP8 for NVIDIA GPUs. |
| 80 | +* `"fp8_nanoo"`: FP8 for AMD MI300/MI325 GPUs. |
| 81 | + |
| 82 | +#### Example command for Qwix |
58 | 83 |
|
59 | 84 | Here is an example of how to run a training job with int8 quantization enabled via Qwix: |
60 | 85 |
|
61 | 86 | ```bash |
62 | | -python3 -m MaxText.train src/MaxText/configs/base.yml ... use_qwix_quantization=True quantization='int8' |
| 87 | +python3 -m MaxText.train src/MaxText/configs/base.yml run_name=$YOUR_JOB_NAME base_output_directory=gs://<my-bucket> dataset_type=synthetic use_qwix_quantization=true quantization='int8' |
63 | 88 | ``` |
64 | 89 |
|
65 | | -## The Qwix interception API |
| 90 | +#### The Qwix Interception API |
66 | 91 |
|
67 | 92 | MaxText integrates Qwix using its powerful and non-intrusive Interception API. This approach allows you to enable QAT for your models without modifying the original model source code. You don't need to manually replace `nn.Dense` with `QuantizedDense` or other quantized layer types. |
68 | 93 |
|
@@ -94,3 +119,26 @@ This rule is then used within a `QtProvider` to quantize the model automatically |
94 | 119 | ```python |
95 | 120 | model = qwix.quantize_model(model, qwix.QtProvider(rule)) |
96 | 121 | ``` |
| 122 | + |
| 123 | +### AQT Quantization |
| 124 | + |
| 125 | +If `use_qwix_quantization` is `False` or not set, you can still apply quantization using the AQT library by setting the `quantization` flag. You can read more about AQT on this [Google Cloud blog](https://cloud.google.com/blog/products/compute/accurate-quantized-training-aqt-for-tpu-v5e). |
| 126 | + |
| 127 | +#### `quantization` values for AQT |
| 128 | + |
| 129 | +When using AQT, you can pass one of the following values to the `quantization` flag: |
| 130 | + |
| 131 | +- 'int8' for dynamic range quantization using 8-bits |
| 132 | +- 'int8w' for weights only quantization using 8-bits |
| 133 | +- 'int4w' for weights only quantization using 4-bits |
| 134 | +- 'intmp' for mixed precision weight only quantization based on config file |
| 135 | +- 'fp8' for 8-bit floating-point GeMMs on NVIDIA GPUs. |
| 136 | + |
| 137 | +#### Example command for AQT |
| 138 | + |
| 139 | +```bash |
| 140 | +python3 -m MaxText.train src/MaxText/configs/base.yml run_name=$YOUR_JOB_NAME base_output_directory=gs://<my-bucket> dataset_type=synthetic use_qwix_quantization=false quantization='int8' |
| 141 | +``` |
| 142 | +Note that `use_qwix_quantization` is not set to `True`. |
| 143 | + |
| 144 | +For further reading, please refer to the [Qwix Read the Docs website](https://qwix.readthedocs.io/en/latest/get_started.html#). |
0 commit comments