Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/guides/dtensor-tp-accuracy.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,6 @@ In batch-variant kernels, low-level implementation details—such as parallel re

After aligning `train_micro_batch_size` and `logprob_batch_size` so that the same samples are processed with identical effective batch configurations, the importance-sampling ratio (`probs_ratio`) becomes 1 as expected, and the observed accuracy issues disappear. This confirms that the mismatch was caused by batch-dependent numerical variation rather than a conceptual error in the RL objective or data pipeline.

Importantly, this problem is **model-specific** rather than universal. Models such as `deepseek-ai/DeepSeek-R1-Distill-Qwen-7B` and `Qwen/Qwen2.5-1.5B` exhibit clear batch-variant behavior under these settings, whereas `meta-llama/Llama-3.1-8B-Instruct` does not show the same sensitivity, likely due to differences in architecture, kernel implementations, or optimization choices in their respective stacks.

### Recommended Solutions

When using DTensor with TP > 1, or when `probs_ratio != 1` is observed in an on-policy setting, the following mitigation strategies are recommended to restore numerical consistency and stabilize training:
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/grpo_math_70B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ policy:
train_global_batch_size: 512
train_micro_batch_size: 1
generation_batch_size: 32 # Only used when generating using HF backend
logprob_batch_size: 4
logprob_batch_size: 1
max_total_sequence_length: 4096
precision: "bfloat16"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ policy:
tokenizer:
name: google/gemma-3-27b-it
train_micro_batch_size: 1
logprob_batch_size: 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about making the edit here:
https://github.com/NVIDIA-NeMo/RL/blob/main/examples/configs/grpo_math_1B.yaml#L78

and do something like:

logprob_batch_size: ${.train_micro_batch_size}

logprob_batch_size: 1
max_total_sequence_length: 16384
dtensor_cfg:
activation_checkpointing: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ loss_fn:
policy:
model_name: openai/gpt-oss-20b
train_micro_batch_size: 1
logprob_batch_size: 1
max_total_sequence_length: 4096
megatron_cfg:
enabled: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ policy:
tokenizer:
name: Qwen/Qwen2.5-32B
train_micro_batch_size: 1
logprob_batch_size: 2
logprob_batch_size: 1
max_total_sequence_length: 16384
dtensor_cfg:
activation_checkpointing: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ policy:
tokenizer:
name: Qwen/Qwen2.5-32B
train_micro_batch_size: 1
logprob_batch_size: 2
logprob_batch_size: 1
max_total_sequence_length: 16384
dtensor_cfg:
activation_checkpointing: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ policy:
tokenizer:
name: Qwen/Qwen2.5-7B-Instruct
train_micro_batch_size: 1
logprob_batch_size: 2
logprob_batch_size: 1
max_total_sequence_length: 4096
dtensor_cfg:
tensor_parallel_size: 4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ checkpointing:
policy:
model_name: Qwen/Qwen3-30B-A3B
train_micro_batch_size: 1
logprob_batch_size: 1
max_total_sequence_length: 4096
dtensor_cfg:
enabled: false
Expand Down
30 changes: 17 additions & 13 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,23 @@ def setup(
flush=True,
)

# ==========================
# Loss Function
# ==========================
loss_fn = ClippedPGLossFn(loss_config)

# Validate force_on_policy_ratio
if loss_config.get("force_on_policy_ratio", False):
assert (
grpo_config["num_prompts_per_step"]
* grpo_config["num_generations_per_prompt"]
== policy_config["train_global_batch_size"]
), (
"force_on_policy_ratio requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt"
)
os.environ["NRL_IGNORE_TP_ACCURACY_CHECK"] = "1"
print(" ✓ force_on_policy_ratio enabled")

# ==========================
# Cluster
# ==========================
Expand Down Expand Up @@ -660,19 +677,6 @@ def initialize_generation_with_policy(
if policy_generation is not None:
policy_generation.prepare_refit_info(state_dict_info)

loss_fn = ClippedPGLossFn(loss_config)

# Validate force_on_policy_ratio
if loss_config.get("force_on_policy_ratio", False):
assert (
grpo_config["num_prompts_per_step"]
* grpo_config["num_generations_per_prompt"]
== policy_config["train_global_batch_size"]
), (
"force_on_policy_ratio requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt"
)
print(" ✓ force_on_policy_ratio enabled")

# Calculate total setup time
total_setup_time = time.perf_counter() - setup_start_time
worker_init_timing_metrics["total_setup_time_s"] = total_setup_time
Expand Down
19 changes: 19 additions & 0 deletions nemo_rl/models/policy/lm_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,25 @@ def __init__(
model_parallel_size = pp_size * cp_size * tp_size
actual_world_size = cluster.world_size()

if (
not bool(os.environ.get("NRL_IGNORE_TP_ACCURACY_CHECK"))
and "logprob_batch_size" in config
and tp_size >= 4
):
sep_line = "\n" + ("-" * 80)
assert config["train_micro_batch_size"] == config["logprob_batch_size"], (
f"{sep_line}\n"
"There is a known batch-variant accuracy issue with TP>=4 for both DTensor and Megatron backend.\n"
"See https://docs.nvidia.com/nemo/rl/latest/guides/dtensor-tp-accuracy.html#root-cause for more details.\n"
"\n"
"Please choose either of the following solutions to avoid this issue:\n"
"1. Set tp_size to 1 or 2. (tensor_parallel_size for DTensor, or tensor_model_parallel_size for Megatron)\n"
"2. Set policy.train_micro_batch_size and policy.logprob_batch_size to be the same value.\n"
"3. Set loss_fn.force_on_policy_ratio=true to force ratio=1.0, this requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt.\n"
"4. Set NRL_IGNORE_TP_ACCURACY_CHECK=1 to bypass this check. (not recommended)"
f"{sep_line}\n"
)

if actual_world_size < model_parallel_size:
raise ValueError(
f"World size ({actual_world_size}) is insufficient for the parallelism configuration. "
Expand Down
34 changes: 32 additions & 2 deletions tests/unit/algorithms/test_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,22 @@ def test_noncolocated_inference_requires_explicit_gpus_per_node_single_node():
},
},
},
"loss_fn": {}, # Config extraction requires this key
"loss_fn": {
"ratio_clip_min": 0.2,
"ratio_clip_max": 0.2,
"ratio_clip_c": None,
"disable_ppo_ratio": False,
"reference_policy_kl_penalty": 0.0,
"reference_policy_kl_type": "k3",
"kl_input_clamp_value": 20.0,
"kl_output_clamp_value": 10.0,
"use_on_policy_kl_approximation": False,
"use_importance_sampling_correction": False,
"truncated_importance_sampling_ratio": None,
"sequence_level_importance_ratios": False,
"token_level_loss": True,
"force_on_policy_ratio": False,
},
"env": {}, # Config extraction requires this key
"grpo": {
"seed": 42,
Expand Down Expand Up @@ -775,7 +790,22 @@ def test_noncolocated_inference_requires_explicit_gpus_per_node_multi_node():
},
},
},
"loss_fn": {}, # Config extraction requires this key
"loss_fn": {
"ratio_clip_min": 0.2,
"ratio_clip_max": 0.2,
"ratio_clip_c": None,
"disable_ppo_ratio": False,
"reference_policy_kl_penalty": 0.0,
"reference_policy_kl_type": "k3",
"kl_input_clamp_value": 20.0,
"kl_output_clamp_value": 10.0,
"use_on_policy_kl_approximation": False,
"use_importance_sampling_correction": False,
"truncated_importance_sampling_ratio": None,
"sequence_level_importance_ratios": False,
"token_level_loss": True,
"force_on_policy_ratio": False,
},
"env": {}, # Config extraction requires this key
"grpo": {
"seed": 42,
Expand Down
Loading