Skip to content

Commit 1b89a43

Browse files
authored
(wip) Harrison/serialize agents (#725)
1 parent cc70565 commit 1b89a43

File tree

9 files changed

+363
-48
lines changed

9 files changed

+363
-48
lines changed
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "bfe18e28",
6+
"metadata": {},
7+
"source": [
8+
"# Serialization\n",
9+
"\n",
10+
"This notebook goes over how to serialize agents. For this notebook, it is important to understand the distinction we draw between `agents` and `tools`. An agent is the LLM powered decision maker that decides which actions to take and in which order. Tools are various instruments (functions) an agent has access to, through which an agent can interact with the outside world. When people generally use agents, they primarily talk about using an agent WITH tools. However, when we talk about serialization of agents, we are talking about the agent by itself. We plan to add support for serializing an agent WITH tools sometime in the future.\n",
11+
"\n",
12+
"Let's start by creating an agent with tools as we normally do:"
13+
]
14+
},
15+
{
16+
"cell_type": "code",
17+
"execution_count": 1,
18+
"id": "eb729f16",
19+
"metadata": {},
20+
"outputs": [],
21+
"source": [
22+
"from langchain.agents import load_tools\n",
23+
"from langchain.agents import initialize_agent\n",
24+
"from langchain.llms import OpenAI\n",
25+
"\n",
26+
"llm = OpenAI(temperature=0)\n",
27+
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm)\n",
28+
"agent = initialize_agent(tools, llm, agent=\"zero-shot-react-description\", verbose=True)"
29+
]
30+
},
31+
{
32+
"cell_type": "markdown",
33+
"id": "0578f566",
34+
"metadata": {},
35+
"source": [
36+
"Let's now serialize the agent. To be explicit that we are serializing ONLY the agent, we will call the `save_agent` method."
37+
]
38+
},
39+
{
40+
"cell_type": "code",
41+
"execution_count": 2,
42+
"id": "dc544de6",
43+
"metadata": {},
44+
"outputs": [],
45+
"source": [
46+
"agent.save_agent('agent.json')"
47+
]
48+
},
49+
{
50+
"cell_type": "code",
51+
"execution_count": 3,
52+
"id": "62dd45bf",
53+
"metadata": {},
54+
"outputs": [
55+
{
56+
"name": "stdout",
57+
"output_type": "stream",
58+
"text": [
59+
"{\r\n",
60+
" \"llm_chain\": {\r\n",
61+
" \"memory\": null,\r\n",
62+
" \"verbose\": false,\r\n",
63+
" \"prompt\": {\r\n",
64+
" \"input_variables\": [\r\n",
65+
" \"input\",\r\n",
66+
" \"agent_scratchpad\"\r\n",
67+
" ],\r\n",
68+
" \"output_parser\": null,\r\n",
69+
" \"template\": \"Answer the following questions as best you can. You have access to the following tools:\\n\\nSearch: A search engine. Useful for when you need to answer questions about current events. Input should be a search query.\\nCalculator: Useful for when you need to answer questions about math.\\n\\nUse the following format:\\n\\nQuestion: the input question you must answer\\nThought: you should always think about what to do\\nAction: the action to take, should be one of [Search, Calculator]\\nAction Input: the input to the action\\nObservation: the result of the action\\n... (this Thought/Action/Action Input/Observation can repeat N times)\\nThought: I now know the final answer\\nFinal Answer: the final answer to the original input question\\n\\nBegin!\\n\\nQuestion: {input}\\nThought:{agent_scratchpad}\",\r\n",
70+
" \"template_format\": \"f-string\"\r\n",
71+
" },\r\n",
72+
" \"llm\": {\r\n",
73+
" \"model_name\": \"text-davinci-003\",\r\n",
74+
" \"temperature\": 0.0,\r\n",
75+
" \"max_tokens\": 256,\r\n",
76+
" \"top_p\": 1,\r\n",
77+
" \"frequency_penalty\": 0,\r\n",
78+
" \"presence_penalty\": 0,\r\n",
79+
" \"n\": 1,\r\n",
80+
" \"best_of\": 1,\r\n",
81+
" \"request_timeout\": null,\r\n",
82+
" \"logit_bias\": {},\r\n",
83+
" \"_type\": \"openai\"\r\n",
84+
" },\r\n",
85+
" \"output_key\": \"text\",\r\n",
86+
" \"_type\": \"llm_chain\"\r\n",
87+
" },\r\n",
88+
" \"return_values\": [\r\n",
89+
" \"output\"\r\n",
90+
" ],\r\n",
91+
" \"_type\": \"zero-shot-react-description\"\r\n",
92+
"}"
93+
]
94+
}
95+
],
96+
"source": [
97+
"!cat agent.json"
98+
]
99+
},
100+
{
101+
"cell_type": "markdown",
102+
"id": "0eb72510",
103+
"metadata": {},
104+
"source": [
105+
"We can now load the agent back in"
106+
]
107+
},
108+
{
109+
"cell_type": "code",
110+
"execution_count": 6,
111+
"id": "eb660b76",
112+
"metadata": {},
113+
"outputs": [],
114+
"source": [
115+
"agent = initialize_agent(tools, llm, agent_path=\"agent.json\", verbose=True)"
116+
]
117+
},
118+
{
119+
"cell_type": "code",
120+
"execution_count": null,
121+
"id": "aa624ea5",
122+
"metadata": {},
123+
"outputs": [],
124+
"source": []
125+
}
126+
],
127+
"metadata": {
128+
"kernelspec": {
129+
"display_name": "Python 3 (ipykernel)",
130+
"language": "python",
131+
"name": "python3"
132+
},
133+
"language_info": {
134+
"codemirror_mode": {
135+
"name": "ipython",
136+
"version": 3
137+
},
138+
"file_extension": ".py",
139+
"mimetype": "text/x-python",
140+
"name": "python",
141+
"nbconvert_exporter": "python",
142+
"pygments_lexer": "ipython3",
143+
"version": "3.10.9"
144+
}
145+
},
146+
"nbformat": 4,
147+
"nbformat_minor": 5
148+
}

langchain/agents/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Interface for agents."""
22
from langchain.agents.agent import Agent, AgentExecutor
33
from langchain.agents.conversational.base import ConversationalAgent
4+
from langchain.agents.initialize import initialize_agent
45
from langchain.agents.load_tools import get_all_tool_names, load_tools
5-
from langchain.agents.loading import initialize_agent
6+
from langchain.agents.loading import load_agent
67
from langchain.agents.mrkl.base import MRKLChain, ZeroShotAgent
78
from langchain.agents.react.base import ReActChain, ReActTextWorldAgent
89
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain
@@ -21,4 +22,5 @@
2122
"load_tools",
2223
"get_all_tool_names",
2324
"ConversationalAgent",
25+
"load_agent",
2426
]

langchain/agents/agent.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
"""Chain that takes in an input and produces an action and action input."""
22
from __future__ import annotations
33

4+
import json
45
import logging
56
from abc import abstractmethod
7+
from pathlib import Path
68
from typing import Any, Dict, List, Optional, Tuple, Union
79

10+
import yaml
811
from pydantic import BaseModel, root_validator
912

1013
from langchain.agents.tools import Tool
@@ -30,6 +33,7 @@ class Agent(BaseModel):
3033
"""
3134

3235
llm_chain: LLMChain
36+
allowed_tools: List[str]
3337
return_values: List[str] = ["output"]
3438

3539
@abstractmethod
@@ -146,7 +150,8 @@ def from_llm_and_tools(
146150
prompt=cls.create_prompt(tools),
147151
callback_manager=callback_manager,
148152
)
149-
return cls(llm_chain=llm_chain, **kwargs)
153+
tool_names = [tool.name for tool in tools]
154+
return cls(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
150155

151156
def return_stopped_response(
152157
self,
@@ -192,6 +197,50 @@ def return_stopped_response(
192197
f"got {early_stopping_method}"
193198
)
194199

200+
@property
201+
@abstractmethod
202+
def _agent_type(self) -> str:
203+
"""Return Identifier of agent type."""
204+
205+
def dict(self, **kwargs: Any) -> Dict:
206+
"""Return dictionary representation of agent."""
207+
_dict = super().dict()
208+
_dict["_type"] = self._agent_type
209+
return _dict
210+
211+
def save(self, file_path: Union[Path, str]) -> None:
212+
"""Save the agent.
213+
214+
Args:
215+
file_path: Path to file to save the agent to.
216+
217+
Example:
218+
.. code-block:: python
219+
220+
# If working with agent executor
221+
agent.agent.save(file_path="path/agent.yaml")
222+
"""
223+
# Convert file to Path object.
224+
if isinstance(file_path, str):
225+
save_path = Path(file_path)
226+
else:
227+
save_path = file_path
228+
229+
directory_path = save_path.parent
230+
directory_path.mkdir(parents=True, exist_ok=True)
231+
232+
# Fetch dictionary to save
233+
agent_dict = self.dict()
234+
235+
if save_path.suffix == ".json":
236+
with open(file_path, "w") as f:
237+
json.dump(agent_dict, f, indent=4)
238+
elif save_path.suffix == ".yaml":
239+
with open(file_path, "w") as f:
240+
yaml.dump(agent_dict, f, default_flow_style=False)
241+
else:
242+
raise ValueError(f"{save_path} must be json or yaml")
243+
195244

196245
class AgentExecutor(Chain, BaseModel):
197246
"""Consists of an agent using tools."""
@@ -215,6 +264,30 @@ def from_agent_and_tools(
215264
agent=agent, tools=tools, callback_manager=callback_manager, **kwargs
216265
)
217266

267+
@root_validator()
268+
def validate_tools(cls, values: Dict) -> Dict:
269+
"""Validate that tools are compatible with agent."""
270+
agent = values["agent"]
271+
tools = values["tools"]
272+
if set(agent.allowed_tools) != set([tool.name for tool in tools]):
273+
raise ValueError(
274+
f"Allowed tools ({agent.allowed_tools}) different than "
275+
f"provided tools ({[tool.name for tool in tools]})"
276+
)
277+
return values
278+
279+
def save(self, file_path: Union[Path, str]) -> None:
280+
"""Raise error - saving not supported for Agent Executors."""
281+
raise ValueError(
282+
"Saving not supported for agent executors. "
283+
"If you are trying to save the agent, please use the "
284+
"`.save_agent(...)`"
285+
)
286+
287+
def save_agent(self, file_path: Union[Path, str]) -> None:
288+
"""Save the underlying agent."""
289+
return self.agent.save(file_path)
290+
218291
@property
219292
def input_keys(self) -> List[str]:
220293
"""Return the input keys.

langchain/agents/conversational/base.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ class ConversationalAgent(Agent):
1818

1919
ai_prefix: str = "AI"
2020

21+
@property
22+
def _agent_type(self) -> str:
23+
"""Return Identifier of agent type."""
24+
return "conversational-react-description"
25+
2126
@property
2227
def observation_prefix(self) -> str:
2328
"""Prefix to append the observation with."""
@@ -100,4 +105,7 @@ def from_llm_and_tools(
100105
prompt=prompt,
101106
callback_manager=callback_manager,
102107
)
103-
return cls(llm_chain=llm_chain, ai_prefix=ai_prefix, **kwargs)
108+
tool_names = [tool.name for tool in tools]
109+
return cls(
110+
llm_chain=llm_chain, allowed_tools=tool_names, ai_prefix=ai_prefix, **kwargs
111+
)

langchain/agents/initialize.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""Load agent."""
2+
from typing import Any, List, Optional
3+
4+
from langchain.agents.agent import AgentExecutor
5+
from langchain.agents.loading import AGENT_TO_CLASS, load_agent
6+
from langchain.agents.tools import Tool
7+
from langchain.callbacks.base import BaseCallbackManager
8+
from langchain.llms.base import BaseLLM
9+
10+
11+
def initialize_agent(
12+
tools: List[Tool],
13+
llm: BaseLLM,
14+
agent: Optional[str] = None,
15+
callback_manager: Optional[BaseCallbackManager] = None,
16+
agent_path: Optional[str] = None,
17+
**kwargs: Any,
18+
) -> AgentExecutor:
19+
"""Load agent given tools and LLM.
20+
21+
Args:
22+
tools: List of tools this agent has access to.
23+
llm: Language model to use as the agent.
24+
agent: The agent to use. Valid options are:
25+
`zero-shot-react-description`
26+
`react-docstore`
27+
`self-ask-with-search`
28+
`conversational-react-description`
29+
If None and agent_path is also None, will default to
30+
`zero-shot-react-description`.
31+
callback_manager: CallbackManager to use. Global callback manager is used if
32+
not provided. Defaults to None.
33+
agent_path: Path to serialized agent to use.
34+
**kwargs: Additional key word arguments to pass to the agent.
35+
36+
Returns:
37+
An agent.
38+
"""
39+
if agent is None and agent_path is None:
40+
agent = "zero-shot-react-description"
41+
if agent is not None and agent_path is not None:
42+
raise ValueError(
43+
"Both `agent` and `agent_path` are specified, "
44+
"but at most only one should be."
45+
)
46+
if agent is not None:
47+
if agent not in AGENT_TO_CLASS:
48+
raise ValueError(
49+
f"Got unknown agent type: {agent}. "
50+
f"Valid types are: {AGENT_TO_CLASS.keys()}."
51+
)
52+
agent_cls = AGENT_TO_CLASS[agent]
53+
agent_obj = agent_cls.from_llm_and_tools(
54+
llm, tools, callback_manager=callback_manager
55+
)
56+
elif agent_path is not None:
57+
agent_obj = load_agent(agent_path, callback_manager=callback_manager)
58+
else:
59+
raise ValueError(
60+
"Somehow both `agent` and `agent_path` are None, "
61+
"this should never happen."
62+
)
63+
return AgentExecutor.from_agent_and_tools(
64+
agent=agent_obj,
65+
tools=tools,
66+
callback_manager=callback_manager,
67+
**kwargs,
68+
)

0 commit comments

Comments
 (0)