Skip to content

Commit deaece6

Browse files
authored
feat: Add support for multi-turn generations and RL (tools, games, etc) (#218)
Signed-off-by: Sahil Jain <[email protected]>
1 parent 1245c50 commit deaece6

File tree

21 files changed

+2132
-263
lines changed

21 files changed

+2132
-263
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ dist/
1717
*.vscode/
1818

1919
# Test
20-
.coverage
20+
coverage.json
21+
.coverage*
2122
test_assets/
2223

2324
# Cache

examples/configs/grpo_math_1B.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
grpo:
33
num_prompts_per_step: 32
44
num_generations_per_prompt: 16
5+
max_rollout_turns: 1 # for multi-turn rollouts. Math Environments just have 1 turn (answering the question)
56
max_num_steps: 1000000
67
normalize_rewards: true
78
use_leave_one_out_baseline: true

nemo_reinforcer/algorithms/grpo.py

Lines changed: 34 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424
from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict
2525
from nemo_reinforcer.algorithms.utils import calculate_baseline_and_std_per_prompt
2626

27-
from nemo_reinforcer.environments.interfaces import EnvironmentInterface
27+
from nemo_reinforcer.environments.interfaces import (
28+
EnvironmentInterface,
29+
EnvironmentReturn,
30+
)
2831
from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster
2932
from nemo_reinforcer.data.interfaces import (
3033
DatumSpec,
@@ -59,6 +62,7 @@
5962
from nemo_reinforcer.utils.logger import Logger, LoggerConfig
6063
from nemo_reinforcer.utils.timer import Timer
6164
from nemo_reinforcer.utils.checkpoint import CheckpointManager, CheckpointingConfig
65+
from nemo_reinforcer.experience.rollouts import run_multi_turn_rollout
6266

6367

6468
# ===============================================================================
@@ -73,6 +77,7 @@ class GRPOConfig(TypedDict):
7377
normalize_rewards: bool
7478
use_leave_one_out_baseline: bool
7579
val_period: int
80+
val_batch_size: int
7681
val_at_start: bool
7782
checkpoint_dir: str
7883

@@ -94,7 +99,7 @@ def _default_grpo_save_state() -> GRPOSaveState:
9499
class MasterConfig(TypedDict):
95100
policy: PolicyConfig
96101
loss_fn: ClippedPGLossConfig
97-
math_env: MathEnvConfig
102+
env_configs: Dict[str, Any]
98103
data: DataConfig
99104
grpo: GRPOConfig
100105
logger: LoggerConfig
@@ -283,120 +288,6 @@ def refit_policy_generation(
283288
policy.offload_after_refit()
284289

285290

286-
def generate_responses(
287-
policy_generation: GenerationInterface,
288-
generation_input_data: BatchedDataDict[GenerationDatumSpec],
289-
batch: BatchedDataDict[DatumSpec],
290-
tokenizer,
291-
input_lengths: torch.Tensor,
292-
include_logprobs: bool = True,
293-
) -> Tuple[BatchedDataDict[DatumSpec], List[List[int]], Dict[str, float | int]]:
294-
"""Generate responses from policy."""
295-
# Generate responses
296-
generation_outputs = policy_generation.generate(generation_input_data)
297-
298-
# Extract generated tokens
299-
generated_ids = []
300-
unpadded_sequence_lengths = generation_outputs["unpadded_sequence_lengths"]
301-
for output_ids, input_length, total_length in zip(
302-
generation_outputs["output_ids"], input_lengths, unpadded_sequence_lengths
303-
):
304-
generated_ids.append(output_ids[input_length:total_length])
305-
306-
generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
307-
308-
# Append to message log
309-
for i, (text, input_length, total_length) in enumerate(
310-
zip(generated_texts, input_lengths, unpadded_sequence_lengths)
311-
):
312-
message = {
313-
"role": "assistant",
314-
"content": text,
315-
"token_ids": generation_outputs["output_ids"][i, input_length:total_length],
316-
}
317-
318-
if include_logprobs and "logprobs" in generation_outputs:
319-
message["generation_logprobs"] = generation_outputs["logprobs"][
320-
i, input_length:total_length
321-
]
322-
323-
batch["message_log"][i].append(message)
324-
325-
metrics = {
326-
"mean_generation_length": (
327-
torch.sum(unpadded_sequence_lengths) - torch.sum(input_lengths)
328-
).item()
329-
/ len(unpadded_sequence_lengths),
330-
"max_seqlen": torch.max(unpadded_sequence_lengths).item(),
331-
}
332-
333-
return batch, generated_ids, metrics
334-
335-
336-
def calculate_rewards(
337-
batch: BatchedDataDict[DatumSpec],
338-
task_to_env: Dict[str, EnvironmentInterface],
339-
) -> Tuple[torch.Tensor, List[LLMMessageLogType]]:
340-
"""Calculate rewards for generated responses.
341-
342-
Args:
343-
batch: Batch containing message_log (LLMMessageLogType) with generated responses
344-
task_to_env: Dictionary mapping task names to their corresponding environments
345-
346-
Returns:
347-
rewards: Tensor of rewards
348-
to_env: Simplified message logs sent to environment (LLMMessageLogType format)
349-
"""
350-
# Extract message logs for environment
351-
to_env = [
352-
get_keys_from_message_log(batch["message_log"][i], ["role", "content"])
353-
for i in range(len(batch["message_log"]))
354-
]
355-
task_names = [batch["task_name"][i] for i in range(len(batch["task_name"]))]
356-
357-
# Group messages by task type
358-
task_groups = {}
359-
for i, task_name in enumerate(task_names):
360-
if task_name not in task_groups:
361-
task_groups[task_name] = []
362-
task_groups[task_name].append((i, to_env[i]))
363-
364-
# Calculate rewards for each task group concurrently
365-
futures = []
366-
future_to_indices = {} # Map future to its corresponding indices
367-
for task_name, group in task_groups.items():
368-
if task_name not in task_to_env:
369-
raise ValueError(f"No environment found for task type: {task_name}")
370-
371-
# Extract indices and messages for this group
372-
indices = [idx for idx, _ in group]
373-
messages = [msg for _, msg in group]
374-
375-
# Get corresponding environment info
376-
env_info = [batch["extra_env_info"][i] for i in indices]
377-
378-
# Submit task to environment and store future
379-
future = task_to_env[task_name].step.remote(messages, env_info)
380-
futures.append(future)
381-
future_to_indices[future] = indices
382-
383-
results = ray.get(futures)
384-
all_rewards = []
385-
for future, result in zip(futures, results):
386-
indices = future_to_indices[future]
387-
_, _, task_rewards, _ = result
388-
389-
# Store results with their original indices
390-
for idx, reward in zip(indices, task_rewards):
391-
all_rewards.append((idx, reward))
392-
393-
# Sort results by original index to maintain order
394-
all_rewards.sort(key=lambda x: x[0])
395-
rewards = torch.tensor([reward for _, reward in all_rewards])
396-
397-
return rewards, to_env
398-
399-
400291
# ===============================================================================
401292
# Training & Validation
402293
# ===============================================================================
@@ -463,7 +354,7 @@ def grpo_train(
463354
print("▶ Preparing batch...")
464355
with timer.time("data_processing"):
465356
# Repeat batch items
466-
repeated_batch = batch.repeat_interleave(
357+
repeated_batch: BatchedDataDict[DatumSpec] = batch.repeat_interleave(
467358
master_config["grpo"]["num_generations_per_prompt"]
468359
)
469360
# Convert LLMMessageLogType to FlatMessagesType for generation
@@ -472,36 +363,33 @@ def grpo_train(
472363
pad_value_dict={"token_ids": tokenizer.pad_token_id},
473364
)
474365
input_ids = batched_flat["token_ids"]
475-
# Create generation-specific input structure
476-
generation_input_data = BatchedDataDict[GenerationDatumSpec](
477-
{
478-
"input_ids": input_ids,
479-
"input_lengths": input_lengths,
480-
}
481-
)
482366

483367
# Generate responses - this updates the LLMMessageLogType in repeated_batch
484-
print(f"▶ Generating responses for batch of size {len(input_ids)}...")
368+
print(f"▶ Generating responses for batch of size {repeated_batch.size}...")
485369
with timer.time("prepare_for_generation"):
486370
if NEED_REFIT and POLICY_GENERATION_STALE:
487371
refit_policy_generation(policy, policy_generation)
488372
POLICY_GENERATION_STALE = False
489373
else:
490374
policy_generation.prepare_for_generation()
375+
491376
with timer.time("generation"):
492-
repeated_batch, _, gen_metrics = generate_responses(
493-
policy_generation,
494-
generation_input_data,
495-
repeated_batch,
496-
tokenizer,
497-
input_lengths,
377+
repeated_batch, rollout_metrics = run_multi_turn_rollout(
378+
policy_generation=policy_generation,
379+
input_batch=repeated_batch,
380+
tokenizer=tokenizer,
381+
task_to_env=task_to_env,
382+
max_seq_len=master_config["policy"]["max_total_sequence_length"],
383+
max_rollout_turns=master_config["grpo"]["max_rollout_turns"],
384+
greedy=False,
498385
)
499386
policy_generation.finish_generation()
500387

501-
# Calculate rewards & advantages based on the updated LLMMessageLogType
502-
print("▶ Calculating rewards...")
388+
# Calculate rewards & advantages
389+
print("▶ Processing rewards...")
503390
with timer.time("reward_calculation"):
504-
rewards, _ = calculate_rewards(repeated_batch, task_to_env)
391+
# Extract rewards from final_batch
392+
rewards = repeated_batch["total_reward"]
505393

506394
print("▶ Computing advantages...")
507395
baseline, std = calculate_baseline_and_std_per_prompt(
@@ -665,14 +553,14 @@ def grpo_train(
665553
metrics[k] = np.sum(v).item()
666554
else:
667555
metrics[k] = np.mean(v).item()
668-
metrics.update(gen_metrics)
556+
metrics.update(rollout_metrics)
669557

670558
timing_metrics = timer.get_timing_metrics(reduction_op="sum")
671559

672560
print(f" • Loss: {metrics['loss']:.4f}")
673561
print(f" • Avg Reward: {np.mean(rewards.numpy()):.4f}")
674562
print(
675-
f" • Mean Generation Length: {gen_metrics['mean_generation_length']:.4f}"
563+
f" • Mean Generation Length: {rollout_metrics['mean_gen_tokens_per_sample']:.4f}"
676564
)
677565

678566
print("\n⏱️ Timing:")
@@ -726,39 +614,25 @@ def validate(
726614
if batch_idx >= max_batches:
727615
break
728616

729-
# Convert LLMMessageLogType to FlatMessagesType for generation
730-
batched_flat, input_lengths = batched_message_log_to_flat_message(
731-
val_batch["message_log"],
732-
pad_value_dict={"token_ids": tokenizer.pad_token_id},
733-
)
734-
# Extract input IDs
735-
input_ids = batched_flat["token_ids"]
736-
# Create generation-specific input structure
737-
generation_input_data = BatchedDataDict(
738-
{
739-
"input_ids": input_ids,
740-
"input_lengths": input_lengths,
741-
}
742-
)
743-
744617
# Generate responses (updates the LLMMessageLogType in batch_with_msg_logs)
745-
val_batch, generated_ids, gen_metrics = generate_responses(
618+
val_batch, gen_metrics = run_multi_turn_rollout(
746619
policy_generation,
747-
generation_input_data,
748620
val_batch,
749621
tokenizer,
750-
input_lengths,
751-
include_logprobs=False,
622+
val_task_to_env,
623+
max_seq_len=master_config["policy"]["max_total_sequence_length"],
624+
max_rollout_turns=master_config["grpo"]["max_rollout_turns"],
625+
greedy=False,
752626
)
753-
754-
# Calculate rewards based on the updated LLMMessageLogType
755-
with timer.time("reward_calculation"):
756-
rewards, to_env = calculate_rewards(val_batch, val_task_to_env)
627+
rewards = val_batch["total_reward"]
757628

758629
total_rewards.extend(rewards.tolist())
759-
total_lengths.extend([len(ids) for ids in generated_ids])
630+
total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"])
760631

761632
# Collect message logs for later display
633+
to_env = get_keys_from_message_log(
634+
val_batch["message_log"], ["role", "content"]
635+
)
762636
all_message_logs.extend(to_env)
763637

764638
# Calculate validation metrics

nemo_reinforcer/algorithms/loss_functions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ def __call__(
9898

9999
lp_error = torch.abs(generation_logprobs - prev_logprobs) # noqa: F841 (precommit ignore for now)
100100
mult_prob_error = masked_mean(torch.exp(lp_error), mask).item()
101+
if mult_prob_error == 0.0:
102+
# this sometimes gets 0 (everything masked/invalid). Doing this to avoid screwing up stats too much
103+
mult_prob_error = 1.0
101104

102105
next_token_logits = next_token_logits.to(torch.float32)
103106

nemo_reinforcer/data/datasets.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ def rl_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict:
124124
idx = [datum_spec["idx"] for datum_spec in data_batch]
125125
batch_max_length = torch.ones_like(length) * length.max()
126126

127+
# Extract stop_strings if present
128+
stop_strings = [datum.get("stop_strings", None) for datum in data_batch]
129+
127130
output = BatchedDataDict(
128131
message_log=message_log,
129132
length=length,
@@ -132,6 +135,7 @@ def rl_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict:
132135
task_name=task_names,
133136
idx=idx,
134137
batch_max_length=batch_max_length,
138+
stop_strings=stop_strings,
135139
)
136140
return output
137141

nemo_reinforcer/data/interfaces.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class DatumSpec(TypedDict):
3232
loss_multiplier: float # multiplier for the loss for this datum. 0 to mask out (say the sample is invalid)
3333
idx: int
3434
task_name: Optional[str] = "default"
35+
stop_strings: Optional[List[str]] = None # Optional stop strings for generation
3536
__extra__: Any # This allows additional fields of any type
3637

3738

nemo_reinforcer/data/llm_message_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,11 @@ def batched_message_log_to_flat_message(
289289
# Create input_lengths tensor
290290
input_lengths = []
291291
for seq in sequenced_lists:
292-
seq_len = next(
293-
(v.size(0) for v in seq.values() if isinstance(v, torch.Tensor)), 0
292+
# Find the maximum length among all tensors in the dictionary, default to 0 if none exist
293+
# Use maximum here since there may be keys that aren't populated for all messages yet.
294+
# For example, logprobs don't get populated for non-generated tokens until post-processing.
295+
seq_len = max(
296+
(v.size(0) for v in seq.values() if isinstance(v, torch.Tensor)), default=0
294297
)
295298
input_lengths.append(seq_len)
296299
input_lengths_tensor = torch.tensor(input_lengths, dtype=torch.int32)

0 commit comments

Comments
 (0)