@@ -100,15 +100,13 @@ def __init__(
100100 self .use_multistep_prompt = workflow_args .get ("use_multistep_prompt" , False )
101101 self .desc = workflow_args .get ("desc" , None )
102102 self .is_slippery = workflow_args .get ("is_slippery" , False )
103- print (f"{ self .rollout_args = } " )
104103 self .max_response_tokens = self .rollout_args .get ("max_response_tokens" , 10240 )
105104
106105 # Extract task-specific arguments
107106 self .raw_task = task .raw_task if hasattr (task , "raw_task" ) else {}
108107 self .size = self .raw_task .get ("size" , 1 )
109108 self .p = self .raw_task .get ("p" , 0.8 )
110109 self .seed = self .raw_task .get ("seed" , 42 )
111- print ("self.size: " , self .size , "self.p: " , self .p , "self.seed: " , self .seed )
112110
113111 if self .desc is None :
114112 random_map , goal_position = generate_random_map (
@@ -241,11 +239,17 @@ def render(self, mode="tiny_rgb_array"):
241239 room_state = self .render (mode = "state" ).tolist ()
242240
243241 if mode == "list" :
244- lookup = lambda cell : GRID_LOOKUP .get (cell , "?" ).strip ("\t " ).strip ()
242+
243+ def lookup (cell ):
244+ return GRID_LOOKUP .get (cell , "?" ).strip ("\t " ).strip ()
245+
245246 return [" " .join (lookup (cell ) for cell in row ) for row in room_state ]
246247
247248 if mode == "tiny_rgb_array" :
248- lookup = lambda cell : GRID_LOOKUP .get (cell , "?" )
249+
250+ def lookup (cell ):
251+ return GRID_LOOKUP .get (cell , "?" )
252+
249253 result = "\n " .join ("" .join (lookup (cell ) for cell in row ) for row in room_state )
250254 return result
251255
@@ -271,7 +275,6 @@ async def run_async(self) -> List[Experience]:
271275
272276 # Run episode until done or max_steps reached
273277 for step in range (self .max_steps ):
274- print ("Current step: " , step )
275278 # Format observation for the model
276279 current_obs_str = str (self .current_observation )
277280 user_prompt_content = (
@@ -301,11 +304,9 @@ async def run_async(self) -> List[Experience]:
301304 else :
302305 response_token_len = messages_token_len - init_prompt_token_len
303306 max_tokens = self .max_response_tokens - response_token_len
304- print (
305- f"!!!Debug: { max_tokens = } used_response_tokens = { self .max_response_tokens - max_tokens } { messages_token_len = } { init_prompt_token_len = } "
306- )
307307
308308 if max_tokens <= 0 :
309+ # messages = messages[:-1] # TODO: apply this?
309310 self .done = False
310311 self .final_reward = 0
311312 break
@@ -314,13 +315,9 @@ async def run_async(self) -> List[Experience]:
314315 rollout_args = self .rollout_args .copy ()
315316 rollout_args ["n" ] = 1
316317 rollout_args ["max_tokens" ] = max_tokens
317- # print("Current step: ", step, rollout_args)
318318 responses = await self .model .chat_async (messages , ** rollout_args )
319319 response_text = responses [0 ].response_text
320320 messages .append ({"role" : "assistant" , "content" : response_text })
321- print (
322- "raw response: " , response_text
323- ) # sometimes has <think></think> and <action>, somtimes not
324321
325322 # Parse action from response
326323 _ , action_str = self ._parse_model_response (response_text )
@@ -349,15 +346,6 @@ async def run_async(self) -> List[Experience]:
349346 "success" : 1 if self .final_reward == 1.0 else 0 ,
350347 },
351348 )
352- print ("\n \n \n " )
353- print ("full messages: " , messages )
354- # print("experience.tokens: ", len(experience.tokens))
355- # print("experience.logprobs: ", len(experience.logprobs))
356- # print("experience.action_mask: ", len(experience.action_mask))
357- # print("experience.prompt_length: ", experience.prompt_length)
358- # print("experience.reward: ", experience.reward)
359- # print("experience.prompt_text: ", experience.prompt_text)
360- # print("experience.response_text: ", experience.response_text, "\n\n\n")
361349 return [experience ]
362350
363351 def _parse_model_response (self , response : str ) -> tuple [str , str ]:
0 commit comments