Skip to content

Commit 70d701e

Browse files
update hinter agent
1 parent eafd5fc commit 70d701e

File tree

5 files changed

+57
-35
lines changed

5 files changed

+57
-35
lines changed

experiments/generic/run_generic_agent.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,16 @@ def main():
3030
agent_args = [get_base_agent(args.llm_config)]
3131
benchmark = DEFAULT_BENCHMARKS[args.benchmark]()
3232

33+
##################### Shuffle env args list, pick subset
34+
import numpy as np
35+
rng = np.random.default_rng(42)
36+
rng.shuffle(benchmark.env_args_list)
37+
benchmark.env_args_list = benchmark.env_args_list[:33]
38+
#####################
39+
40+
# for env_args in benchmark.env_args_list:
41+
# env_args.max_steps = 100
42+
3343
if args.relaunch:
3444
# relaunch an existing study
3545
study = Study.load_most_recent(contains=None)

experiments/hinter/run_hinter_agent.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ def main():
4949

5050
benchmark = DEFAULT_BENCHMARKS[args.benchmark]()
5151

52+
# # shuffle env_args_list and pick first 33
53+
# import numpy as np
54+
# rng = np.random.default_rng(42)
55+
# rng.shuffle(benchmark.env_args_list)
56+
# benchmark.env_args_list = benchmark.env_args_list[:33]
57+
5258

5359
if args.relaunch:
5460
# relaunch an existing study

experiments/hinter/run_hinter_agent.sh

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ PARALLEL_BACKEND="ray"
99
HINT_TYPE="docs" # human, llm, docs
1010
HINT_INDEX_TYPE="sparse" # sparse, dense
1111
HINT_QUERY_TYPE="goal" # goal, llm
12-
HINT_NUM_RESULTS=5
12+
HINT_NUM_RESULTS=3
1313

1414
HINT_INDEX_PATH="indexes/servicenow-docs-bm25"
1515
# HINT_INDEX_PATH="indexes/servicenow-docs-embeddinggemma-300m"
1616
HINT_RETRIEVER_PATH="google/embeddinggemma-300m"
1717

1818
N_JOBS=6
1919

20-
python experiments/hint/run_hinter_agent.py \
20+
python experiments/hinter/run_hinter_agent.py \
2121
--benchmark $BENCHMARK \
2222
--llm-config $LLM_CONFIG \
2323
--parallel-backend $PARALLEL_BACKEND \
@@ -27,5 +27,4 @@ python experiments/hint/run_hinter_agent.py \
2727
--hint-query-type $HINT_QUERY_TYPE \
2828
--hint-index-path $HINT_INDEX_PATH \
2929
--hint-retriever-path $HINT_RETRIEVER_PATH \
30-
--hint-num-results $HINT_NUM_RESULTS \
31-
--relaunch
30+
--hint-num-results $HINT_NUM_RESULTS

src/agentlab/agents/generic_agent_hinter/generic_agent.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ def get_action(self, obs):
120120
queries if getattr(self.flags, "hint_level", "episode") == "step" else None
121121
)
122122

123+
# get hints
124+
if self.flags.use_hints:
125+
task_hints = self._get_task_hints()
126+
else:
127+
task_hints = []
128+
123129
main_prompt = MainPrompt(
124130
action_set=self.action_set,
125131
obs_history=self.obs_history,
@@ -130,7 +136,7 @@ def get_action(self, obs):
130136
step=self.plan_step,
131137
flags=self.flags,
132138
llm=self.chat_llm,
133-
queries=queries_for_hints,
139+
task_hints=task_hints,
134140
)
135141

136142
# Set task name for task hints if available
@@ -303,60 +309,57 @@ def _init_hints_index(self):
303309
print(f"Warning: Could not load hint database: {e}")
304310
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
305311

306-
def get_hints_for_task(self, task_name: str) -> str:
312+
def _get_task_hints(self) -> list[str]:
307313
"""Get hints for a specific task."""
308-
if not self.use_task_hint:
309-
return ""
314+
if not self.flags.use_task_hint:
315+
return []
310316

311-
if self.hint_type == "docs":
317+
if self.flags.hint_type == "docs":
312318
if not hasattr(self, "hint_index"):
313319
print("Initializing hint index new time")
314320
self._init()
315-
if self.hint_query_type == "goal":
316-
query = self.goal
317-
elif self.hint_query_type == "llm":
321+
if self.flags.hint_query_type == "goal":
322+
query = self.obs_history[-1]["goal_object"][0]["text"]
323+
elif self.flags.hint_query_type == "llm":
318324
query = self.llm.generate(self._prompt + self._abstract_ex + self._concrete_ex)
319325
else:
320-
raise ValueError(f"Unknown hint query type: {self.hint_query_type}")
326+
raise ValueError(f"Unknown hint query type: {self.flags.hint_query_type}")
321327

322-
if self.hint_index_type == "sparse":
328+
print(f"Query: {query}")
329+
if self.flags.hint_index_type == "sparse":
323330
import bm25s
324331
query_tokens = bm25s.tokenize(query)
325-
docs, _ = self.hint_index.retrieve(query_tokens, k=self.hint_num_results)
332+
docs, _ = self.hint_index.retrieve(query_tokens, k=self.flags.hint_num_results)
326333
docs = [elem["text"] for elem in docs[0]]
327334
# HACK: truncate to 20k characters (should cover >99% of the cases)
328335
for doc in docs:
329336
if len(doc) > 20000:
330337
doc = doc[:20000]
331338
doc += " ...[truncated]"
332-
elif self.hint_index_type == "dense":
339+
elif self.flags.hint_index_type == "dense":
333340
query_embedding = self.hint_retriever.encode(query)
334-
_, docs = self.hint_index.get_nearest_examples("embeddings", query_embedding, k=self.hint_num_results)
341+
_, docs = self.hint_index.get_nearest_examples("embeddings", query_embedding, k=self.flags.hint_num_results)
335342
docs = docs["text"]
336343

337-
hints_str = (
338-
"# Hints:\nHere are some hints for the task you are working on:\n"
339-
+ "\n".join(docs)
340-
)
341-
return hints_str
344+
return docs
342345

343346
# Check if hint_db has the expected structure
344347
if (
345348
self.hint_db.empty
346349
or "task_name" not in self.hint_db.columns
347350
or "hint" not in self.hint_db.columns
348351
):
349-
return ""
352+
return []
350353

351354
try:
352355
# When step-level, pass queries as goal string to fit the llm_prompt
353-
goal_or_queries = self.goal
354-
if self.hint_level == "step" and self.queries:
356+
goal_or_queries = self.obs_history[-1]["goal_object"][0]["text"]
357+
if self.flags.hint_level == "step" and self.queries:
355358
goal_or_queries = "\n".join(self.queries)
356359

357360
task_hints = self.hints_source.choose_hints(
358361
self.llm,
359-
task_name,
362+
self.task_name,
360363
goal_or_queries,
361364
)
362365

@@ -366,13 +369,8 @@ def get_hints_for_task(self, task_name: str) -> str:
366369
if hint:
367370
hints.append(f"- {hint}")
368371

369-
if len(hints) > 0:
370-
hints_str = (
371-
"# Hints:\nHere are some hints for the task you are working on:\n"
372-
+ "\n".join(hints)
373-
)
374-
return hints_str
372+
return hints
375373
except Exception as e:
376-
print(f"Warning: Error getting hints for task {task_name}: {e}")
374+
print(f"Warning: Error getting hints for task {self.task_name}: {e}")
377375

378-
return ""
376+
return []

src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,16 @@ class GenericPromptFlags(dp.Flags):
5959
max_trunc_itr: int = 20
6060
flag_group: str = None
6161

62+
# hint related
63+
use_task_hint: bool = False
64+
hint_type: str = "docs"
65+
hint_index_type: str = "sparse"
66+
hint_query_type: str = "direct"
67+
hint_index_path: str = "indexes/servicenow-docs-bm25"
68+
hint_retriever_path: str = "google/embeddinggemma-300m"
69+
hint_num_results: int = 5
70+
n_retrieval_queries: int = 1
71+
6272

6373
class MainPrompt(dp.Shrinkable):
6474
def __init__(
@@ -68,7 +78,6 @@ def __init__(
6878
actions: list[str],
6979
memories: list[str],
7080
thoughts: list[str],
71-
hints: list[str],
7281
previous_plan: str,
7382
step: int,
7483
flags: GenericPromptFlags,

0 commit comments

Comments
 (0)