Skip to content

Commit 539893e

Browse files
committed
revert async change and enable fast test
Signed-off-by: Yuki Huang <yukih@nvidia.com>
1 parent 665219d commit 539893e

File tree

2 files changed

+42
-66
lines changed

2 files changed

+42
-66
lines changed

nemo_rl/experience/rollouts.py

Lines changed: 41 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -934,35 +934,25 @@ async def run_single_sample_with_error_handling(i, sample_state):
934934

935935
# Reconstruct batch from sample results
936936
batch_size = len(final_sample_states)
937-
final_batch_dict = {
938-
"message_log": [state["message_log"] for state in final_sample_states],
939-
"extra_env_info": [
940-
state["extra_env_info"] for state in final_sample_states
941-
],
942-
"task_name": [state["task_name"] for state in final_sample_states],
943-
"total_reward": torch.stack(
944-
[state["total_reward"] for state in final_sample_states]
945-
),
946-
"idx": [
947-
state.get("idx", i) for i, state in enumerate(final_sample_states)
948-
],
949-
"truncated": torch.tensor(
950-
[metrics["truncated"] for metrics in all_sample_metrics],
951-
dtype=torch.bool,
952-
),
953-
}
954-
955-
# Add any reward component keys (reward1, reward2, ...) from the first state
956-
reward_keys = [
957-
k for k in final_sample_states[0]
958-
if k.startswith("reward") and k[6:].isdigit()
959-
]
960-
reward_keys = sorted(reward_keys, key=lambda k: int(k[6:]))
961-
for key in reward_keys:
962-
final_batch_dict[key] = torch.stack(
963-
[state[key] for state in final_sample_states]
964-
)
965-
final_batch = BatchedDataDict[DatumSpec](final_batch_dict)
937+
final_batch = BatchedDataDict[DatumSpec](
938+
{
939+
"message_log": [state["message_log"] for state in final_sample_states],
940+
"extra_env_info": [
941+
state["extra_env_info"] for state in final_sample_states
942+
],
943+
"task_name": [state["task_name"] for state in final_sample_states],
944+
"total_reward": torch.stack(
945+
[state["total_reward"] for state in final_sample_states]
946+
),
947+
"idx": [
948+
state.get("idx", i) for i, state in enumerate(final_sample_states)
949+
],
950+
"truncated": torch.tensor(
951+
[metrics["truncated"] for metrics in all_sample_metrics],
952+
dtype=torch.bool,
953+
),
954+
}
955+
)
966956

967957
# Preserve additional fields from the original input_batch
968958
for key in input_batch.keys():
@@ -1237,42 +1227,28 @@ def run_async_nemo_gym_rollout(
12371227
)
12381228
input_ids = batched_flat["token_ids"]
12391229

1240-
final_batch_dict = {
1241-
"agent_ref": [r["agent_ref"] for r in results],
1242-
"message_log": [r["message_log"] for r in results],
1243-
# length is used downstream for mean_prompt_length
1244-
"length": torch.tensor(
1245-
[len(r["input_message_log"][0]["token_ids"]) for r in results]
1246-
),
1247-
"loss_multiplier": input_batch["loss_multiplier"],
1248-
# Unnecessary parts of the DatumSpec unused by the GRPO algorithm
1249-
# extra_env_info: dict[str, Any]
1250-
# idx: int
1251-
# task_name: NotRequired[str]
1252-
# stop_strings: NotRequired[list[str]] # Optional stop strings for generation
1253-
# Extra information not in the DatumSpec used by the GRPO algorithm
1254-
"total_reward": torch.tensor([r["full_result"]["reward"] for r in results]),
1255-
# Add truncated field to match other rollout paths (reusing hit_max_tokens logic)
1256-
"truncated": torch.tensor(
1257-
[m["hit_max_tokens"] for m in all_sample_metrics], dtype=torch.bool
1258-
),
1259-
}
1260-
1261-
# Add any reward component keys (reward1, reward2, ...) from full_result
1262-
if results:
1263-
full_result = results[0].get("full_result", {})
1264-
reward_keys = sorted(
1265-
[
1266-
k for k in full_result
1267-
if isinstance(k, str) and k.startswith("reward") and k[6:].isdigit()
1268-
],
1269-
key=lambda k: int(k[6:]),
1270-
)
1271-
for key in reward_keys:
1272-
final_batch_dict[key] = torch.tensor(
1273-
[r["full_result"][key] for r in results]
1274-
)
1275-
final_batch = BatchedDataDict[DatumSpec](final_batch_dict)
1230+
final_batch = BatchedDataDict[DatumSpec](
1231+
{
1232+
"agent_ref": [r["agent_ref"] for r in results],
1233+
"message_log": [r["message_log"] for r in results],
1234+
# length is used downstream for mean_prompt_length
1235+
"length": torch.tensor(
1236+
[len(r["input_message_log"][0]["token_ids"]) for r in results]
1237+
),
1238+
"loss_multiplier": input_batch["loss_multiplier"],
1239+
# Unnecessary parts of the DatumSpec unused by the GRPO algorithm
1240+
# extra_env_info: dict[str, Any]
1241+
# idx: int
1242+
# task_name: NotRequired[str]
1243+
# stop_strings: NotRequired[list[str]] # Optional stop strings for generation
1244+
# Extra information not in the DatumSpec used by the GRPO algorithm
1245+
"total_reward": torch.tensor([r["full_result"]["reward"] for r in results]),
1246+
# Add truncated field to match other rollout paths (reusing hit_max_tokens logic)
1247+
"truncated": torch.tensor(
1248+
[m["hit_max_tokens"] for m in all_sample_metrics], dtype=torch.bool
1249+
),
1250+
}
1251+
)
12761252

12771253
return AsyncNemoGymRolloutResult(
12781254
input_ids=input_ids,

tests/functional/L1_Functional_Tests_GPU.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ run_test uv run --no-sync bash ./tests/functional/dpo_automodel_lora.sh
4545
run_test uv run --no-sync bash ./tests/functional/dpo_megatron.sh
4646
run_test uv run --no-sync bash ./tests/functional/eval.sh
4747
run_test uv run --no-sync bash ./tests/functional/eval_async.sh
48-
run_test uv run --no-sync bash ./tests/functional/gdpo.sh
48+
run_test fast uv run --no-sync bash ./tests/functional/gdpo.sh
4949
run_test fast uv run --no-sync bash ./tests/functional/grpo.sh
5050
run_test fast uv run --no-sync bash ./tests/functional/grpo_async_gym.sh
5151
run_test uv run --no-sync bash ./tests/functional/grpo_automodel_lora.sh

0 commit comments

Comments
 (0)