Skip to content

Commit 52310f6

Browse files
authored
Clean up fp8 / fp4 recipe handling (#1504)
Some misc. cleanup of #1484 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit # Release Notes * **New Features** * Added FP8/FP4 low-precision quantization support for ESM2, Llama3, Mixtral, and Qwen models with per-layer precision control * Added quantized model initialization option for faster FP8 deployments * Introduced per-layer autocast context management for flexible precision configuration * **Documentation** * Added comprehensive "Running with Low Precision (FP8/FP4)" guides for ESM2 and Llama3 models * **Tests** * Enhanced quantization tests with per-layer precision validation and legacy FP8 pathway coverage * Removed redundant quantization test module * **Chores** * Refactored model initialization flows for quantization recipe handling * Updated configuration paths across training scripts <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent afcb80e commit 52310f6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+2423
-2128
lines changed

.secrets.baseline

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,15 @@
142142
}
143143
],
144144
"results": {
145+
"bionemo-recipes/recipes/esm2_native_te/tests/test_train.py": [
146+
{
147+
"type": "Base64 High Entropy String",
148+
"filename": "bionemo-recipes/recipes/esm2_native_te/tests/test_train.py",
149+
"hashed_secret": "76fc9c9f16bca9a436a5fcede09cc3593a4bf6f0",
150+
"is_verified": false,
151+
"line_number": 511
152+
}
153+
],
145154
"pyproject.toml": [
146155
{
147156
"type": "Hex High Entropy String",
@@ -152,5 +161,5 @@
152161
}
153162
]
154163
},
155-
"generated_at": "2025-12-29T20:49:21Z"
164+
"generated_at": "2026-03-09T20:44:27Z"
156165
}

bionemo-recipes/models/esm2/README.md

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,102 @@ Training recipes are available in the `bionemo-recipes/recipes/` directory:
6969
- **[esm2_accelerate_te](../../recipes/esm2_accelerate_te/)** - Trains the model using HuggingFace
7070
[Accelerate](https://huggingface.co/docs/accelerate/index).
7171

72+
## Running with Low Precision (FP8/FP4)
73+
74+
The TE-optimized ESM-2 model supports per-layer quantization via two mechanisms: a **config-level**
75+
`layer_precision` list that declares which layers use which precision, and **constructor-level** recipe
76+
objects (`fp8_recipe`, `fp4_recipe`) that control the quantization behaviour.
77+
78+
### Configuration: `layer_precision`
79+
80+
`NVEsmConfig.layer_precision` is a list of length `num_hidden_layers` where each element is `"fp8"`,
81+
`"fp4"`, or `None` (BF16 fallback). When set, it controls the `te.autocast` context used for each
82+
transformer layer during both initialization and forward pass.
83+
84+
```python
85+
from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM
86+
87+
# All layers in FP8
88+
config = NVEsmConfig.from_pretrained(
89+
"nvidia/esm2_t6_8M_UR50D",
90+
layer_precision=["fp8"] * 6,
91+
)
92+
```
93+
94+
If you pass an `fp8_recipe` to the model constructor **without** setting `layer_precision`, it
95+
defaults to `["fp8"] * num_hidden_layers` (all layers FP8). You can also mix precisions, for example
96+
running most layers in FP8 but keeping the first and last layers in BF16:
97+
98+
```python
99+
layer_precision = [None] + ["fp8"] * 4 + [None]
100+
config = NVEsmConfig.from_pretrained(
101+
"nvidia/esm2_t6_8M_UR50D",
102+
layer_precision=layer_precision,
103+
)
104+
```
105+
106+
### Constructor arguments: `fp8_recipe` and `fp4_recipe`
107+
108+
The model classes (`NVEsmModel`, `NVEsmForMaskedLM`, `NVEsmForTokenClassification`) accept
109+
`fp8_recipe` and `fp4_recipe` keyword arguments. These are `transformer_engine.common.recipe.Recipe`
110+
objects that configure the quantization algorithm (e.g., delayed scaling, block scaling, MXFP8).
111+
112+
```python
113+
import transformer_engine.common.recipe as te_recipe
114+
115+
from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM
116+
117+
fp8_recipe = te_recipe.DelayedScaling()
118+
119+
config = NVEsmConfig.from_pretrained(
120+
"nvidia/esm2_t6_8M_UR50D",
121+
layer_precision=["fp8"] * 6,
122+
)
123+
model = NVEsmForMaskedLM(config, fp8_recipe=fp8_recipe)
124+
```
125+
126+
For FP4 (NVFP4) quantization, pass an `fp4_recipe` instead and set the corresponding layers to
127+
`"fp4"` in `layer_precision`:
128+
129+
```python
130+
fp4_recipe = te_recipe.NVFP4BlockScaling()
131+
132+
config = NVEsmConfig.from_pretrained(
133+
"nvidia/esm2_t6_8M_UR50D",
134+
layer_precision=["fp4"] * 6,
135+
)
136+
model = NVEsmForMaskedLM(config, fp4_recipe=fp4_recipe)
137+
```
138+
139+
You can also mix FP8 and FP4 layers by providing both recipes and a mixed `layer_precision` list.
140+
141+
### Quantized model initialization: `use_quantized_model_init`
142+
143+
When `use_quantized_model_init=True` is set in the config, layers are created inside a
144+
`te.quantized_model_init` context. This tells TransformerEngine to initialize weights directly in
145+
the target quantized format, avoiding a separate quantization step after initialization. This is
146+
primarily useful when loading pre-quantized checkpoints.
147+
148+
```python
149+
config = NVEsmConfig.from_pretrained(
150+
"nvidia/esm2_t6_8M_UR50D",
151+
layer_precision=["fp4"] * 6,
152+
use_quantized_model_init=True,
153+
)
154+
model = NVEsmForMaskedLM(config, fp4_recipe=te_recipe.NVFP4BlockScaling())
155+
```
156+
157+
### Notes
158+
159+
- The `lm_head` (and `dense` projection in `NVEsmLMHead`) always runs in higher precision
160+
(`te.autocast(enabled=False)`) regardless of `layer_precision`, to avoid numerical instability in
161+
the output logits.
162+
- FP8 requires compute capability 9.0+ (Hopper). MXFP8 requires compute capability 10.0+
163+
(Blackwell).
164+
- If an `fp8_recipe` is provided without `layer_precision`, all layers default to FP8. Providing
165+
both `fp8_recipe` and `fp4_recipe` without `layer_precision` raises a `RuntimeError`.
166+
- An FP4 layer **requires** an `fp4_recipe`; omitting it raises a `RuntimeError`.
167+
72168
## Converting Between Model Formats
73169

74170
This section explains how to convert between Hugging Face Transformers and Transformer Engine (TE) ESM2 model formats.

0 commit comments

Comments
 (0)