@@ -934,35 +934,25 @@ async def run_single_sample_with_error_handling(i, sample_state):
934934
935935 # Reconstruct batch from sample results
936936 batch_size = len (final_sample_states )
937- final_batch_dict = {
938- "message_log" : [state ["message_log" ] for state in final_sample_states ],
939- "extra_env_info" : [
940- state ["extra_env_info" ] for state in final_sample_states
941- ],
942- "task_name" : [state ["task_name" ] for state in final_sample_states ],
943- "total_reward" : torch .stack (
944- [state ["total_reward" ] for state in final_sample_states ]
945- ),
946- "idx" : [
947- state .get ("idx" , i ) for i , state in enumerate (final_sample_states )
948- ],
949- "truncated" : torch .tensor (
950- [metrics ["truncated" ] for metrics in all_sample_metrics ],
951- dtype = torch .bool ,
952- ),
953- }
954-
955- # Add any reward component keys (reward1, reward2, ...) from the first state
956- reward_keys = [
957- k for k in final_sample_states [0 ]
958- if k .startswith ("reward" ) and k [6 :].isdigit ()
959- ]
960- reward_keys = sorted (reward_keys , key = lambda k : int (k [6 :]))
961- for key in reward_keys :
962- final_batch_dict [key ] = torch .stack (
963- [state [key ] for state in final_sample_states ]
964- )
965- final_batch = BatchedDataDict [DatumSpec ](final_batch_dict )
937+ final_batch = BatchedDataDict [DatumSpec ](
938+ {
939+ "message_log" : [state ["message_log" ] for state in final_sample_states ],
940+ "extra_env_info" : [
941+ state ["extra_env_info" ] for state in final_sample_states
942+ ],
943+ "task_name" : [state ["task_name" ] for state in final_sample_states ],
944+ "total_reward" : torch .stack (
945+ [state ["total_reward" ] for state in final_sample_states ]
946+ ),
947+ "idx" : [
948+ state .get ("idx" , i ) for i , state in enumerate (final_sample_states )
949+ ],
950+ "truncated" : torch .tensor (
951+ [metrics ["truncated" ] for metrics in all_sample_metrics ],
952+ dtype = torch .bool ,
953+ ),
954+ }
955+ )
966956
967957 # Preserve additional fields from the original input_batch
968958 for key in input_batch .keys ():
@@ -1237,42 +1227,28 @@ def run_async_nemo_gym_rollout(
12371227 )
12381228 input_ids = batched_flat ["token_ids" ]
12391229
1240- final_batch_dict = {
1241- "agent_ref" : [r ["agent_ref" ] for r in results ],
1242- "message_log" : [r ["message_log" ] for r in results ],
1243- # length is used downstream for mean_prompt_length
1244- "length" : torch .tensor (
1245- [len (r ["input_message_log" ][0 ]["token_ids" ]) for r in results ]
1246- ),
1247- "loss_multiplier" : input_batch ["loss_multiplier" ],
1248- # Unnecessary parts of the DatumSpec unused by the GRPO algorithm
1249- # extra_env_info: dict[str, Any]
1250- # idx: int
1251- # task_name: NotRequired[str]
1252- # stop_strings: NotRequired[list[str]] # Optional stop strings for generation
1253- # Extra information not in the DatumSpec used by the GRPO algorithm
1254- "total_reward" : torch .tensor ([r ["full_result" ]["reward" ] for r in results ]),
1255- # Add truncated field to match other rollout paths (reusing hit_max_tokens logic)
1256- "truncated" : torch .tensor (
1257- [m ["hit_max_tokens" ] for m in all_sample_metrics ], dtype = torch .bool
1258- ),
1259- }
1260-
1261- # Add any reward component keys (reward1, reward2, ...) from full_result
1262- if results :
1263- full_result = results [0 ].get ("full_result" , {})
1264- reward_keys = sorted (
1265- [
1266- k for k in full_result
1267- if isinstance (k , str ) and k .startswith ("reward" ) and k [6 :].isdigit ()
1268- ],
1269- key = lambda k : int (k [6 :]),
1270- )
1271- for key in reward_keys :
1272- final_batch_dict [key ] = torch .tensor (
1273- [r ["full_result" ][key ] for r in results ]
1274- )
1275- final_batch = BatchedDataDict [DatumSpec ](final_batch_dict )
1230+ final_batch = BatchedDataDict [DatumSpec ](
1231+ {
1232+ "agent_ref" : [r ["agent_ref" ] for r in results ],
1233+ "message_log" : [r ["message_log" ] for r in results ],
1234+ # length is used downstream for mean_prompt_length
1235+ "length" : torch .tensor (
1236+ [len (r ["input_message_log" ][0 ]["token_ids" ]) for r in results ]
1237+ ),
1238+ "loss_multiplier" : input_batch ["loss_multiplier" ],
1239+ # Unnecessary parts of the DatumSpec unused by the GRPO algorithm
1240+ # extra_env_info: dict[str, Any]
1241+ # idx: int
1242+ # task_name: NotRequired[str]
1243+ # stop_strings: NotRequired[list[str]] # Optional stop strings for generation
1244+ # Extra information not in the DatumSpec used by the GRPO algorithm
1245+ "total_reward" : torch .tensor ([r ["full_result" ]["reward" ] for r in results ]),
1246+ # Add truncated field to match other rollout paths (reusing hit_max_tokens logic)
1247+ "truncated" : torch .tensor (
1248+ [m ["hit_max_tokens" ] for m in all_sample_metrics ], dtype = torch .bool
1249+ ),
1250+ }
1251+ )
12761252
12771253 return AsyncNemoGymRolloutResult (
12781254 input_ids = input_ids ,
0 commit comments