@@ -325,12 +325,17 @@ def _init(self):
325325 embedder_server = self .embedder_server ,
326326 )
327327
328- def apply (self , llm , discussion : StructuredDiscussion , task_name : str ) -> dict :
328+ def apply (self , llm , discussion : StructuredDiscussion , obs : dict , task_name : str ) -> dict :
329329 if not self .use_task_hint :
330330 return {}
331331
332- goal = "\n " .join ([c .get ("text" , "" ) for c in discussion .groups [0 ].messages [1 ].content ])
333- task_hints = self .hints_source .choose_hints (llm , task_name , goal )
332+ # goal = "\n".join([c.get("text", "") for c in discussion.groups[0].messages[1].content])
333+ try :
334+ goal_text = obs ["goal_object" ][0 ]["text" ]
335+ except (KeyError , IndexError ):
336+ Warning ("Goal text not found in observation" )
337+ goal_text = ""
338+ task_hints = self .hints_source .choose_hints (llm , task_name , goal_text )
334339
335340 hints = []
336341 for hint in task_hints :
@@ -472,7 +477,7 @@ def get_action(self, obs: Any) -> float:
472477
473478 self .config .summarizer .apply_init (self .llm , self .discussion )
474479 self .config .general_hints .apply (self .llm , self .discussion )
475- self .task_hint .apply (self .llm , self .discussion , self .task_name )
480+ self .task_hint .apply (self .llm , self .discussion , obs = obs , task_name = self .task_name )
476481
477482 self .discussion .new_group ()
478483
0 commit comments