44It is based on the dynamic_prompting module from the agentlab package.
55"""
66
7+ import json
78import logging
89from dataclasses import dataclass
910from pathlib import Path
@@ -67,6 +68,8 @@ class GenericPromptFlags(dp.Flags):
6768 hint_index_path : str = None
6869 hint_retriever_path : str = None
6970 hint_num_results : int = 5
71+ n_retrieval_queries : int = 3
72+ hint_level : Literal ["episode" , "step" ] = "episode"
7073
7174
7275class MainPrompt (dp .Shrinkable ):
@@ -81,6 +84,7 @@ def __init__(
8184 step : int ,
8285 flags : GenericPromptFlags ,
8386 llm : ChatModel ,
87+ queries : list [str ] | None = None ,
8488 ) -> None :
8589 super ().__init__ ()
8690 self .flags = flags
@@ -130,6 +134,8 @@ def time_for_caution():
130134 hint_index_path = flags .hint_index_path ,
131135 hint_retriever_path = flags .hint_retriever_path ,
132136 hint_num_results = flags .hint_num_results ,
137+ hint_level = flags .hint_level ,
138+ queries = queries ,
133139 )
134140 self .plan = Plan (previous_plan , step , lambda : flags .use_plan ) # TODO add previous plan
135141 self .criticise = Criticise (visible = lambda : flags .use_criticise )
@@ -324,6 +330,8 @@ def __init__(
324330 hint_num_results : int = 5 ,
325331 skip_hints_for_current_task : bool = False ,
326332 hint_retrieval_mode : Literal ["direct" , "llm" , "emb" ] = "direct" ,
333+ hint_level : Literal ["episode" , "step" ] = "episode" ,
334+ queries : list [str ] | None = None ,
327335 ) -> None :
328336 super ().__init__ (visible = use_task_hint )
329337 self .use_task_hint = use_task_hint
@@ -339,6 +347,8 @@ def __init__(
339347 self .skip_hints_for_current_task = skip_hints_for_current_task
340348 self .goal = goal
341349 self .llm = llm
350+ self .hint_level : Literal ["episode" , "step" ] = hint_level
351+ self .queries : list [str ] | None = queries
342352 self ._init ()
343353
344354 _prompt = "" # Task hints are added dynamically in MainPrompt
@@ -394,6 +404,7 @@ def _init(self):
394404 else :
395405 print (f"Warning: Hint database not found at { hint_db_path } " )
396406 self .hint_db = pd .DataFrame (columns = ["task_name" , "hint" ])
407+
397408 self .hints_source = HintsSource (
398409 hint_db_path = hint_db_path .as_posix (),
399410 hint_retrieval_mode = self .hint_retrieval_mode ,
@@ -448,7 +459,16 @@ def get_hints_for_task(self, task_name: str) -> str:
448459 return ""
449460
450461 try :
451- task_hints = self .hints_source .choose_hints (self .llm , task_name , self .goal )
462+ # When step-level, pass queries as goal string to fit the llm_prompt
463+ goal_or_queries = self .goal
464+ if self .hint_level == "step" and self .queries :
465+ goal_or_queries = "\n " .join (self .queries )
466+
467+ task_hints = self .hints_source .choose_hints (
468+ self .llm ,
469+ task_name ,
470+ goal_or_queries ,
471+ )
452472
453473 hints = []
454474 for hint in task_hints :
@@ -466,3 +486,78 @@ def get_hints_for_task(self, task_name: str) -> str:
466486 print (f"Warning: Error getting hints for task { task_name } : { e } " )
467487
468488 return ""
489+
490+
491+ class StepWiseContextIdentificationPrompt (dp .Shrinkable ):
492+ def __init__ (
493+ self ,
494+ obs_history : list [dict ],
495+ actions : list [str ],
496+ thoughts : list [str ],
497+ obs_flags : dp .ObsFlags ,
498+ n_queries : int = 1 ,
499+ ) -> None :
500+ super ().__init__ ()
501+ self .obs_flags = obs_flags
502+ self .n_queries = n_queries
503+ self .history = dp .History (obs_history , actions , None , thoughts , obs_flags )
504+ self .instructions = dp .GoalInstructions (obs_history [- 1 ]["goal_object" ])
505+ self .obs = dp .Observation (obs_history [- 1 ], obs_flags )
506+
507+ self .think = dp .Think (visible = True ) # To replace with static text maybe
508+
509+ @property
510+ def _prompt (self ) -> HumanMessage :
511+ prompt = HumanMessage (self .instructions .prompt )
512+
513+ prompt .add_text (
514+ f"""\
515+ { self .obs .prompt } \
516+ { self .history .prompt } \
517+ """
518+ )
519+
520+ example_queries = [
521+ "The user has started sorting a table and needs to apply multiple column criteria simultaneously." ,
522+ "The user is attempting to configure advanced sorting options but the interface is unclear." ,
523+ "The user has selected the first sort column and is now looking for how to add a second sort criterion." ,
524+ "The user is in the middle of a multi-step sorting process and needs guidance on the next action." ,
525+ ]
526+
527+ example_queries_str = json .dumps (example_queries [: self .n_queries ], indent = 2 )
528+
529+ prompt .add_text (
530+ f"""
531+ # Querying memory
532+
533+ Before choosing an action, let's search our available documentation and memory for relevant context.
534+ Generate a brief, general summary of the current status to help identify useful hints. Return your answer as follow
535+ <think>chain of thought</think>
536+ <queries>json list of strings</queries> for the queries. Return exactly { self .n_queries }
537+ queries in the list.
538+
539+ # Concrete Example
540+
541+ <think>
542+ I have to sort by client and country. I could use the built-in sort on each column but I'm not sure if
543+ I will be able to sort by both at the same time.
544+ </think>
545+
546+ <queries>
547+ { example_queries_str }
548+ </queries>
549+ """
550+ )
551+
552+ return self .obs .add_screenshot (prompt )
553+
554+ def shrink (self ):
555+ self .history .shrink ()
556+ self .obs .shrink ()
557+
558+ def _parse_answer (self , text_answer ):
559+ ans_dict = parse_html_tags_raise (
560+ text_answer , keys = ["think" , "queries" ], merge_multiple = True
561+ )
562+ ans_dict ["queries" ] = json .loads (ans_dict .get ("queries" , "[]" ))
563+ return ans_dict
0 commit comments