Skip to content

Commit 7684dc2

Browse files
terrykongHeyyyyyyG
andauthored
feat: async nemo gym (#1985)
Signed-off-by: Jiaqi Zeng <jiaqiz@nvidia.com> Signed-off-by: Terry Kong <terryk@nvidia.com> Co-authored-by: Jiaqi Zeng <jiaqiz@nvidia.com>
1 parent ac2c774 commit 7684dc2

File tree

9 files changed

+268
-76
lines changed

9 files changed

+268
-76
lines changed

examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,13 @@ grpo:
3737
skip_reference_policy_logprobs_calculation: true
3838
seq_logprob_error_threshold: null
3939

40+
async_grpo:
41+
enabled: false # Set to true to enable async training mode
42+
# Max age (in training steps) for trajectories used in training
43+
max_trajectory_age_steps: 1
44+
in_flight_weight_updates: false # Set to true to enable in-flight weight updates
45+
recompute_kv_cache_after_weight_updates: false # Set to true to recompute kv cache after in-flight-weight-updates
46+
4047
loss_fn:
4148
reference_policy_kl_penalty: 0
4249
reference_policy_kl_type: "k3"
@@ -246,6 +253,10 @@ data:
246253
shuffle: true
247254
num_workers: 0
248255

256+
# use multiple dataloader for train
257+
# see https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/grpo.md#multiple-dataloaders for more details.
258+
use_multiple_dataloader: false
259+
249260
# Using the prepared train and validation datasets (downloaded from HuggingFace and split 90/10)
250261
# Train: 1129 samples, Validation: 126 samples
251262
train:

examples/nemo_gym/run_grpo_nemo_gym.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,61 @@ def main() -> None:
231231
logger=logger,
232232
master_config=master_config,
233233
)
234+
# Check if async mode is enabled
235+
elif "async_grpo" in config["grpo"] and config["grpo"]["async_grpo"]["enabled"]:
236+
# Async GRPO does not support dynamic sampling, reward scaling, or reward shaping (DAPO features)
237+
unsupported_features = [
238+
"use_dynamic_sampling",
239+
"reward_scaling",
240+
"reward_shaping",
241+
]
242+
243+
for feature in unsupported_features:
244+
if feature not in config["grpo"]:
245+
continue
246+
247+
if feature == "use_dynamic_sampling":
248+
if config["grpo"][feature]:
249+
raise NotImplementedError(
250+
f"{feature} is not supported with async GRPO"
251+
)
252+
else:
253+
if config["grpo"][feature]["enabled"]:
254+
raise NotImplementedError(
255+
f"{feature} is not supported with async GRPO"
256+
)
257+
258+
# Async GRPO does not support multiple dataloaders
259+
if config["data"]["use_multiple_dataloader"]:
260+
raise NotImplementedError(
261+
"use_multiple_dataloader is not supported with async GRPO"
262+
)
263+
264+
from nemo_rl.algorithms.grpo import async_grpo_train
265+
266+
print("🚀 Running async GRPO training")
267+
268+
async_config = config["grpo"]["async_grpo"]
269+
# Run async GRPO training
270+
async_grpo_train(
271+
policy=policy,
272+
policy_generation=policy_generation,
273+
dataloader=dataloader,
274+
val_dataloader=val_dataloader,
275+
tokenizer=tokenizer,
276+
loss_fn=loss_fn,
277+
task_to_env=task_to_env,
278+
val_task_to_env=val_task_to_env,
279+
logger=logger,
280+
checkpointer=checkpointer,
281+
grpo_save_state=grpo_state,
282+
master_config=master_config,
283+
max_trajectory_age_steps=async_config["max_trajectory_age_steps"],
284+
)
234285
else:
286+
print("🚀 Running synchronous GRPO training")
287+
288+
# Run standard GRPO training
235289
grpo_train(
236290
policy,
237291
policy_generation,

nemo_rl/algorithms/async_utils.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -642,17 +642,40 @@ def _run_prompt_group_worker(
642642
prompt_idx: int,
643643
) -> None:
644644
try:
645+
# Import here to avoid circular dependency
646+
from nemo_rl.algorithms.grpo import _should_use_nemo_gym
647+
from nemo_rl.experience.rollouts import run_async_nemo_gym_rollout
648+
645649
# Run rollout for this prompt group
646650
# Async engine supports concurrent generation; avoid locking
647-
final_batch, rollout_metrics = run_async_multi_turn_rollout(
648-
policy_generation=self.policy_generation,
649-
input_batch=repeated_batch,
650-
tokenizer=self.tokenizer,
651-
task_to_env=self.task_to_env,
652-
max_seq_len=self.master_config["policy"]["max_total_sequence_length"],
653-
max_rollout_turns=self.master_config["grpo"]["max_rollout_turns"],
654-
greedy=False,
655-
)
651+
# Check if we should use nemo_gym (similar to synchronous GRPO)
652+
if _should_use_nemo_gym(self.master_config):
653+
generation_config = self.master_config["policy"]["generation"]
654+
env_cfg = self.master_config.get("env") or {}
655+
nemo_gym_rollout_result = run_async_nemo_gym_rollout(
656+
policy_generation=self.policy_generation,
657+
input_batch=repeated_batch,
658+
tokenizer=self.tokenizer,
659+
task_to_env=self.task_to_env,
660+
max_seq_len=None,
661+
generation_config=generation_config,
662+
max_rollout_turns=None,
663+
greedy=False,
664+
)
665+
final_batch = nemo_gym_rollout_result.final_batch
666+
rollout_metrics = nemo_gym_rollout_result.rollout_metrics
667+
else:
668+
final_batch, rollout_metrics = run_async_multi_turn_rollout(
669+
policy_generation=self.policy_generation,
670+
input_batch=repeated_batch,
671+
tokenizer=self.tokenizer,
672+
task_to_env=self.task_to_env,
673+
max_seq_len=self.master_config["policy"][
674+
"max_total_sequence_length"
675+
],
676+
max_rollout_turns=self.master_config["grpo"]["max_rollout_turns"],
677+
greedy=False,
678+
)
656679

657680
# Move to CPU and push to buffer (avoid blocking on GC/push)
658681
final_batch_cpu = final_batch.to("cpu")

nemo_rl/algorithms/grpo.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3059,11 +3059,24 @@ def async_grpo_train(
30593059
checkpointer.finalize_checkpoint(checkpoint_path)
30603060
policy.offload_after_refit()
30613061

3062-
log_data = {"content": flat_messages_content}
3062+
# Logging
3063+
# Log training data (match sync GRPO logging payload for parity)
3064+
log_data = {}
3065+
if "agent_ref" in repeated_batch:
3066+
log_data["agent_ref"] = repeated_batch["agent_ref"]
3067+
log_data["content"] = flat_messages_content
30633068
log_data["rewards"] = rewards.tolist()
3069+
if master_config["grpo"]["use_dynamic_sampling"]:
3070+
# In dynamic sampling, `rewards` corresponds to filtered rewards
3071+
log_data["filtered_rewards"] = rewards.tolist()
3072+
log_data["rewards"] = repeated_batch["total_reward"].tolist()
3073+
log_data["input_lengths"] = input_lengths.tolist()
3074+
log_data["token_ids"] = train_data["input_ids"].tolist()
3075+
log_data["token_loss_mask"] = train_data["token_mask"].tolist()
3076+
log_data["sample_loss_mask"] = train_data["sample_mask"].tolist()
3077+
log_data["advantages"] = train_data["advantages"].tolist()
30643078
log_data["generation_logprobs"] = train_data["generation_logprobs"].tolist()
30653079
log_data["prev_logprobs"] = train_data["prev_logprobs"].tolist()
3066-
log_data["input_lengths"] = input_lengths.tolist()
30673080
logger.log_batched_dict_as_jsonl(
30683081
log_data, f"train_data_step{step + 1}.jsonl"
30693082
)

nemo_rl/environments/nemo_gym.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,20 @@ def _postprocess_nemo_gym_to_nemo_rl_result(
232232
)
233233
output_item_dict.pop("generation_log_probs")
234234

235+
if not nemo_rl_message_log:
236+
input_messages = nemo_gym_result["responses_create_params"]["input"]
237+
prompt_token_ids = tokenizer.apply_chat_template(
238+
input_messages, tokenize=True
239+
)
240+
raise ValueError(
241+
f"NeMo Gym returned a result with no generation data. "
242+
f"This typically means the prompt for the first turn already exceeds the vLLM max_model_len, "
243+
f"so vLLM rejected the request before any tokens could be generated.\n"
244+
f" Prompt length: {len(prompt_token_ids)} tokens.\n"
245+
f" → Fix: increase `policy.max_total_sequence_length` and `policy.generation.vllm_cfg.max_model_len` "
246+
f"to a value larger than {len(prompt_token_ids)}."
247+
)
248+
235249
return {
236250
"message_log": nemo_rl_message_log,
237251
"input_message_log": nemo_rl_message_log[:1],

nemo_rl/models/generation/vllm/vllm_worker_async.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -369,20 +369,31 @@ async def _preprocess_chat(
369369
messages_for_replace_prefix_tokens = deepcopy(messages)
370370

371371
# res is conversation, [request_prompt], [engine_prompt]
372-
res = await super()._preprocess_chat(
373-
request,
374-
tokenizer,
375-
messages,
376-
chat_template,
377-
chat_template_content_format,
378-
add_generation_prompt,
379-
continue_final_message,
380-
tool_dicts,
381-
documents,
382-
chat_template_kwargs,
383-
tool_parser,
384-
add_special_tokens,
385-
)
372+
try:
373+
res = await super()._preprocess_chat(
374+
request,
375+
tokenizer,
376+
messages,
377+
chat_template,
378+
chat_template_content_format,
379+
add_generation_prompt,
380+
continue_final_message,
381+
tool_dicts,
382+
documents,
383+
chat_template_kwargs,
384+
tool_parser,
385+
add_special_tokens,
386+
)
387+
except ValueError as e:
388+
if "maximum context length" in str(e):
389+
import logging
390+
391+
# Print a clean one-liner warning that max model length has been exceeded
392+
# The exception is still raised, but later filtered out by the MaxContextLengthFilter
393+
logging.getLogger(__name__).warning(
394+
"Prompt exceeds max_model_len: %s", e
395+
)
396+
raise
386397

387398
if request.required_prefix_token_ids is None:
388399
return res
@@ -572,6 +583,24 @@ def filter(self, record: LogRecord) -> bool:
572583

573584
vllm_async_llm_logger.addFilter(CleanLoggingFilter())
574585

586+
from logging import getLogger as _getLogger
587+
588+
_getLogger("vllm.entrypoints.openai.protocol").addFilter(CleanLoggingFilter())
589+
590+
# Suppress the noisy vLLM traceback when a prompt exceeds max_model_len.
591+
# This is expected during multi-turn rollouts; we log a clean one-line
592+
# warning from _preprocess_chat instead.
593+
class MaxContextLengthFilter(LoggingFilter):
594+
def filter(self, record: LogRecord) -> bool:
595+
if record.exc_info and record.exc_info[1]:
596+
if "maximum context length" in str(record.exc_info[1]):
597+
return False
598+
return True
599+
600+
_getLogger("vllm.entrypoints.openai.serving_chat").addFilter(
601+
MaxContextLengthFilter()
602+
)
603+
575604
return app
576605

577606
def _setup_vllm_server(self) -> "tuple[threading.Thread, str, uvicorn.Server]":
@@ -602,6 +631,7 @@ def _setup_vllm_server(self) -> "tuple[threading.Thread, str, uvicorn.Server]":
602631
app,
603632
host="0.0.0.0",
604633
port=free_port,
634+
timeout_keep_alive=120, # Keep connections alive longer (default is 5s), fix for this error: Hit an exception while making a request (try 1): <class 'aiohttp.client_exceptions.ClientOSError'>: [Errno 104] Connection reset by peer
605635
)
606636
server = uvicorn.Server(config=config)
607637

tests/functional/L1_Functional_Tests_GPU.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@ run_test uv run --no-sync bash ./tests/functional/dpo_megatron.sh
4646
run_test uv run --no-sync bash ./tests/functional/eval.sh
4747
run_test uv run --no-sync bash ./tests/functional/eval_async.sh
4848
run_test fast uv run --no-sync bash ./tests/functional/grpo.sh
49+
run_test fast uv run --no-sync bash ./tests/functional/grpo_async_gym.sh
4950
run_test uv run --no-sync bash ./tests/functional/grpo_automodel_lora.sh
5051
run_test uv run --no-sync bash ./tests/functional/grpo_automodel_lora_async.sh
5152
run_test uv run --no-sync bash ./tests/functional/grpo_automodel_lora_non_colocated.sh
5253
run_test uv run --no-sync bash ./tests/functional/grpo_megatron.sh
53-
run_test fast uv run --no-sync bash ./tests/functional/grpo_megatron_async.sh
5454
run_test uv run --no-sync bash ./tests/functional/grpo_megatron_generation.sh
5555
run_test uv run --no-sync bash ./tests/functional/grpo_multiple_dataloaders.sh
5656
run_test uv run --no-sync bash ./tests/functional/grpo_multiturn.sh

tests/functional/grpo_async_gym.sh

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#!/bin/bash
2+
3+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
4+
PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..)
5+
# Mark the current repo as safe, since wandb fetches metadata about the repo
6+
git config --global --add safe.directory $PROJECT_ROOT
7+
8+
set -eou pipefail
9+
10+
EXP_NAME=$(basename $0 .sh)
11+
EXP_DIR=$SCRIPT_DIR/$EXP_NAME
12+
LOG_DIR=$EXP_DIR/logs
13+
JSON_METRICS=$EXP_DIR/metrics.json
14+
RUN_LOG=$EXP_DIR/run.log
15+
CHECKPOINT_DIR=$EXP_DIR/checkpoints
16+
DATA_DIR=$EXP_DIR/data
17+
export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-}
18+
19+
rm -rf $EXP_DIR $LOG_DIR
20+
mkdir -p $EXP_DIR $LOG_DIR $CHECKPOINT_DIR $DATA_DIR
21+
22+
cd $PROJECT_ROOT
23+
24+
# Follow nemo-gym instructions here to get this data:
25+
# https://docs.nvidia.com/nemo/gym/0.1.0/tutorials/nemo-rl-grpo/setup.html#training-nemo-rl-grpo-setup
26+
cd 3rdparty/Gym-workspace/Gym
27+
28+
# We need HF_TOKEN to download the data from huggingface
29+
if [[ ! -f env.yaml ]]; then
30+
if [[ -z "${HF_TOKEN:-}" ]]; then
31+
echo "[ERROR] HF_TOKEN is not set"
32+
exit 1
33+
fi
34+
echo "hf_token: $HF_TOKEN" >> env.yaml
35+
fi
36+
37+
config_paths="responses_api_models/vllm_model/configs/vllm_model_for_training.yaml,\
38+
resources_servers/workplace_assistant/configs/workplace_assistant.yaml"
39+
40+
uv run ng_prepare_data "+config_paths=[${config_paths}]" \
41+
+output_dirpath=data/workplace_assistant \
42+
+mode=train_preparation \
43+
+should_download=true \
44+
+data_source=huggingface
45+
cd -
46+
47+
# This trimming of the workplace assistant dataset is necessary b/c with all the tools the first prompt is >4000 tokens
48+
# which will cause vllm to return nothing on the first prompt and crash RL. Since we want to keep this test short to
49+
# smoke test, we trim all but the first tool
50+
TRAIN_PATH=$DATA_DIR/workplace_assistant_train.jsonl
51+
VALIDATION_PATH=$DATA_DIR/workplace_assistant_validation.jsonl
52+
jq -c '.responses_create_params.tools |= (.[0:1])' 3rdparty/Gym-workspace/Gym/data/workplace_assistant/train.jsonl > $TRAIN_PATH
53+
jq -c '.responses_create_params.tools |= (.[0:1])' 3rdparty/Gym-workspace/Gym/data/workplace_assistant/validation.jsonl > $VALIDATION_PATH
54+
55+
uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \
56+
$PROJECT_ROOT/examples/nemo_gym/run_grpo_nemo_gym.py \
57+
--config $PROJECT_ROOT/examples/nemo_gym/grpo_qwen3_30ba3b_instruct.yaml \
58+
policy.model_name=Qwen/Qwen3-0.6B \
59+
policy.dtensor_cfg.enabled=false \
60+
policy.megatron_cfg.enabled=true \
61+
policy.megatron_cfg.tensor_model_parallel_size=1 \
62+
policy.megatron_cfg.pipeline_model_parallel_size=1 \
63+
policy.megatron_cfg.expert_model_parallel_size=1 \
64+
policy.megatron_cfg.context_parallel_size=1 \
65+
policy.megatron_cfg.sequence_parallel=false \
66+
policy.generation.vllm_cfg.tensor_parallel_size=1 \
67+
policy.generation.vllm_cfg.async_engine=true \
68+
policy.max_total_sequence_length=512 \
69+
policy.generation.colocated.enabled=false \
70+
policy.generation.colocated.resources.num_nodes=1 \
71+
policy.generation.colocated.resources.gpus_per_node=1 \
72+
grpo.num_prompts_per_step=4 \
73+
grpo.num_generations_per_prompt=2 \
74+
grpo.max_num_steps=10 \
75+
grpo.async_grpo.enabled=true \
76+
grpo.async_grpo.max_trajectory_age_steps=1 \
77+
grpo.async_grpo.in_flight_weight_updates=true \
78+
policy.train_global_batch_size=4 \
79+
policy.train_micro_batch_size=1 \
80+
cluster.gpus_per_node=2 \
81+
loss_fn.use_importance_sampling_correction=true \
82+
logger.tensorboard_enabled=true \
83+
logger.log_dir=$LOG_DIR \
84+
logger.wandb_enabled=false \
85+
logger.monitor_gpus=true \
86+
checkpointing.enabled=false \
87+
checkpointing.checkpoint_dir=$CHECKPOINT_DIR \
88+
data.train.data_path=$TRAIN_PATH \
89+
data.validation.data_path=$VALIDATION_PATH \
90+
$@ \
91+
2>&1 | tee $RUN_LOG
92+
93+
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
94+
95+
# Observed to be between 0.8-1.3
96+
uv run tests/check_metrics.py $JSON_METRICS \
97+
'median(data["train/gen_kl_error"]) < 1.3'

0 commit comments

Comments
 (0)