Skip to content

Commit 5fbbe57

Browse files
committed
phony command, joblib stuff, took think out of prompt
1 parent 82dbaba commit 5fbbe57

File tree

3 files changed

+45
-16
lines changed

3 files changed

+45
-16
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,4 @@ exclude = '''
5757
[project.scripts]
5858
agentlab-assistant = "agentlab.ui_assistant:main"
5959
agentlab-xray = "agentlab.analyze.agent_xray:main"
60+
agentlab-analyze = "agentlab.analyze.error_analysis.pipeline:main"

src/agentlab/analyze/error_analysis/pipeline.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ def __call__(self, *args, **kwds):
2323
return "analysis"
2424

2525

26+
def analyze(exp_result, episode_summarizer, save_analysis_func):
27+
error_analysis = episode_summarizer(exp_result)
28+
save_analysis_func(exp_result, error_analysis)
29+
30+
2631
@dataclass
2732
class ErrorAnalysisPipeline:
2833
exp_dir: Path
@@ -36,12 +41,21 @@ def filter_exp_results(self) -> Generator[ExpResult, None, None]:
3641
if self.filter is None or self.filter in str(exp_result.exp_dir):
3742
yield exp_result
3843

39-
def run_analysis(self):
44+
def run_analysis(self, parallel=False, jobs=-1):
4045
filtered_results = self.filter_exp_results()
4146

42-
for exp_result in filtered_results:
43-
error_analysis = self.episode_summarizer(exp_result)
44-
self.save_analysis(exp_result, error_analysis)
47+
if parallel:
48+
import joblib
49+
50+
joblib.Parallel(n_jobs=jobs, backend="threading")(
51+
joblib.delayed(analyze)(exp_result, self.episode_summarizer, self.save_analysis)
52+
for exp_result in filtered_results
53+
)
54+
55+
else:
56+
for exp_result in filtered_results:
57+
error_analysis = self.episode_summarizer(exp_result)
58+
self.save_analysis(exp_result, error_analysis)
4559

4660
def save_analysis(self, exp_result: ExpResult, error_analysis: dict, exists_ok=True):
4761
"""Save the analysis to json"""
@@ -56,28 +70,37 @@ def save_analysis(self, exp_result: ExpResult, error_analysis: dict, exists_ok=T
5670
HTML_FORMATTER = lambda x: x.get("pruned_html", "No HTML available")
5771

5872

59-
if __name__ == "__main__":
73+
def main():
6074
import argparse
6175

6276
parser = argparse.ArgumentParser()
6377
parser.add_argument("-e", "--exp_dir", type=str)
6478
parser.add_argument("-f", "--filter", type=str, default=None)
79+
parser.add_argument("-p", "--parallel", action="store_true")
80+
parser.add_argument("-j", "--jobs", type=int, default=-1)
6581

6682
args = parser.parse_args()
83+
84+
assert args.exp_dir is not None, "Please provide an exp_dir, e.g., -e /path/to/exp_dir"
85+
6786
exp_dir = Path(args.exp_dir)
6887
filter = args.filter
88+
parallel = args.parallel
89+
jobs = args.jobs
6990

7091
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT
7192

7293
llm = CHAT_MODEL_ARGS_DICT["azure/gpt-4o-2024-08-06"].make_model()
7394

74-
step_summarizer = ChangeSummarizer(llm, lambda x: x)
75-
episode_summarizer = EpisodeSummarizer()
76-
7795
pipeline = ErrorAnalysisPipeline(
7896
exp_dir=exp_dir,
7997
filter=filter,
8098
episode_summarizer=EpisodeErrorSummarizer(ChangeSummarizer(llm, AXTREE_FORMATTER), llm),
8199
)
82100

83-
pipeline.run_analysis()
101+
pipeline.run_analysis(parallel=parallel, jobs=jobs)
102+
103+
104+
if __name__ == "__main__":
105+
106+
main()

src/agentlab/analyze/error_analysis/summarizer.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
CHANGE_SUMMARIZER_PROMPT,
77
ERROR_CLASSIFICATION_PROMPT,
88
)
9-
from agentlab.analyze.inspect_results import summarize
109
from agentlab.llm.llm_utils import json_parser, parse_html_tags
10+
from agentlab.llm.tracking import set_tracker
1111

1212

1313
def _diff(past_obs, current_obs):
@@ -94,14 +94,20 @@ def __call__(self, exp_results: ExpResult) -> EpisodeAnalysis:
9494
# if exp_results.steps_info[-1].reward == 1:
9595
# return {"analysis": "Success", "summaries": {}}
9696

97-
summaries = self.make_change_summaries(exp_results)
97+
with set_tracker("summary") as summaries_tracker:
98+
summaries = self.make_change_summaries(exp_results)
9899
prompt = self.make_prompt(exp_results, summaries)
99-
raw_analysis = self.llm(prompt)["content"]
100+
101+
with set_tracker("analysis") as analysis_tracker:
102+
raw_analysis = self.llm(prompt)["content"]
100103
analysis = self.parse(raw_analysis)
101-
return {
104+
res = {
102105
"analysis": analysis,
103106
"summaries": {i: a for i, a in enumerate(summaries)},
104107
}
108+
res.update(analysis_tracker.stats)
109+
res.update(summaries_tracker.stats)
110+
return res
105111

106112
def make_change_summaries(self, exp_result: ExpResult) -> list[str]:
107113
summaries = [] # type: list[str]
@@ -136,16 +142,15 @@ def format_summary(summary):
136142

137143
txt_summaries = "\n".join([format_summary(summary) for summary in summaries])
138144

139-
thoughts = [step.agent_info.think for step in exp_results.steps_info[:-1]]
140145
actions = [step.action for step in exp_results.steps_info[:-1]]
141146
action_errors = "\n".join(
142147
[step.obs["last_action_error"] for step in exp_results.steps_info[1:]]
143148
)
144149

145150
txt_actions = "\n".join(
146151
[
147-
f"Thoughts: {thought}\nAction: {action}\nAction Error: {action_error}"
148-
for action, thought, action_error in zip(actions, thoughts, action_errors)
152+
f"Action: {action}\nAction Error: {action_error}"
153+
for action, action_error in zip(actions, action_errors)
149154
]
150155
)
151156
return ERROR_CLASSIFICATION_PROMPT.format(

0 commit comments

Comments
 (0)