11from __future__ import annotations
22
33import logging
4+ from string import Template
45from typing import Any , Dict , List , Optional , Union
56
6- from colorama import Fore
7-
87from autochain .agent .base_agent import BaseAgent
9- from autochain .agent .message import ChatMessageHistory , SystemMessage
8+ from autochain .agent .message import ChatMessageHistory , SystemMessage , UserMessage
109from autochain .agent .openai_functions_agent .output_parser import (
1110 OpenAIFunctionOutputParser ,
1211)
12+ from autochain .agent .openai_functions_agent .prompt import ESTIMATE_CONFIDENCE_PROMPT
1313from autochain .agent .structs import AgentAction , AgentFinish
1414from autochain .models .base import BaseLanguageModel , Generation
1515from autochain .tools .base import Tool
1616from autochain .utils import print_with_color
17+ from colorama import Fore
1718
1819logger = logging .getLogger (__name__ )
1920
@@ -30,6 +31,7 @@ class OpenAIFunctionsAgent(BaseAgent):
3031 allowed_tools : Dict [str , Tool ] = {}
3132 tools : List [Tool ] = []
3233 prompt : Optional [str ] = None
34+ min_confidence : int = 3
3335
3436 @classmethod
3537 def from_llm_and_tools (
@@ -38,6 +40,7 @@ def from_llm_and_tools(
3840 tools : Optional [List [Tool ]] = None ,
3941 output_parser : Optional [OpenAIFunctionOutputParser ] = None ,
4042 prompt : str = None ,
43+ min_confidence : int = 3 ,
4144 ** kwargs : Any ,
4245 ) -> OpenAIFunctionsAgent :
4346 tools = tools or []
@@ -50,39 +53,98 @@ def from_llm_and_tools(
5053 output_parser = _output_parser ,
5154 tools = tools ,
5255 prompt = prompt ,
56+ min_confidence = min_confidence ,
5357 ** kwargs ,
5458 )
5559
5660 def plan (
5761 self ,
5862 history : ChatMessageHistory ,
5963 intermediate_steps : List [AgentAction ],
64+ retries : int = 2 ,
6065 ** kwargs : Any ,
6166 ) -> Union [AgentAction , AgentFinish ]:
62- print_with_color ("Planning" , Fore .LIGHTYELLOW_EX )
67+ while retries > 0 :
68+ print_with_color ("Planning" , Fore .LIGHTYELLOW_EX )
6369
64- final_messages = []
65- if self .prompt :
66- final_messages .append (SystemMessage (content = self .prompt ))
67- final_messages += history .messages
70+ final_messages = []
71+ if self .prompt :
72+ final_messages .append (SystemMessage (content = self .prompt ))
73+ final_messages += history .messages
6874
69- logger .info (f"\n Planning Input: { [m .content for m in final_messages ]} \n " )
70- full_output : Generation = self .llm .generate (
71- final_messages , self .tools
72- ).generations [0 ]
75+ logger .info (f"\n Planning Input: { [m .content for m in final_messages ]} \n " )
76+ full_output : Generation = self .llm .generate (
77+ final_messages , self .tools
78+ ).generations [0 ]
7379
74- agent_output : Union [AgentAction , AgentFinish ] = self .output_parser .parse (
75- full_output .message
80+ agent_output : Union [AgentAction , AgentFinish ] = self .output_parser .parse (
81+ full_output .message
82+ )
83+ print (
84+ f"Planning output: \n message content: { repr (full_output .message .content )} ; "
85+ f"function_call: "
86+ f"{ repr (full_output .message .function_call )} " ,
87+ Fore .YELLOW ,
88+ )
89+ if isinstance (agent_output , AgentAction ):
90+ print_with_color (
91+ f"Plan to take action '{ agent_output .tool } '" , Fore .LIGHTYELLOW_EX
92+ )
93+
94+ generation_is_confident = self .is_generation_confident (
95+ history = history ,
96+ agent_output = agent_output ,
97+ min_confidence = self .min_confidence ,
98+ )
99+ if not generation_is_confident :
100+ retries -= 1
101+ print_with_color (
102+ f"Generation is not confident, { retries } retries left" ,
103+ Fore .LIGHTYELLOW_EX ,
104+ )
105+ continue
106+ else :
107+ return agent_output
108+
109+ def is_generation_confident (
110+ self ,
111+ history : ChatMessageHistory ,
112+ agent_output : Union [AgentAction , AgentFinish ],
113+ min_confidence : int = 3 ,
114+ ) -> bool :
115+ """
116+ Estimate the confidence of the generation
117+ Args:
118+ history: history of the conversation
119+ agent_output: the output from the agent
120+ min_confidence: minimum confidence score to be considered as confident
121+ """
122+
123+ def _format_assistant_message (action_output : Union [AgentAction , AgentFinish ]):
124+ if isinstance (action_output , AgentFinish ):
125+ assistant_message = f"Assistant: { action_output .message } "
126+ elif isinstance (action_output , AgentAction ):
127+ assistant_message = f"Action: { action_output .tool } with input: { action_output .tool_input } "
128+ else :
129+ raise ValueError ("Unsupported action for estimating confidence score" )
130+
131+ return assistant_message
132+
133+ prompt = Template (ESTIMATE_CONFIDENCE_PROMPT ).substitute (
134+ policy = self .prompt ,
135+ conversation_history = history .format_message (),
136+ assistant_message = _format_assistant_message (agent_output ),
76137 )
77- print (
78- f"Planning output: \n message content: { repr (full_output .message .content )} ; "
79- f"function_call: "
80- f"{ repr (full_output .message .function_call )} " ,
81- Fore .YELLOW ,
138+ logger .info (f"\n Estimate confidence prompt: { prompt } \n " )
139+
140+ message = UserMessage (content = prompt )
141+
142+ full_output : Generation = self .llm .generate ([message ], self .tools ).generations [
143+ 0
144+ ]
145+
146+ estimated_confidence = self .output_parser .parse_estimated_confidence (
147+ full_output .message
82148 )
83- if isinstance (agent_output , AgentAction ):
84- print_with_color (
85- f"Plan to take action '{ agent_output .tool } '" , Fore .LIGHTYELLOW_EX
86- )
87149
88- return agent_output
150+ return estimated_confidence >= min_confidence
0 commit comments