77 ERROR_CLASSIFICATION_PROMPT ,
88)
99from agentlab .analyze .inspect_results import summarize
10+ from agentlab .llm .llm_utils import json_parser
1011
1112
1213def _diff (past_obs , current_obs ):
@@ -21,7 +22,7 @@ def _diff(past_obs, current_obs):
2122class ChangeSummarizer :
2223
2324 llm : callable # language model
24- obs_formatter : callable = lambda x : x .get ("axtree_txt " , "No AXTREE available" )
25+ obs_formatter : callable = lambda x : x .get ("dom_txt " , "No AXTREE available" )
2526 use_diff : bool = False
2627
2728 def summarize (self , obs : StepInfo , next_obs : StepInfo , past_summaries : list [str ]) -> str :
@@ -74,20 +75,35 @@ class EpisodeAnalysis:
7475class EpisodeSummarizer :
7576
7677 change_summarizer : ChangeSummarizer = None
78+ llm : callable = None
79+ parser : callable = lambda x : json_parser (x )[0 ]
7780
7881 def make_prompt (self , exp_results : ExpResult , summaries : list [str ]): ...
7982
8083 def __call__ (self , exp_results : ExpResult ) -> EpisodeAnalysis :
8184 """Run Change Summarizer for every step in the episode or extract a pre-computed one."""
85+
86+ if exp_results .steps_info [- 1 ].reward == 1 :
87+ return {"analysis" : "Success" , "summaries" : {}}
88+
8289 summaries = self .make_change_summaries (exp_results )
90+ prompt = self .make_prompt (exp_results , summaries )
91+ raw_analysis = self .llm (prompt )["content" ]
92+ analysis = self .parser (raw_analysis )
93+ return {
94+ "analysis" : analysis ,
95+ "summaries" : {i : self .parser (a ) for i , a in enumerate (summaries )},
96+ }
8397
8498 def make_change_summaries (self , exp_result : ExpResult ) -> list [str ]:
8599 summaries = [] # type: list[str]
86100 # this assumes that there is always an extra step at the end of the episode
87101 # it is generally the case, but exps can sometimes fail in a weird way and not save the last step_info
88102 # TODO:(thibault) make some checks or w/e
89103 for step , next_step in zip (exp_result .steps_info [:- 1 ], exp_result .steps_info [1 :]):
90- summaries .append (self .change_summarizer .summarize (step , next_step , summaries ))
104+ summaries .append (
105+ self .change_summarizer .summarize (step , next_step , summaries )["content" ]
106+ )
91107 return summaries
92108
93109
@@ -96,12 +112,26 @@ class EpisodeErrorSummarizer(EpisodeSummarizer):
96112
97113 change_summarizer : ChangeSummarizer = None
98114
99- def make_prompt (self , current_observation , action_history , historical_summaries , goal , plan ):
115+ def make_prompt (self , exp_results : ExpResult , summaries : list [ str ] ):
100116 """TODO: Implement the prompt."""
117+ goal = exp_results .steps_info [0 ].obs ["goal" ]
118+
119+ txt_summaries = "\n " .join (summaries )
120+
121+ thoughts = [step .agent_info .think for step in exp_results .steps_info [:- 1 ]]
122+ actions = [step .action for step in exp_results .steps_info [:- 1 ]]
123+ action_errors = "\n " .join (
124+ [step .obs ["last_action_error" ] for step in exp_results .steps_info [1 :]]
125+ )
126+
127+ txt_actions = "\n " .join (
128+ [
129+ f"Thoughts: { thought } \n Action: { action } \n Action Error: { action_error } "
130+ for action , thought , action_error in zip (actions , thoughts , action_errors )
131+ ]
132+ )
101133 return ERROR_CLASSIFICATION_PROMPT .format (
102134 goal = goal ,
103- plan = plan ,
104- current_observation = current_observation ,
105- historical_summaries = historical_summaries ,
106- action_history = action_history ,
135+ historical_summaries = txt_summaries ,
136+ action_history = txt_actions ,
107137 )
0 commit comments