Skip to content

Commit 4ea3217

Browse files
HeyyyyyyGterrykong
authored andcommitted
FROM 72e87b1bcff34bc1c2a67f352202d3a2cdbb3b84 async+gym (but missing genrm change)
Signed-off-by: Jiaqi Zeng <jiaqiz@nvidia.com> Signed-off-by: Terry Kong <terryk@nvidia.com>
1 parent 02febf1 commit 4ea3217

File tree

2 files changed

+91
-9
lines changed

2 files changed

+91
-9
lines changed

examples/nemo_gym/run_grpo_nemo_gym.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,55 @@ def main() -> None:
231231
logger=logger,
232232
master_config=master_config,
233233
)
234+
# Check if async mode is enabled
235+
elif "async_grpo" in config["grpo"] and config["grpo"]["async_grpo"]["enabled"]:
236+
# Async GRPO does not support dynamic sampling, reward scaling, or reward shaping (DAPO features)
237+
unsupported_features = [
238+
"use_dynamic_sampling",
239+
"reward_scaling",
240+
"reward_shaping",
241+
]
242+
243+
for feature in unsupported_features:
244+
if feature not in config["grpo"]:
245+
continue
246+
247+
if feature == "use_dynamic_sampling":
248+
if config["grpo"][feature]:
249+
raise NotImplementedError(
250+
f"{feature} is not supported with async GRPO"
251+
)
252+
else:
253+
if config["grpo"][feature]["enabled"]:
254+
raise NotImplementedError(
255+
f"{feature} is not supported with async GRPO"
256+
)
257+
258+
from nemo_rl.algorithms.grpo import async_grpo_train
259+
260+
print("🚀 Running async GRPO training")
261+
262+
async_config = config["grpo"]["async_grpo"]
263+
# Run async GRPO training
264+
async_grpo_train(
265+
policy=policy,
266+
policy_generation=policy_generation,
267+
dataloader=dataloader,
268+
val_dataloader=val_dataloader,
269+
tokenizer=tokenizer,
270+
loss_fn=loss_fn,
271+
task_to_env=task_to_env,
272+
val_task_to_env=val_task_to_env,
273+
logger=logger,
274+
checkpointer=checkpointer,
275+
grpo_save_state=grpo_state,
276+
master_config=master_config,
277+
max_trajectory_age_steps=async_config["max_trajectory_age_steps"],
278+
)
234279
else:
280+
print("🚀 Running synchronous GRPO training")
281+
282+
# Run standard GRPO training
235283
grpo_train(
236284
policy,
237285
policy_generation,

nemo_rl/algorithms/async_utils.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -642,17 +642,51 @@ def _run_prompt_group_worker(
642642
prompt_idx: int,
643643
) -> None:
644644
try:
645+
# Import here to avoid circular dependency
646+
from nemo_rl.algorithms.grpo import _should_use_nemo_gym
647+
from nemo_rl.experience.rollouts import run_async_nemo_gym_rollout
648+
645649
# Run rollout for this prompt group
646650
# Async engine supports concurrent generation; avoid locking
647-
final_batch, rollout_metrics = run_async_multi_turn_rollout(
648-
policy_generation=self.policy_generation,
649-
input_batch=repeated_batch,
650-
tokenizer=self.tokenizer,
651-
task_to_env=self.task_to_env,
652-
max_seq_len=self.master_config["policy"]["max_total_sequence_length"],
653-
max_rollout_turns=self.master_config["grpo"]["max_rollout_turns"],
654-
greedy=False,
655-
)
651+
# Check if we should use nemo_gym (similar to synchronous GRPO)
652+
if _should_use_nemo_gym(self.master_config):
653+
generation_config = self.master_config["policy"]["generation"]
654+
env_cfg = self.master_config.get("env") or {}
655+
nemo_gym_rollout_result = run_async_nemo_gym_rollout(
656+
policy_generation=self.policy_generation,
657+
input_batch=repeated_batch,
658+
tokenizer=self.tokenizer,
659+
task_to_env=self.task_to_env,
660+
max_seq_len=None,
661+
generation_config=generation_config,
662+
max_rollout_turns=None,
663+
greedy=False,
664+
# GenRM compare config
665+
use_genrm_compare=env_cfg.get("use_genrm_compare", False),
666+
num_generations_per_prompt=self.master_config["grpo"][
667+
"num_generations_per_prompt"
668+
],
669+
genrm_compare_server_name=env_cfg.get(
670+
"genrm_compare_server_name", "genrm_compare"
671+
),
672+
genrm_agent_names=env_cfg.get(
673+
"genrm_agent_names", ["genrm_simple_agent"]
674+
),
675+
)
676+
final_batch = nemo_gym_rollout_result.final_batch
677+
rollout_metrics = nemo_gym_rollout_result.rollout_metrics
678+
else:
679+
final_batch, rollout_metrics = run_async_multi_turn_rollout(
680+
policy_generation=self.policy_generation,
681+
input_batch=repeated_batch,
682+
tokenizer=self.tokenizer,
683+
task_to_env=self.task_to_env,
684+
max_seq_len=self.master_config["policy"][
685+
"max_total_sequence_length"
686+
],
687+
max_rollout_turns=self.master_config["grpo"]["max_rollout_turns"],
688+
greedy=False,
689+
)
656690

657691
# Move to CPU and push to buffer (avoid blocking on GC/push)
658692
final_batch_cpu = final_batch.to("cpu")

0 commit comments

Comments
 (0)