11import fnmatch
22import json
3+ import logging
34from abc import ABC , abstractmethod
5+ from collections import defaultdict
46from copy import copy
57from dataclasses import asdict , dataclass , field
8+ from email .policy import default
69from pathlib import Path
7- from typing import Any
10+ from typing import Any , Literal
811
912import bgym
1013import pandas as pd
3437)
3538from agentlab .llm .tracking import cost_tracker_decorator
3639
40+ logger = logging .getLogger (__name__ )
41+
3742
3843@dataclass
3944class Block (ABC ):
@@ -296,6 +301,10 @@ def apply_init(self, llm, discussion: StructuredDiscussion) -> dict:
296301class TaskHint (Block ):
297302 use_task_hint : bool = True
298303 hint_db_rel_path : str = "hint_db.csv"
304+ hint_retrieval_mode : Literal ["direct" , "llm" ] = "direct" # direct or retrieval
305+ llm_prompt : str = """We're choosing hints to help solve the following task:\n {goal}.\n
306+ You need to choose the most relevant hints topic from the following list:\n \n Hint topics:\n {topics}\n
307+ Choose hint for the task and return only its numbers, e.g. 1. If you don't know the answer, return -1."""
299308
300309 def _init (self ):
301310 """Initialize the block."""
@@ -306,9 +315,8 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
306315 if not self .use_task_hint :
307316 return
308317
309- task_hints = self .hint_db [
310- self .hint_db ["task_name" ].apply (lambda x : fnmatch .fnmatch (x , task_name ))
311- ]
318+ goal = "\n " .join ([c .get ("text" , "" ) for c in discussion .groups [0 ].messages [1 ].content ])
319+ task_hints = self .choose_hints (llm , task_name , goal )
312320
313321 hints = []
314322 for hint in task_hints ["hint" ]:
@@ -325,6 +333,44 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
325333
326334 discussion .append (msg )
327335
336+ def choose_hints (self , llm , task_name : str , goal : str ) -> pd .DataFrame :
337+ """Choose hints based on the task name."""
338+ if self .hint_retrieval_mode == "llm" :
339+ return self .choose_hints_llm (llm , goal )
340+ elif self .hint_retrieval_mode == "direct" :
341+ return self .choose_hints_direct (task_name )
342+ else :
343+ raise ValueError (f"Unknown hint retrieval mode: { self .hint_retrieval_mode } " )
344+
345+ def choose_hints_llm (self , llm , goal : str ) -> pd .DataFrame :
346+ """Choose hints using LLM to filter the hints."""
347+ topic_to_hints = defaultdict (list )
348+ for i , row in self .hint_db .iterrows ():
349+ topic_to_hints [row ["semantic_keys" ]].append (i )
350+ hint_topics = list (topic_to_hints .keys ())
351+ topics = "\n " .join ([f"{ i } . { h } " for i , h in enumerate (hint_topics )])
352+ prompt = self .llm_prompt .format (goal = goal , topics = topics )
353+ response = llm (APIPayload (messages = [llm .msg .user ().add_text (prompt )]))
354+ try :
355+ hint_topic_idx = json .loads (response .think )
356+ if hint_topic_idx < 0 or hint_topic_idx >= len (hint_topics ):
357+ logger .error (f"Wrong LLM hint id response: { response .think } , return no hints" )
358+ return pd .DataFrame (columns = self .hint_db .columns )
359+ hint_topic = hint_topics [hint_topic_idx ]
360+ hint_indices = topic_to_hints [hint_topic ]
361+ df = self .hint_db .iloc [hint_indices ].copy ()
362+ df = df .drop_duplicates (subset = ["hint" ], keep = "first" ) # leave only unique hints
363+ logger .debug (f"LLM hint topic { hint_topic_idx } , chosen hints: { df ['hint' ].tolist ()} " )
364+ except json .JSONDecodeError :
365+ logger .error (f"Failed to parse LLM hint id response: { response .think } , return no hints" )
366+ df = pd .DataFrame (columns = self .hint_db .columns )
367+ return df
368+
369+ def choose_hints_direct (self , task_name : str ) -> pd .DataFrame :
370+ return self .hint_db [
371+ self .hint_db ["task_name" ].apply (lambda x : fnmatch .fnmatch (x , task_name ))
372+ ]
373+
328374
329375@dataclass
330376class PromptConfig :
@@ -583,7 +629,7 @@ def get_action(self, obs: Any) -> float:
583629 ),
584630 summarizer = Summarizer (do_summary = True ),
585631 general_hints = GeneralHints (use_hints = False ),
586- task_hint = TaskHint (use_task_hint = True ),
632+ task_hint = TaskHint (use_task_hint = True , hint_retrieval_mode = "llm" ),
587633 keep_last_n_obs = None ,
588634 multiaction = True , # whether to use multi-action or not
589635 # action_subsets=("bid",),
0 commit comments