Skip to content

Commit 02dd9b8

Browse files
authored
Merge pull request #295 from ServiceNow/scratch/refactor-hint-retrieval
Refactor Hint Retrieval
2 parents 15c5639 + ce866f6 commit 02dd9b8

File tree

9 files changed

+552
-386
lines changed

9 files changed

+552
-386
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import argparse
2+
3+
from dotenv import load_dotenv
4+
5+
load_dotenv()
6+
7+
import argparse
8+
import logging
9+
10+
from agentlab.agents.generic_agent.tmlr_config import get_base_agent
11+
from agentlab.experiments.study import Study
12+
from bgym import DEFAULT_BENCHMARKS
13+
14+
logging.getLogger().setLevel(logging.WARNING)
15+
16+
17+
def main():
18+
parser = argparse.ArgumentParser()
19+
parser.add_argument("--benchmark", required=True)
20+
parser.add_argument("--llm-config", required=True)
21+
parser.add_argument("--relaunch", action="store_true")
22+
parser.add_argument("--n-jobs", type=int, default=5)
23+
parser.add_argument("--n-relaunch", type=int, default=3)
24+
parser.add_argument("--parallel-backend", type=str, default="ray")
25+
parser.add_argument("--reproducibility-mode", action="store_true")
26+
27+
args = parser.parse_args()
28+
29+
# instantiate agent
30+
agent_args = [get_base_agent(args.llm_config)]
31+
benchmark = DEFAULT_BENCHMARKS[args.benchmark]()
32+
33+
##################### Shuffle env args list, pick subset
34+
import numpy as np
35+
rng = np.random.default_rng(42)
36+
rng.shuffle(benchmark.env_args_list)
37+
benchmark.env_args_list = benchmark.env_args_list[:33]
38+
#####################
39+
40+
# for env_args in benchmark.env_args_list:
41+
# env_args.max_steps = 100
42+
43+
if args.relaunch:
44+
# relaunch an existing study
45+
study = Study.load_most_recent(contains=None)
46+
study.find_incomplete(include_errors=True)
47+
48+
else:
49+
study = Study(
50+
agent_args,
51+
benchmark,
52+
logging_level=logging.WARNING,
53+
logging_level_stdout=logging.WARNING,
54+
)
55+
56+
study.run(
57+
n_jobs=args.n_jobs,
58+
parallel_backend="ray",
59+
strict_reproducibility=args.reproducibility_mode,
60+
n_relaunch=args.n_relaunch,
61+
)
62+
63+
64+
if __name__ == "__main__":
65+
main()
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#!/bin/bash
2+
3+
BENCHMARK="workarena_l1"
4+
5+
LLM_CONFIG="azure/gpt-5-mini-2025-08-07"
6+
# PARALLEL_BACKEND="sequential"
7+
PARALLEL_BACKEND="ray"
8+
9+
N_JOBS=5
10+
N_RELAUNCH=3
11+
12+
python experiments/generic/run_generic_agent.py \
13+
--benchmark $BENCHMARK \
14+
--llm-config $LLM_CONFIG \
15+
--parallel-backend $PARALLEL_BACKEND \
16+
--n-jobs $N_JOBS \
17+
--n-relaunch $N_RELAUNCH
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
2+
from dotenv import load_dotenv
3+
import argparse
4+
5+
load_dotenv()
6+
7+
import logging
8+
import argparse
9+
10+
from agentlab.agents.generic_agent_hinter.generic_agent import GenericAgentArgs
11+
from agentlab.agents.generic_agent_hinter.agent_configs import CHAT_MODEL_ARGS_DICT, FLAGS_GPT_4o
12+
from bgym import DEFAULT_BENCHMARKS
13+
from agentlab.experiments.study import Study
14+
15+
logging.getLogger().setLevel(logging.WARNING)
16+
17+
18+
def main():
19+
parser = argparse.ArgumentParser()
20+
parser.add_argument("--benchmark", required=True)
21+
parser.add_argument("--llm-config", required=True)
22+
parser.add_argument("--relaunch", action="store_true")
23+
parser.add_argument("--n-jobs", type=int, default=6)
24+
parser.add_argument("--parallel-backend", type=str, default="ray")
25+
parser.add_argument("--reproducibility-mode", action="store_true")
26+
# hint flags
27+
parser.add_argument("--hint-type", type=str, default="docs")
28+
parser.add_argument("--hint-index-type", type=str, default="sparse")
29+
parser.add_argument("--hint-query-type", type=str, default="direct")
30+
parser.add_argument("--hint-index-path", type=str, default="indexes/servicenow-docs-bm25")
31+
parser.add_argument("--hint-retriever-path", type=str, default="google/embeddinggemma-300m")
32+
parser.add_argument("--hint-num-results", type=int, default=5)
33+
parser.add_argument("--debug", action="store_true")
34+
args = parser.parse_args()
35+
36+
flags = FLAGS_GPT_4o
37+
flags.use_task_hint = True
38+
flags.hint_type = args.hint_type
39+
flags.hint_index_type = args.hint_index_type
40+
flags.hint_query_type = args.hint_query_type
41+
flags.hint_index_path = args.hint_index_path
42+
flags.hint_retriever_path = args.hint_retriever_path
43+
flags.hint_num_results = args.hint_num_results
44+
45+
# instantiate agent
46+
agent_args = [GenericAgentArgs(
47+
chat_model_args=CHAT_MODEL_ARGS_DICT[args.llm_config],
48+
flags=flags,
49+
)]
50+
51+
benchmark = DEFAULT_BENCHMARKS[args.benchmark]()
52+
53+
if args.debug:
54+
# shuffle env_args_list and
55+
import numpy as np
56+
rng = np.random.default_rng(42)
57+
rng.shuffle(benchmark.env_args_list)
58+
benchmark.env_args_list = benchmark.env_args_list[:6]
59+
60+
61+
if args.relaunch:
62+
# relaunch an existing study
63+
study = Study.load_most_recent(contains=None)
64+
study.find_incomplete(include_errors=True)
65+
66+
else:
67+
study = Study(
68+
agent_args,
69+
benchmark,
70+
logging_level=logging.WARNING,
71+
logging_level_stdout=logging.WARNING,
72+
)
73+
74+
study.run(
75+
n_jobs=args.n_jobs,
76+
parallel_backend=args.parallel_backend,
77+
strict_reproducibility=args.reproducibility_mode,
78+
n_relaunch=3,
79+
)
80+
81+
82+
83+
if __name__ == "__main__":
84+
main()
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#!/bin/bash
2+
3+
BENCHMARK="workarena_l1"
4+
5+
LLM_CONFIG="azure/gpt-5-mini-2025-08-07"
6+
# PARALLEL_BACKEND="sequential"
7+
PARALLEL_BACKEND="ray"
8+
9+
HINT_TYPE="docs" # human, llm, docs
10+
HINT_INDEX_TYPE="sparse" # sparse, dense
11+
HINT_QUERY_TYPE="goal" # goal, llm
12+
HINT_NUM_RESULTS=3
13+
14+
HINT_INDEX_PATH="indexes/servicenow-docs-bm25"
15+
# HINT_INDEX_PATH="indexes/servicenow-docs-embeddinggemma-300m"
16+
HINT_RETRIEVER_PATH="google/embeddinggemma-300m"
17+
18+
N_JOBS=6
19+
20+
python experiments/hinter/run_hinter_agent.py \
21+
--benchmark $BENCHMARK \
22+
--llm-config $LLM_CONFIG \
23+
--parallel-backend $PARALLEL_BACKEND \
24+
--n-jobs $N_JOBS \
25+
--hint-type $HINT_TYPE \
26+
--hint-index-type $HINT_INDEX_TYPE \
27+
--hint-query-type $HINT_QUERY_TYPE \
28+
--hint-index-path $HINT_INDEX_PATH \
29+
--hint-retriever-path $HINT_RETRIEVER_PATH \
30+
--hint-num-results $HINT_NUM_RESULTS

0 commit comments

Comments
 (0)