Skip to content

Commit 3272e2b

Browse files
authored
refactor: Add AgentStep (#4431)
1 parent 4d19bd1 commit 3272e2b

File tree

4 files changed

+304
-114
lines changed

4 files changed

+304
-114
lines changed

haystack/agents/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1+
from haystack.agents.agent_step import AgentStep
12
from haystack.agents.base import Agent
23
from haystack.agents.base import Tool

haystack/agents/agent_step.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
import re
5+
from typing import Optional, Dict, Tuple, Any
6+
7+
from haystack import Answer
8+
from haystack.errors import AgentError
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class AgentStep:
14+
"""
15+
The AgentStep class represents a single step in the execution of an agent.
16+
17+
"""
18+
19+
def __init__(
20+
self,
21+
current_step: int = 1,
22+
max_steps: int = 10,
23+
final_answer_pattern: str = r"Final Answer\s*:\s*(.*)",
24+
prompt_node_response: str = "",
25+
transcript: str = "",
26+
):
27+
"""
28+
:param current_step: The current step in the execution of the agent.
29+
:param max_steps: The maximum number of steps the agent can execute.
30+
:param final_answer_pattern: The regex pattern to extract the final answer from the PromptNode response.
31+
:param prompt_node_response: The PromptNode response received.
32+
:param transcript: The full Agent execution transcript based on the Agent's initial prompt template and the
33+
text it generated during execution up to this step. The transcript is used to generate the next prompt.
34+
"""
35+
self.current_step = current_step
36+
self.max_steps = max_steps
37+
self.final_answer_pattern = final_answer_pattern
38+
self.prompt_node_response = prompt_node_response
39+
self.transcript = transcript
40+
41+
def prepare_prompt(self):
42+
"""
43+
Prepares the prompt for the next step.
44+
"""
45+
return self.transcript
46+
47+
def create_next_step(self, prompt_node_response: Any) -> AgentStep:
48+
"""
49+
Creates the next agent step based on the current step and the PromptNode response.
50+
:param prompt_node_response: The PromptNode response received.
51+
"""
52+
if not isinstance(prompt_node_response, list):
53+
raise AgentError(
54+
f"Agent output must be a list of str, but {prompt_node_response} received. "
55+
f"Transcript:\n{self.transcript}"
56+
)
57+
58+
if not prompt_node_response:
59+
raise AgentError(
60+
f"Agent output must be a non empty list of str, but {prompt_node_response} received. "
61+
f"Transcript:\n{self.transcript}"
62+
)
63+
64+
return AgentStep(
65+
current_step=self.current_step + 1,
66+
max_steps=self.max_steps,
67+
final_answer_pattern=self.final_answer_pattern,
68+
prompt_node_response=prompt_node_response[0],
69+
transcript=self.transcript,
70+
)
71+
72+
def extract_tool_name_and_tool_input(self, tool_pattern: str) -> Tuple[Optional[str], Optional[str]]:
73+
"""
74+
Parse the tool name and the tool input from the PromptNode response.
75+
:param tool_pattern: The regex pattern to extract the tool name and the tool input from the PromptNode response.
76+
:return: A tuple containing the tool name and the tool input.
77+
"""
78+
tool_match = re.search(tool_pattern, self.prompt_node_response)
79+
if tool_match:
80+
tool_name = tool_match.group(1)
81+
tool_input = tool_match.group(3)
82+
return tool_name.strip('" []\n').strip(), tool_input.strip('" \n')
83+
return None, None
84+
85+
def final_answer(self, query: str) -> Dict[str, Any]:
86+
"""
87+
Formats an answer as a dict containing `query` and `answers` similar to the output of a Pipeline.
88+
The full transcript based on the Agent's initial prompt template and the text it generated during execution.
89+
90+
:param query: The search query
91+
"""
92+
answer: Dict[str, Any] = {
93+
"query": query,
94+
"answers": [Answer(answer="", type="generative")],
95+
"transcript": self.transcript,
96+
}
97+
if self.current_step >= self.max_steps:
98+
logger.warning(
99+
"Maximum number of iterations (%s) reached for query (%s). Increase max_steps "
100+
"or no answer can be provided for this query.",
101+
self.max_steps,
102+
query,
103+
)
104+
else:
105+
final_answer = self.extract_final_answer()
106+
if not final_answer:
107+
logger.warning(
108+
"Final answer pattern (%s) not found in PromptNode response (%s).",
109+
self.final_answer_pattern,
110+
self.prompt_node_response,
111+
)
112+
else:
113+
answer = {
114+
"query": query,
115+
"answers": [Answer(answer=final_answer, type="generative")],
116+
"transcript": self.transcript,
117+
}
118+
return answer
119+
120+
def extract_final_answer(self) -> Optional[str]:
121+
"""
122+
Parse the final answer from the PromptNode response.
123+
:return: The final answer.
124+
"""
125+
if not self.is_last():
126+
raise AgentError("Cannot extract final answer from non terminal step.")
127+
128+
final_answer_match = re.search(self.final_answer_pattern, self.prompt_node_response)
129+
if final_answer_match:
130+
final_answer = final_answer_match.group(1)
131+
return final_answer.strip('" ')
132+
return None
133+
134+
def is_final_answer_pattern_found(self) -> bool:
135+
"""
136+
Check if the final answer pattern was found in PromptNode response.
137+
:return: True if the final answer pattern was found in PromptNode response, False otherwise.
138+
"""
139+
return bool(re.search(self.final_answer_pattern, self.prompt_node_response))
140+
141+
def is_last(self) -> bool:
142+
"""
143+
Check if this is the last step of the Agent.
144+
:return: True if this is the last step of the Agent, False otherwise.
145+
"""
146+
return self.is_final_answer_pattern_found() or self.current_step >= self.max_steps
147+
148+
def completed(self, observation: Optional[str]):
149+
"""
150+
Update the transcript with the observation
151+
:param observation: received observation from the Agent environment.
152+
"""
153+
self.transcript += (
154+
f"{self.prompt_node_response}\nObservation: {observation}\nThought:"
155+
if observation
156+
else self.prompt_node_response
157+
)

0 commit comments

Comments
 (0)