Skip to content

Commit 5a1203b

Browse files
authored
Introduce confidence to autochain (#166)
1 parent 6c756ee commit 5a1203b

File tree

9 files changed

+219
-39
lines changed

9 files changed

+219
-39
lines changed

autochain/agent/base_agent.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44
from string import Template
55
from typing import Any, List, Optional, Sequence, Union
66

7-
from pydantic import BaseModel, Extra
8-
97
from autochain.agent.message import ChatMessageHistory
108
from autochain.agent.prompt_formatter import JSONPromptTemplate
119
from autochain.agent.structs import AgentAction, AgentFinish, AgentOutputParser
1210
from autochain.models.base import BaseLanguageModel
1311
from autochain.tools.base import Tool
12+
from pydantic import BaseModel
1413

1514

1615
class BaseAgent(BaseModel, ABC):
@@ -104,3 +103,12 @@ def get_prompt_template(
104103
if input_variables is None:
105104
input_variables = ["input", "agent_scratchpad"]
106105
return JSONPromptTemplate(template=template, input_variables=input_variables)
106+
107+
def is_generation_confident(
108+
self,
109+
history: ChatMessageHistory,
110+
agent_output: Union[AgentAction, AgentFinish],
111+
min_confidence: int = 3,
112+
) -> bool:
113+
"""Check if the generation is confident enough to take action"""
114+
return True

autochain/agent/message.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import enum
22
from abc import abstractmethod
3-
from typing import List, Any, Dict
3+
from typing import Any, Dict, List
44

55
from pydantic import BaseModel, Field
66

@@ -60,6 +60,7 @@ class FunctionMessage(BaseMessage):
6060
"""Type of message that is a function message."""
6161

6262
name: str
63+
conversational_message: str = ""
6364

6465
@property
6566
def type(self) -> str:
@@ -76,14 +77,24 @@ def save_message(self, message: str, message_type: MessageType, **kwargs):
7677
elif message_type == MessageType.UserMessage:
7778
self.messages.append(UserMessage(content=message))
7879
elif message_type == MessageType.FunctionMessage:
79-
self.messages.append(FunctionMessage(content=message, name=kwargs["name"]))
80+
self.messages.append(
81+
FunctionMessage(
82+
content=message,
83+
name=kwargs["name"],
84+
conversational_message=kwargs["conversational_message"],
85+
)
86+
)
8087
elif message_type == MessageType.SystemMessage:
8188
self.messages.append(SystemMessage(content=message))
8289

8390
def format_message(self):
8491
string_messages = []
8592
if len(self.messages) > 0:
8693
for m in self.messages:
94+
if isinstance(m, FunctionMessage):
95+
string_messages.append(f"Action: {m.conversational_message}")
96+
continue
97+
8798
if isinstance(m, UserMessage):
8899
role = "User"
89100
elif isinstance(m, AIMessage):
Lines changed: 86 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
from __future__ import annotations
22

33
import logging
4+
from string import Template
45
from typing import Any, Dict, List, Optional, Union
56

6-
from colorama import Fore
7-
87
from autochain.agent.base_agent import BaseAgent
9-
from autochain.agent.message import ChatMessageHistory, SystemMessage
8+
from autochain.agent.message import ChatMessageHistory, SystemMessage, UserMessage
109
from autochain.agent.openai_functions_agent.output_parser import (
1110
OpenAIFunctionOutputParser,
1211
)
12+
from autochain.agent.openai_functions_agent.prompt import ESTIMATE_CONFIDENCE_PROMPT
1313
from autochain.agent.structs import AgentAction, AgentFinish
1414
from autochain.models.base import BaseLanguageModel, Generation
1515
from autochain.tools.base import Tool
1616
from autochain.utils import print_with_color
17+
from colorama import Fore
1718

1819
logger = logging.getLogger(__name__)
1920

@@ -30,6 +31,7 @@ class OpenAIFunctionsAgent(BaseAgent):
3031
allowed_tools: Dict[str, Tool] = {}
3132
tools: List[Tool] = []
3233
prompt: Optional[str] = None
34+
min_confidence: int = 3
3335

3436
@classmethod
3537
def from_llm_and_tools(
@@ -38,6 +40,7 @@ def from_llm_and_tools(
3840
tools: Optional[List[Tool]] = None,
3941
output_parser: Optional[OpenAIFunctionOutputParser] = None,
4042
prompt: str = None,
43+
min_confidence: int = 3,
4144
**kwargs: Any,
4245
) -> OpenAIFunctionsAgent:
4346
tools = tools or []
@@ -50,39 +53,98 @@ def from_llm_and_tools(
5053
output_parser=_output_parser,
5154
tools=tools,
5255
prompt=prompt,
56+
min_confidence=min_confidence,
5357
**kwargs,
5458
)
5559

5660
def plan(
5761
self,
5862
history: ChatMessageHistory,
5963
intermediate_steps: List[AgentAction],
64+
retries: int = 2,
6065
**kwargs: Any,
6166
) -> Union[AgentAction, AgentFinish]:
62-
print_with_color("Planning", Fore.LIGHTYELLOW_EX)
67+
while retries > 0:
68+
print_with_color("Planning", Fore.LIGHTYELLOW_EX)
6369

64-
final_messages = []
65-
if self.prompt:
66-
final_messages.append(SystemMessage(content=self.prompt))
67-
final_messages += history.messages
70+
final_messages = []
71+
if self.prompt:
72+
final_messages.append(SystemMessage(content=self.prompt))
73+
final_messages += history.messages
6874

69-
logger.info(f"\nPlanning Input: {[m.content for m in final_messages]} \n")
70-
full_output: Generation = self.llm.generate(
71-
final_messages, self.tools
72-
).generations[0]
75+
logger.info(f"\nPlanning Input: {[m.content for m in final_messages]} \n")
76+
full_output: Generation = self.llm.generate(
77+
final_messages, self.tools
78+
).generations[0]
7379

74-
agent_output: Union[AgentAction, AgentFinish] = self.output_parser.parse(
75-
full_output.message
80+
agent_output: Union[AgentAction, AgentFinish] = self.output_parser.parse(
81+
full_output.message
82+
)
83+
print(
84+
f"Planning output: \nmessage content: {repr(full_output.message.content)}; "
85+
f"function_call: "
86+
f"{repr(full_output.message.function_call)}",
87+
Fore.YELLOW,
88+
)
89+
if isinstance(agent_output, AgentAction):
90+
print_with_color(
91+
f"Plan to take action '{agent_output.tool}'", Fore.LIGHTYELLOW_EX
92+
)
93+
94+
generation_is_confident = self.is_generation_confident(
95+
history=history,
96+
agent_output=agent_output,
97+
min_confidence=self.min_confidence,
98+
)
99+
if not generation_is_confident:
100+
retries -= 1
101+
print_with_color(
102+
f"Generation is not confident, {retries} retries left",
103+
Fore.LIGHTYELLOW_EX,
104+
)
105+
continue
106+
else:
107+
return agent_output
108+
109+
def is_generation_confident(
110+
self,
111+
history: ChatMessageHistory,
112+
agent_output: Union[AgentAction, AgentFinish],
113+
min_confidence: int = 3,
114+
) -> bool:
115+
"""
116+
Estimate the confidence of the generation
117+
Args:
118+
history: history of the conversation
119+
agent_output: the output from the agent
120+
min_confidence: minimum confidence score to be considered as confident
121+
"""
122+
123+
def _format_assistant_message(action_output: Union[AgentAction, AgentFinish]):
124+
if isinstance(action_output, AgentFinish):
125+
assistant_message = f"Assistant: {action_output.message}"
126+
elif isinstance(action_output, AgentAction):
127+
assistant_message = f"Action: {action_output.tool} with input: {action_output.tool_input}"
128+
else:
129+
raise ValueError("Unsupported action for estimating confidence score")
130+
131+
return assistant_message
132+
133+
prompt = Template(ESTIMATE_CONFIDENCE_PROMPT).substitute(
134+
policy=self.prompt,
135+
conversation_history=history.format_message(),
136+
assistant_message=_format_assistant_message(agent_output),
76137
)
77-
print(
78-
f"Planning output: \nmessage content: {repr(full_output.message.content)}; "
79-
f"function_call: "
80-
f"{repr(full_output.message.function_call)}",
81-
Fore.YELLOW,
138+
logger.info(f"\nEstimate confidence prompt: {prompt} \n")
139+
140+
message = UserMessage(content=prompt)
141+
142+
full_output: Generation = self.llm.generate([message], self.tools).generations[
143+
0
144+
]
145+
146+
estimated_confidence = self.output_parser.parse_estimated_confidence(
147+
full_output.message
82148
)
83-
if isinstance(agent_output, AgentAction):
84-
print_with_color(
85-
f"Plan to take action '{agent_output.tool}'", Fore.LIGHTYELLOW_EX
86-
)
87149

88-
return agent_output
150+
return estimated_confidence >= min_confidence

autochain/agent/openai_functions_agent/output_parser.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import json
2+
import logging
3+
import re
24
from typing import Union
35

46
from autochain.agent.message import AIMessage
57
from autochain.agent.structs import AgentAction, AgentFinish, AgentOutputParser
68

9+
logger = logging.getLogger(__name__)
10+
711

812
class OpenAIFunctionOutputParser(AgentOutputParser):
913
def parse(self, message: AIMessage) -> Union[AgentAction, AgentFinish]:
@@ -18,3 +22,26 @@ def parse(self, message: AIMessage) -> Union[AgentAction, AgentFinish]:
1822
)
1923
else:
2024
return AgentFinish(message=message.content, log=message.content)
25+
26+
def parse_estimated_confidence(self, message: AIMessage) -> int:
27+
"""Parse estimated confidence from the message"""
28+
29+
def find_first_integer(input_string):
30+
# Define a regular expression pattern to match integers
31+
pattern = re.compile(r"\d+")
32+
33+
# Search for the first match in the input string
34+
match = pattern.search(input_string)
35+
36+
# Check if a match is found
37+
if match:
38+
# Extract and return the matched integer
39+
return int(match.group())
40+
else:
41+
# Return 0 if no integer is found
42+
logger.info(f"\nCannot find confidence in message: {input_string}\n")
43+
return 0
44+
45+
content = message.content.strip()
46+
47+
return find_first_integer(content)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
ESTIMATE_CONFIDENCE_PROMPT = """Given the system policy assistant needs to strictly follow and
2+
the conversation history between user and assistant so far,
3+
"System policy: ${policy}
4+
${conversation_history}"
5+
6+
How confident are you the next step from assistant should be the following:
7+
"${assistant_message}"
8+
9+
Estimate the confidence from 1-5, 1 being the least confident and 5 being the most confident.
10+
Confidence:
11+
"""

autochain/agent/structs.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
import json
2-
import re
32
from abc import abstractmethod
4-
from typing import Union, Any, Dict, List
3+
from typing import Any, Dict, List, Union
54

5+
from autochain.agent.message import BaseMessage, UserMessage
6+
from autochain.chain import constants
67
from autochain.models.base import Generation
7-
88
from autochain.models.chat_openai import ChatOpenAI
99
from pydantic import BaseModel
1010

11-
from autochain.agent.message import BaseMessage, UserMessage
12-
from autochain.chain import constants
13-
from autochain.errors import OutputParserException
14-
1511

1612
class AgentAction(BaseModel):
1713
"""Agent's action to take."""
@@ -89,3 +85,7 @@ def parse_clarification(
8985
) -> Union[AgentAction, AgentFinish]:
9086
"""Parse clarification outputs"""
9187
return agent_action
88+
89+
def parse_estimated_confidence(self, message: BaseMessage) -> int:
90+
"""Parse estimated confidence from the message"""
91+
return 1

autochain/chain/base_chain.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ def _run(
137137
self.memory.save_conversation(
138138
message=str(next_step_output.tool_output),
139139
name=next_step_output.tool,
140+
conversational_message=f"{next_step_output.tool} with input: "
141+
f"{next_step_output.tool_input}",
140142
message_type=MessageType.FunctionMessage,
141143
)
142144

autochain/chain/chain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def handle_repeated_action(self, agent_action: AgentAction) -> AgentFinish:
3737
print("No response from agent. Gracefully exit due to repeated action")
3838
return AgentFinish(
3939
message=self.graceful_exit_tool.run(),
40-
log=f"Gracefully exit due to repeated action",
40+
log="Gracefully exit due to repeated action",
4141
)
4242

4343
def take_next_step(

0 commit comments

Comments
 (0)