Skip to content

Commit 1c85276

Browse files
fix: convergence issue by adding use_inductor=False in vllm compilation_config (#1014)
Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
1 parent f0588dc commit 1c85276

File tree

7 files changed

+400
-2
lines changed

7 files changed

+400
-2
lines changed

docs/adding-new-models.md

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,4 +190,125 @@ uv run --extra mcore tools/model_diagnostics/3.check_hf_model_embeddings_untrain
190190
- Thresholds can be adjusted via flags:
191191
- `--near-zero-threshold` (default: `1e-10`)
192192
- `--identical-threshold` (default: `1e-8`)
193-
- If any near-zero or identical rows are reported, the model may have issues of numerical instability (e.g., inf grad norms) during post-training if any of these problematic tokens are encountered. We have observed this happening when special tokens are reserved in the tokenizer and embedding, but none are encountered during pre-training. It may help to initialize these embeddings similar to how they were initialize during pre-training.
193+
- If any near-zero or identical rows are reported, the model may have issues of numerical instability (e.g., inf grad norms) during post-training if any of these problematic tokens are encountered. We have observed this happening when special tokens are reserved in the tokenizer and embedding, but none are encountered during pre-training. It may help to initialize these embeddings similar to how they were initialize during pre-training.
194+
195+
## [4.vllm_precision_compilation_test.py](https://github.com/NVIDIA-NeMo/RL/blob/main/tools/model_diagnostics/4.vllm_precision_compilation_test.py)
196+
197+
Tests vLLM precision compilation by comparing log probabilities across different compilation modes and configurations. This script helps diagnose numerical precision issues that commonly arise when using different vLLM compilation settings. **Note that this is not a strict pass/fail test** - it's designed to help you understand and investigate numerical discrepancies.
198+
199+
```sh
200+
# Example run
201+
uv run --extra vllm tools/model_diagnostics/4.vllm_precision_compilation_test.py --model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
202+
203+
# Typical output shows mixed results:
204+
# Eager and cuda graph mode lps: FAILED - Arrays are different
205+
...
206+
# Eager and cuda graph mode lps with torch inductor precision flag: FAILED - Arrays are different
207+
...
208+
# Eager and cuda graph mode lps with use_inductor disabled: PASSED - Arrays are close within tolerance (atol=0.001, rtol=0.001)
209+
```
210+
211+
See example for model `deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B`
212+
```
213+
====================================================================================================
214+
Eager and cuda graph mode lps (prompt lps): FAILED - Arrays are different
215+
Detailed error:
216+
Not equal to tolerance rtol=0.001, atol=0.001
217+
218+
Mismatched elements: 96 / 515 (18.6%)
219+
Max absolute difference among violations: 0.3885002
220+
Max relative difference among violations: 0.20179409
221+
ACTUAL: array([[-1.424489e+01, -3.924684e-01, -3.135911e+00, -4.258007e-01,
222+
-3.443364e-04, nan, nan, nan,
223+
nan, nan, nan, nan,...
224+
DESIRED: array([[-1.420929e+01, -3.619126e-01, -3.241854e+00, -4.308376e-01,
225+
-3.047717e-04, nan, nan, nan,
226+
nan, nan, nan, nan,...
227+
====================================================================================================
228+
====================================================================================================
229+
Eager and cuda graph mode lps (generation lps): FAILED - Arrays are different
230+
Detailed error:
231+
Not equal to tolerance rtol=0.001, atol=0.001
232+
233+
nan location mismatch:
234+
ACTUAL: array([[-1.231834e+01, -1.411233e-01, -3.764260e-01, ..., nan,
235+
nan, nan],
236+
[-8.567932e+00, -1.066314e+01, -4.463661e-01, ..., nan,...
237+
DESIRED: array([[-1.226752e+01, -1.508305e-01, -4.024158e-01, ..., nan,
238+
nan, nan],
239+
[-8.610202e+00, -1.067061e+01, -4.593382e-01, ..., -1.060957e-05,...
240+
====================================================================================================
241+
...
242+
====================================================================================================
243+
Eager and cuda graph mode lps with torch inductor precision flag (prompt lps): FAILED - Arrays are different
244+
Detailed error:
245+
Not equal to tolerance rtol=0.001, atol=0.001
246+
247+
Mismatched elements: 96 / 515 (18.6%)
248+
Max absolute difference among violations: 0.3885002
249+
Max relative difference among violations: 0.20179409
250+
ACTUAL: array([[-1.424489e+01, -3.924684e-01, -3.135911e+00, -4.258007e-01,
251+
-3.443364e-04, nan, nan, nan,
252+
nan, nan, nan, nan,...
253+
DESIRED: array([[-1.420929e+01, -3.619126e-01, -3.241854e+00, -4.308376e-01,
254+
-3.047717e-04, nan, nan, nan,
255+
nan, nan, nan, nan,...
256+
====================================================================================================
257+
====================================================================================================
258+
Eager and cuda graph mode lps with torch inductor precision flag (generation lps): FAILED - Arrays are different
259+
Detailed error:
260+
Not equal to tolerance rtol=0.001, atol=0.001
261+
262+
nan location mismatch:
263+
ACTUAL: array([[-1.231834e+01, -1.411233e-01, -3.764260e-01, ..., nan,
264+
nan, nan],
265+
[-8.567932e+00, -1.066314e+01, -4.463661e-01, ..., nan,...
266+
DESIRED: array([[-1.226752e+01, -1.508305e-01, -4.024158e-01, ..., nan,
267+
nan, nan],
268+
[-8.610202e+00, -1.067061e+01, -4.593382e-01, ..., -1.060957e-05,...
269+
====================================================================================================
270+
...
271+
Eager and cuda graph mode lps with use_inductor disabled (prompt lps): PASSED - Arrays are close within tolerance (atol=0.001, rtol=0.001)
272+
Eager and cuda graph mode lps with use_inductor disabled (generation lps): PASSED - Arrays are close within tolerance (atol=0.001, rtol=0.001)
273+
```
274+
275+
**What this script tests:**
276+
277+
The script is to compare both prompt and generation logprobs under the following setups:
278+
279+
1. **Eager vs CUDA Graph Mode**: Compares log probabilities between eager execution (ground truth) and CUDA graph compilation mode
280+
- **⚠️ Commonly fails**: This comparison often shows discrepancies due to compilation optimizations
281+
2. **Torch Inductor Precision**: Tests with `TORCHINDUCTOR_EMULATE_PRECISION_CASTS=1` environment variable
282+
- **⚠️ May help**: This flag may help but typically doesn't resolve all the numerical differences
283+
3. **Inductor Disabled**: Verifies that disabling Torch Inductor compilation (`use_inductor=False`) maintains output consistency
284+
- **✅ Usually works well**: This configuration often produces results very close to eager mode
285+
- **Note**: `use_inductor=False` disables Inductor compilation but keeps CUDA graph capture active for compatible operations
286+
287+
**Performance vs Accuracy Trade-offs:**
288+
289+
The different compilation modes offer distinct trade-offs between accuracy and performance:
290+
291+
- **Eager Mode** (`enforce_eager=True`): Highest accuracy (ground truth) but slowest execution
292+
- **CUDA Graph Mode with Inductor Disabled** (`enforce_eager=False` and `compilation_config={"use_inductor": False}`): Near-eager accuracy with significant speedup from CUDA graph optimization
293+
- **CUDA Graph Mode with Inductor Enabled** (`enforce_eager=False` and `compilation_config={"use_inductor": True}`): Potentially fastest execution with custom Triton kernels (since Triton is the current backend of Inductor), but may introduce numerical differences. For accuracy improvement, try the torch inductor precision flag: `export TORCHINDUCTOR_EMULATE_PRECISION_CASTS=1`
294+
295+
**Note**: Performance characteristics vary by model. For example, `deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B` shows similar speed performance between `use_inductor=True` and `use_inductor=False`, making the accuracy-preserving option preferable.
296+
297+
**Why this matters:**
298+
299+
- **Debugging**: Helps identify which compilation settings cause numerical differences
300+
- **Configuration**: Shows which settings work best for your model
301+
- **Understanding**: Reveals how compilation affects model outputs
302+
303+
**When to use:**
304+
305+
- **Model integration** - understand numerical behavior across vLLM configurations
306+
- **Debugging** - investigate differences between development and production
307+
- **Research** - study compilation strategy impacts on precision
308+
309+
**Interpreting results:**
310+
311+
- **Eager vs CUDA Graph failures are normal** - don't panic if this fails
312+
- **Focus on patterns** - some models are more sensitive than others
313+
- **Use as guidance** - helps choose reliable compilation settings
314+
- **Balance precision vs performance** - choose what works for your use case

examples/configs/grpo_math_1B.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,14 @@ policy:
180180
enable_expert_parallel: false
181181
gpu_memory_utilization: 0.6
182182
max_model_len: ${policy.max_total_sequence_length}
183+
# when enforce_eager is False, it is optional to set ++policy.generation.vllm_kwargs.compilation_config.use_inductor=False for better accuracy,
184+
# with the flag, vllm will use the custom CUDA kernels instead of the Triton kernels generated by torch.compile
185+
# for more details, see convergence issue https://github.com/NVIDIA-NeMo/RL/issues/998
183186
enforce_eager: False
184187
use_deep_gemm: False
185188
num_last_layers_in_bf16: 0
186189
num_first_layers_in_bf16: 0
190+
vllm_kwargs: {}
187191
colocated:
188192
# true: generation shares training GPUs
189193
# false: uses dedicated generation resources

examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,13 @@ policy:
104104
enable_expert_parallel: false
105105
gpu_memory_utilization: 0.6
106106
max_model_len: ${policy.max_total_sequence_length}
107-
enforce_eager: True
107+
enforce_eager: False
108+
vllm_kwargs:
109+
compilation_config:
110+
# when enforce_eager is False, set ++policy.generation.vllm_kwargs.compilation_config.use_inductor=False for better accuracy,
111+
# with the flag, vllm will use the custom CUDA kernels instead of the Triton kernels generated by torch.compile
112+
# for more details, see convergence issue https://github.com/NVIDIA-NeMo/RL/issues/998
113+
use_inductor: False
108114
colocated:
109115
# true: generation shares training GPUs
110116
# false: uses dedicated generation resources

nemo_rl/models/generation/vllm/vllm_worker_async.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,16 @@
3636
) # pragma: no cover
3737
class VllmAsyncGenerationWorker(BaseVllmGenerationWorker):
3838
def _create_engine(self, llm_kwargs: dict[str, Any]) -> None:
39+
from vllm.config import CompilationConfig
3940
from vllm.engine.arg_utils import AsyncEngineArgs
4041
from vllm.v1.engine.async_llm import AsyncLLM
4142

43+
# (TODO: zhiyul) Remove this workaround after upgrading vLLM where the compilation_config passing issue is resolved.
44+
if llm_kwargs.get("compilation_config", None):
45+
llm_kwargs["compilation_config"] = CompilationConfig(
46+
**llm_kwargs["compilation_config"]
47+
)
48+
4249
self.llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**llm_kwargs))
4350

4451
async def post_init_async(self):

pyrefly.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ project-includes = [
115115
"nemo_rl/utils/venvs.py",
116116
"tools/model_diagnostics/1.max_model_len_respected.py",
117117
"tools/model_diagnostics/2.long_generation_decode_vs_prefill.py",
118+
"tools/model_diagnostics/4.vllm_precision_compilation_test.py",
118119
]
119120

120121
# Disable TypedDict mutation errors since TypedDict objects are regular dicts at runtime

tests/unit/test_recipes_and_test_suites.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@
3737
"vlm_grpo": "examples/configs/vlm_grpo_3B.yaml",
3838
}
3939

40+
# Configuration keys that are allowed to be added to base configs during testing
41+
# These keys may exist in recipe configs but not in base configs, so we need to
42+
# manually add them to avoid merge conflicts during config validation
43+
ALLOWED_ADDITIONAL_CONFIG_KEYS = ["policy.generation.vllm_kwargs"]
44+
4045

4146
@pytest.fixture
4247
def nightly_test_suite():
@@ -262,6 +267,18 @@ def test_all_recipes_can_merge_configs_with_base_config(
262267
recipe_yaml_path = os.path.join(recipes_dir, recipe_yaml)
263268
recipe_config = load_config(recipe_yaml_path)
264269
OmegaConf.set_struct(recipe_config, True)
270+
271+
# Work around ALLOWED_ADDITIONAL_CONFIG_KEYS by manually adding allowed keys to the base config
272+
# This prevents merge conflicts when recipe configs contain keys not present in base configs
273+
for key in ALLOWED_ADDITIONAL_CONFIG_KEYS:
274+
if OmegaConf.select(recipe_config, key):
275+
OmegaConf.update(
276+
base_config,
277+
key,
278+
OmegaConf.select(recipe_config, key),
279+
force_add=True,
280+
)
281+
265282
# This will raise a error if the config can't be merged
266283
print(f"Merging {recipe_yaml} with {base_yaml}")
267284
merged_config = OmegaConf.merge(base_config, recipe_config)

0 commit comments

Comments
 (0)