@@ -225,35 +225,39 @@ async def run_async(self) -> bool:
225225 return False
226226
227227 rollout_obj = Rollout (rollout_id = task .rollout_id ) # Default empty rollout
228-
229228 try :
230- try :
231- self .agent .on_rollout_start (task , self , self .tracer )
232- except Exception :
233- logger .exception (f"{ self ._log_prefix (rollout_id )} Exception during on_rollout_start hook." )
234-
235- with self .tracer .trace_context (name = f"rollout_{ rollout_id } " ):
236- start_time = time .time ()
237- rollout_method = (
238- self .agent .training_rollout_async if task .mode == "train" else self .agent .validation_rollout_async
239- )
240- # Pass the task input, not the whole task object
241- result = await rollout_method (task .input , task .rollout_id , resources_update .resources )
242- rollout_obj = self ._to_rollout_object (result , task .rollout_id )
243- end_time = time .time ()
244- logger .info (
245- f"{ self ._log_prefix (rollout_id )} Completed in "
246- f"{ end_time - start_time :.2f} s. Reward: { rollout_obj .final_reward } "
247- )
229+ self .agent .on_rollout_start (task , self , self .tracer )
248230 except Exception :
249- logger .exception (f"{ self ._log_prefix (rollout_id )} Exception during rollout." )
250- finally :
231+ logger .exception (f"{ self ._log_prefix (rollout_id )} Exception during on_rollout_start hook." )
232+ MAX_TRY = 3
233+ while MAX_TRY > 0 :
251234 try :
252- self .agent .on_rollout_end (task , rollout_obj , self , self .tracer )
235+ with self .tracer .trace_context (name = f"rollout_{ rollout_id } " ):
236+ start_time = time .time ()
237+ rollout_method = (
238+ self .agent .training_rollout_async if task .mode == "train" else self .agent .validation_rollout_async
239+ )
240+ # Pass the task input, not the whole task object
241+ result = await rollout_method (task .input , task .rollout_id , resources_update .resources )
242+ rollout_obj = self ._to_rollout_object (result , task .rollout_id )
243+ end_time = time .time ()
244+ logger .info (
245+ f"{ self ._log_prefix (rollout_id )} Completed in "
246+ f"{ end_time - start_time :.2f} s. Reward: { rollout_obj .final_reward } "
247+ )
248+ break
253249 except Exception :
254- logger .exception (f"{ self ._log_prefix (rollout_id )} Exception during on_rollout_end hook." )
255- await self .client .post_rollout_async (rollout_obj )
256-
250+ logger .exception (f"{ self ._log_prefix (rollout_id )} Exception during rollout." )
251+ MAX_TRY = MAX_TRY - 1
252+ finally :
253+ if rollout_obj .triplets :
254+ try :
255+ self .agent .on_rollout_end (task , rollout_obj , self , self .tracer )
256+ except Exception :
257+ logger .exception (f"{ self ._log_prefix (rollout_id )} Exception during on_rollout_end hook." )
258+ await self .client .post_rollout_async (rollout_obj )
259+ else :
260+ raise Exception ("rollout_obj.triplets is EMPTY" )
257261 return True
258262
259263 async def iter_async (self ) -> int :
0 commit comments