2424from nemo_reinforcer .distributed .batched_data_dict import BatchedDataDict
2525from nemo_reinforcer .algorithms .utils import calculate_baseline_and_std_per_prompt
2626
27- from nemo_reinforcer .environments .interfaces import EnvironmentInterface
27+ from nemo_reinforcer .environments .interfaces import (
28+ EnvironmentInterface ,
29+ EnvironmentReturn ,
30+ )
2831from nemo_reinforcer .distributed .virtual_cluster import RayVirtualCluster
2932from nemo_reinforcer .data .interfaces import (
3033 DatumSpec ,
5962from nemo_reinforcer .utils .logger import Logger , LoggerConfig
6063from nemo_reinforcer .utils .timer import Timer
6164from nemo_reinforcer .utils .checkpoint import CheckpointManager , CheckpointingConfig
65+ from nemo_reinforcer .experience .rollouts import run_multi_turn_rollout
6266
6367
6468# ===============================================================================
@@ -73,6 +77,7 @@ class GRPOConfig(TypedDict):
7377 normalize_rewards : bool
7478 use_leave_one_out_baseline : bool
7579 val_period : int
80+ val_batch_size : int
7681 val_at_start : bool
7782 checkpoint_dir : str
7883
@@ -94,7 +99,7 @@ def _default_grpo_save_state() -> GRPOSaveState:
9499class MasterConfig (TypedDict ):
95100 policy : PolicyConfig
96101 loss_fn : ClippedPGLossConfig
97- math_env : MathEnvConfig
102+ env_configs : Dict [ str , Any ]
98103 data : DataConfig
99104 grpo : GRPOConfig
100105 logger : LoggerConfig
@@ -283,120 +288,6 @@ def refit_policy_generation(
283288 policy .offload_after_refit ()
284289
285290
286- def generate_responses (
287- policy_generation : GenerationInterface ,
288- generation_input_data : BatchedDataDict [GenerationDatumSpec ],
289- batch : BatchedDataDict [DatumSpec ],
290- tokenizer ,
291- input_lengths : torch .Tensor ,
292- include_logprobs : bool = True ,
293- ) -> Tuple [BatchedDataDict [DatumSpec ], List [List [int ]], Dict [str , float | int ]]:
294- """Generate responses from policy."""
295- # Generate responses
296- generation_outputs = policy_generation .generate (generation_input_data )
297-
298- # Extract generated tokens
299- generated_ids = []
300- unpadded_sequence_lengths = generation_outputs ["unpadded_sequence_lengths" ]
301- for output_ids , input_length , total_length in zip (
302- generation_outputs ["output_ids" ], input_lengths , unpadded_sequence_lengths
303- ):
304- generated_ids .append (output_ids [input_length :total_length ])
305-
306- generated_texts = tokenizer .batch_decode (generated_ids , skip_special_tokens = True )
307-
308- # Append to message log
309- for i , (text , input_length , total_length ) in enumerate (
310- zip (generated_texts , input_lengths , unpadded_sequence_lengths )
311- ):
312- message = {
313- "role" : "assistant" ,
314- "content" : text ,
315- "token_ids" : generation_outputs ["output_ids" ][i , input_length :total_length ],
316- }
317-
318- if include_logprobs and "logprobs" in generation_outputs :
319- message ["generation_logprobs" ] = generation_outputs ["logprobs" ][
320- i , input_length :total_length
321- ]
322-
323- batch ["message_log" ][i ].append (message )
324-
325- metrics = {
326- "mean_generation_length" : (
327- torch .sum (unpadded_sequence_lengths ) - torch .sum (input_lengths )
328- ).item ()
329- / len (unpadded_sequence_lengths ),
330- "max_seqlen" : torch .max (unpadded_sequence_lengths ).item (),
331- }
332-
333- return batch , generated_ids , metrics
334-
335-
336- def calculate_rewards (
337- batch : BatchedDataDict [DatumSpec ],
338- task_to_env : Dict [str , EnvironmentInterface ],
339- ) -> Tuple [torch .Tensor , List [LLMMessageLogType ]]:
340- """Calculate rewards for generated responses.
341-
342- Args:
343- batch: Batch containing message_log (LLMMessageLogType) with generated responses
344- task_to_env: Dictionary mapping task names to their corresponding environments
345-
346- Returns:
347- rewards: Tensor of rewards
348- to_env: Simplified message logs sent to environment (LLMMessageLogType format)
349- """
350- # Extract message logs for environment
351- to_env = [
352- get_keys_from_message_log (batch ["message_log" ][i ], ["role" , "content" ])
353- for i in range (len (batch ["message_log" ]))
354- ]
355- task_names = [batch ["task_name" ][i ] for i in range (len (batch ["task_name" ]))]
356-
357- # Group messages by task type
358- task_groups = {}
359- for i , task_name in enumerate (task_names ):
360- if task_name not in task_groups :
361- task_groups [task_name ] = []
362- task_groups [task_name ].append ((i , to_env [i ]))
363-
364- # Calculate rewards for each task group concurrently
365- futures = []
366- future_to_indices = {} # Map future to its corresponding indices
367- for task_name , group in task_groups .items ():
368- if task_name not in task_to_env :
369- raise ValueError (f"No environment found for task type: { task_name } " )
370-
371- # Extract indices and messages for this group
372- indices = [idx for idx , _ in group ]
373- messages = [msg for _ , msg in group ]
374-
375- # Get corresponding environment info
376- env_info = [batch ["extra_env_info" ][i ] for i in indices ]
377-
378- # Submit task to environment and store future
379- future = task_to_env [task_name ].step .remote (messages , env_info )
380- futures .append (future )
381- future_to_indices [future ] = indices
382-
383- results = ray .get (futures )
384- all_rewards = []
385- for future , result in zip (futures , results ):
386- indices = future_to_indices [future ]
387- _ , _ , task_rewards , _ = result
388-
389- # Store results with their original indices
390- for idx , reward in zip (indices , task_rewards ):
391- all_rewards .append ((idx , reward ))
392-
393- # Sort results by original index to maintain order
394- all_rewards .sort (key = lambda x : x [0 ])
395- rewards = torch .tensor ([reward for _ , reward in all_rewards ])
396-
397- return rewards , to_env
398-
399-
400291# ===============================================================================
401292# Training & Validation
402293# ===============================================================================
@@ -463,7 +354,7 @@ def grpo_train(
463354 print ("▶ Preparing batch..." )
464355 with timer .time ("data_processing" ):
465356 # Repeat batch items
466- repeated_batch = batch .repeat_interleave (
357+ repeated_batch : BatchedDataDict [ DatumSpec ] = batch .repeat_interleave (
467358 master_config ["grpo" ]["num_generations_per_prompt" ]
468359 )
469360 # Convert LLMMessageLogType to FlatMessagesType for generation
@@ -472,36 +363,33 @@ def grpo_train(
472363 pad_value_dict = {"token_ids" : tokenizer .pad_token_id },
473364 )
474365 input_ids = batched_flat ["token_ids" ]
475- # Create generation-specific input structure
476- generation_input_data = BatchedDataDict [GenerationDatumSpec ](
477- {
478- "input_ids" : input_ids ,
479- "input_lengths" : input_lengths ,
480- }
481- )
482366
483367 # Generate responses - this updates the LLMMessageLogType in repeated_batch
484- print (f"▶ Generating responses for batch of size { len ( input_ids ) } ..." )
368+ print (f"▶ Generating responses for batch of size { repeated_batch . size } ..." )
485369 with timer .time ("prepare_for_generation" ):
486370 if NEED_REFIT and POLICY_GENERATION_STALE :
487371 refit_policy_generation (policy , policy_generation )
488372 POLICY_GENERATION_STALE = False
489373 else :
490374 policy_generation .prepare_for_generation ()
375+
491376 with timer .time ("generation" ):
492- repeated_batch , _ , gen_metrics = generate_responses (
493- policy_generation ,
494- generation_input_data ,
495- repeated_batch ,
496- tokenizer ,
497- input_lengths ,
377+ repeated_batch , rollout_metrics = run_multi_turn_rollout (
378+ policy_generation = policy_generation ,
379+ input_batch = repeated_batch ,
380+ tokenizer = tokenizer ,
381+ task_to_env = task_to_env ,
382+ max_seq_len = master_config ["policy" ]["max_total_sequence_length" ],
383+ max_rollout_turns = master_config ["grpo" ]["max_rollout_turns" ],
384+ greedy = False ,
498385 )
499386 policy_generation .finish_generation ()
500387
501- # Calculate rewards & advantages based on the updated LLMMessageLogType
502- print ("▶ Calculating rewards..." )
388+ # Calculate rewards & advantages
389+ print ("▶ Processing rewards..." )
503390 with timer .time ("reward_calculation" ):
504- rewards , _ = calculate_rewards (repeated_batch , task_to_env )
391+ # Extract rewards from final_batch
392+ rewards = repeated_batch ["total_reward" ]
505393
506394 print ("▶ Computing advantages..." )
507395 baseline , std = calculate_baseline_and_std_per_prompt (
@@ -665,14 +553,14 @@ def grpo_train(
665553 metrics [k ] = np .sum (v ).item ()
666554 else :
667555 metrics [k ] = np .mean (v ).item ()
668- metrics .update (gen_metrics )
556+ metrics .update (rollout_metrics )
669557
670558 timing_metrics = timer .get_timing_metrics (reduction_op = "sum" )
671559
672560 print (f" • Loss: { metrics ['loss' ]:.4f} " )
673561 print (f" • Avg Reward: { np .mean (rewards .numpy ()):.4f} " )
674562 print (
675- f" • Mean Generation Length: { gen_metrics [ 'mean_generation_length ' ]:.4f} "
563+ f" • Mean Generation Length: { rollout_metrics [ 'mean_gen_tokens_per_sample ' ]:.4f} "
676564 )
677565
678566 print ("\n ⏱️ Timing:" )
@@ -726,39 +614,25 @@ def validate(
726614 if batch_idx >= max_batches :
727615 break
728616
729- # Convert LLMMessageLogType to FlatMessagesType for generation
730- batched_flat , input_lengths = batched_message_log_to_flat_message (
731- val_batch ["message_log" ],
732- pad_value_dict = {"token_ids" : tokenizer .pad_token_id },
733- )
734- # Extract input IDs
735- input_ids = batched_flat ["token_ids" ]
736- # Create generation-specific input structure
737- generation_input_data = BatchedDataDict (
738- {
739- "input_ids" : input_ids ,
740- "input_lengths" : input_lengths ,
741- }
742- )
743-
744617 # Generate responses (updates the LLMMessageLogType in batch_with_msg_logs)
745- val_batch , generated_ids , gen_metrics = generate_responses (
618+ val_batch , gen_metrics = run_multi_turn_rollout (
746619 policy_generation ,
747- generation_input_data ,
748620 val_batch ,
749621 tokenizer ,
750- input_lengths ,
751- include_logprobs = False ,
622+ val_task_to_env ,
623+ max_seq_len = master_config ["policy" ]["max_total_sequence_length" ],
624+ max_rollout_turns = master_config ["grpo" ]["max_rollout_turns" ],
625+ greedy = False ,
752626 )
753-
754- # Calculate rewards based on the updated LLMMessageLogType
755- with timer .time ("reward_calculation" ):
756- rewards , to_env = calculate_rewards (val_batch , val_task_to_env )
627+ rewards = val_batch ["total_reward" ]
757628
758629 total_rewards .extend (rewards .tolist ())
759- total_lengths .extend ([ len ( ids ) for ids in generated_ids ])
630+ total_lengths .append ( gen_metrics [ "mean_gen_tokens_per_sample" ])
760631
761632 # Collect message logs for later display
633+ to_env = get_keys_from_message_log (
634+ val_batch ["message_log" ], ["role" , "content" ]
635+ )
762636 all_message_logs .extend (to_env )
763637
764638 # Calculate validation metrics
0 commit comments