Skip to content

Commit e2a7fed

Browse files
authored
Harrison/serialize from llm and tools (#760)
1 parent 12dc7f2 commit e2a7fed

File tree

4 files changed

+74
-4
lines changed

4 files changed

+74
-4
lines changed

langchain/agents/conversational/base.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,22 @@ def from_llm_and_tools(
9191
llm: BaseLLM,
9292
tools: List[Tool],
9393
callback_manager: Optional[BaseCallbackManager] = None,
94+
prefix: str = PREFIX,
95+
suffix: str = SUFFIX,
9496
ai_prefix: str = "AI",
9597
human_prefix: str = "Human",
98+
input_variables: Optional[List[str]] = None,
9699
**kwargs: Any,
97100
) -> Agent:
98101
"""Construct an agent from an LLM and tools."""
99102
cls._validate_tools(tools)
100103
prompt = cls.create_prompt(
101-
tools, ai_prefix=ai_prefix, human_prefix=human_prefix
104+
tools,
105+
ai_prefix=ai_prefix,
106+
human_prefix=human_prefix,
107+
prefix=prefix,
108+
suffix=suffix,
109+
input_variables=input_variables,
102110
)
103111
llm_chain = LLMChain(
104112
llm=llm,

langchain/agents/initialize.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def initialize_agent(
5454
llm, tools, callback_manager=callback_manager
5555
)
5656
elif agent_path is not None:
57-
agent_obj = load_agent(agent_path, callback_manager=callback_manager)
57+
agent_obj = load_agent(
58+
agent_path, llm=llm, tools=tools, callback_manager=callback_manager
59+
)
5860
else:
5961
raise ValueError(
6062
"Somehow both `agent` and `agent_path` are None, "

langchain/agents/loading.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import tempfile
55
from pathlib import Path
6-
from typing import Any, Union
6+
from typing import Any, List, Optional, Union
77

88
import requests
99
import yaml
@@ -13,7 +13,9 @@
1313
from langchain.agents.mrkl.base import ZeroShotAgent
1414
from langchain.agents.react.base import ReActDocstoreAgent
1515
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
16+
from langchain.agents.tools import Tool
1617
from langchain.chains.loading import load_chain, load_chain_from_config
18+
from langchain.llms.base import BaseLLM
1719

1820
AGENT_TO_CLASS = {
1921
"zero-shot-react-description": ZeroShotAgent,
@@ -25,10 +27,42 @@
2527
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/agents/"
2628

2729

28-
def load_agent_from_config(config: dict, **kwargs: Any) -> Agent:
30+
def _load_agent_from_tools(
31+
config: dict, llm: BaseLLM, tools: List[Tool], **kwargs: Any
32+
) -> Agent:
33+
config_type = config.pop("_type")
34+
if config_type not in AGENT_TO_CLASS:
35+
raise ValueError(f"Loading {config_type} agent not supported")
36+
37+
if config_type not in AGENT_TO_CLASS:
38+
raise ValueError(f"Loading {config_type} agent not supported")
39+
agent_cls = AGENT_TO_CLASS[config_type]
40+
combined_config = {**config, **kwargs}
41+
return agent_cls.from_llm_and_tools(llm, tools, **combined_config)
42+
43+
44+
def load_agent_from_config(
45+
config: dict,
46+
llm: Optional[BaseLLM] = None,
47+
tools: Optional[List[Tool]] = None,
48+
**kwargs: Any,
49+
) -> Agent:
2950
"""Load agent from Config Dict."""
3051
if "_type" not in config:
3152
raise ValueError("Must specify an agent Type in config")
53+
load_from_tools = config.pop("load_from_llm_and_tools", False)
54+
if load_from_tools:
55+
if llm is None:
56+
raise ValueError(
57+
"If `load_from_llm_and_tools` is set to True, "
58+
"then LLM must be provided"
59+
)
60+
if tools is None:
61+
raise ValueError(
62+
"If `load_from_llm_and_tools` is set to True, "
63+
"then tools must be provided"
64+
)
65+
return _load_agent_from_tools(config, llm, tools, **kwargs)
3266
config_type = config.pop("_type")
3367

3468
if config_type not in AGENT_TO_CLASS:

langchain/agents/mrkl/base.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from langchain.agents.agent import Agent, AgentExecutor
88
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
99
from langchain.agents.tools import Tool
10+
from langchain.callbacks.base import BaseCallbackManager
11+
from langchain.chains import LLMChain
1012
from langchain.llms.base import BaseLLM
1113
from langchain.prompts import PromptTemplate
1214

@@ -92,6 +94,30 @@ def create_prompt(
9294
input_variables = ["input", "agent_scratchpad"]
9395
return PromptTemplate(template=template, input_variables=input_variables)
9496

97+
@classmethod
98+
def from_llm_and_tools(
99+
cls,
100+
llm: BaseLLM,
101+
tools: List[Tool],
102+
callback_manager: Optional[BaseCallbackManager] = None,
103+
prefix: str = PREFIX,
104+
suffix: str = SUFFIX,
105+
input_variables: Optional[List[str]] = None,
106+
**kwargs: Any,
107+
) -> Agent:
108+
"""Construct an agent from an LLM and tools."""
109+
cls._validate_tools(tools)
110+
prompt = cls.create_prompt(
111+
tools, prefix=prefix, suffix=suffix, input_variables=input_variables
112+
)
113+
llm_chain = LLMChain(
114+
llm=llm,
115+
prompt=prompt,
116+
callback_manager=callback_manager,
117+
)
118+
tool_names = [tool.name for tool in tools]
119+
return cls(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
120+
95121
@classmethod
96122
def _validate_tools(cls, tools: List[Tool]) -> None:
97123
for tool in tools:

0 commit comments

Comments
 (0)