|
| 1 | +""" |
| 2 | +This file is a modified version for ChatGLM3-6B the original glm3_agent.py file from the langchain repo. |
| 3 | +""" |
| 4 | + |
| 5 | +import json |
| 6 | +import logging |
| 7 | +import typing |
| 8 | +from typing import Sequence, Optional, Union |
| 9 | + |
| 10 | +import langchain_core.prompts |
| 11 | +import langchain_core.messages |
| 12 | +from langchain_core.runnables import Runnable, RunnablePassthrough |
| 13 | +from langchain.agents.agent import AgentOutputParser |
| 14 | +from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser |
| 15 | +from langchain.prompts.chat import ChatPromptTemplate |
| 16 | +from langchain.output_parsers import OutputFixingParser |
| 17 | +from langchain.schema import AgentAction, AgentFinish, OutputParserException |
| 18 | +from langchain.schema.language_model import BaseLanguageModel |
| 19 | +from langchain.tools.base import BaseTool |
| 20 | +from pydantic.v1 import Field |
| 21 | + |
| 22 | +from pydantic.v1.schema import model_schema |
| 23 | + |
| 24 | +logger = logging.getLogger(__name__) |
| 25 | + |
| 26 | +SYSTEM_PROMPT = "Answer the following questions as best as you can. You have access to the following tools:\n{tools}" |
| 27 | +HUMAN_MESSAGE = "Let's start! Human:{input}\n\n{agent_scratchpad}" |
| 28 | + |
| 29 | + |
| 30 | +class StructuredGLM3ChatOutputParser(AgentOutputParser): |
| 31 | + """ |
| 32 | + Output parser with retries for the structured chat agent. |
| 33 | + """ |
| 34 | + base_parser: AgentOutputParser = Field(default_factory=StructuredChatOutputParser) |
| 35 | + output_fixing_parser: Optional[OutputFixingParser] = None |
| 36 | + |
| 37 | + def parse(self, text: str) -> Union[AgentAction, AgentFinish]: |
| 38 | + print(text) |
| 39 | + |
| 40 | + special_tokens = ["Action:", "<|observation|>"] |
| 41 | + first_index = min([text.find(token) if token in text else len(text) for token in special_tokens]) |
| 42 | + text = text[:first_index] |
| 43 | + |
| 44 | + if "tool_call" in text: |
| 45 | + action_end = text.find("```") |
| 46 | + action = text[:action_end].strip() |
| 47 | + params_str_start = text.find("(") + 1 |
| 48 | + params_str_end = text.rfind(")") |
| 49 | + params_str = text[params_str_start:params_str_end] |
| 50 | + |
| 51 | + params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param] |
| 52 | + params = {pair[0].strip(): pair[1].strip().strip("'\"") for pair in params_pairs} |
| 53 | + |
| 54 | + action_json = { |
| 55 | + "action": action, |
| 56 | + "action_input": params |
| 57 | + } |
| 58 | + else: |
| 59 | + action_json = { |
| 60 | + "action": "Final Answer", |
| 61 | + "action_input": text |
| 62 | + } |
| 63 | + action_str = f""" |
| 64 | +Action: |
| 65 | +``` |
| 66 | +{json.dumps(action_json, ensure_ascii=False)} |
| 67 | +```""" |
| 68 | + try: |
| 69 | + if self.output_fixing_parser is not None: |
| 70 | + parsed_obj: Union[ |
| 71 | + AgentAction, AgentFinish |
| 72 | + ] = self.output_fixing_parser.parse(action_str) |
| 73 | + else: |
| 74 | + parsed_obj = self.base_parser.parse(action_str) |
| 75 | + return parsed_obj |
| 76 | + except Exception as e: |
| 77 | + raise OutputParserException(f"Could not parse LLM output: {text}") from e |
| 78 | + |
| 79 | + @property |
| 80 | + def _type(self) -> str: |
| 81 | + return "StructuredGLM3ChatOutputParser" |
| 82 | + |
| 83 | + |
| 84 | +def create_structured_glm3_chat_agent( |
| 85 | + llm: BaseLanguageModel, tools: Sequence[BaseTool] |
| 86 | +) -> Runnable: |
| 87 | + tools_json = [] |
| 88 | + for tool in tools: |
| 89 | + tool_schema = model_schema(tool.args_schema) if tool.args_schema else {} |
| 90 | + description = tool.description.split(" - ")[ |
| 91 | + 1].strip() if tool.description and " - " in tool.description else tool.description |
| 92 | + parameters = {k: {sub_k: sub_v for sub_k, sub_v in v.items() if sub_k != 'title'} for k, v in |
| 93 | + tool_schema.get("properties", {}).items()} |
| 94 | + simplified_config_langchain = { |
| 95 | + "name": tool.name, |
| 96 | + "description": description, |
| 97 | + "parameters": parameters |
| 98 | + } |
| 99 | + tools_json.append(simplified_config_langchain) |
| 100 | + tools = "\n".join([str(tool) for tool in tools_json]) |
| 101 | + |
| 102 | + prompt = ChatPromptTemplate( |
| 103 | + input_variables=["input", "agent_scratchpad"], |
| 104 | + input_types={'chat_history': typing.List[typing.Union[ |
| 105 | + langchain_core.messages.ai.AIMessage, |
| 106 | + langchain_core.messages.human.HumanMessage, |
| 107 | + langchain_core.messages.chat.ChatMessage, |
| 108 | + langchain_core.messages.system.SystemMessage, |
| 109 | + langchain_core.messages.function.FunctionMessage, |
| 110 | + langchain_core.messages.tool.ToolMessage]] |
| 111 | + }, |
| 112 | + messages=[ |
| 113 | + langchain_core.prompts.SystemMessagePromptTemplate( |
| 114 | + prompt=langchain_core.prompts.PromptTemplate( |
| 115 | + input_variables=['tools'], |
| 116 | + template=SYSTEM_PROMPT) |
| 117 | + ), |
| 118 | + langchain_core.prompts.MessagesPlaceholder( |
| 119 | + variable_name='chat_history', |
| 120 | + optional=True |
| 121 | + ), |
| 122 | + langchain_core.prompts.HumanMessagePromptTemplate( |
| 123 | + prompt=langchain_core.prompts.PromptTemplate( |
| 124 | + input_variables=['agent_scratchpad', 'input'], |
| 125 | + template=HUMAN_MESSAGE |
| 126 | + ) |
| 127 | + ) |
| 128 | + ] |
| 129 | + |
| 130 | + ).partial(tools=tools) |
| 131 | + |
| 132 | + llm_with_stop = llm.bind(stop=["<|observation|>"]) |
| 133 | + agent = ( |
| 134 | + RunnablePassthrough.assign( |
| 135 | + agent_scratchpad=lambda x: x["intermediate_steps"], |
| 136 | + ) |
| 137 | + | prompt |
| 138 | + | llm_with_stop |
| 139 | + | StructuredGLM3ChatOutputParser() |
| 140 | + ) |
| 141 | + return agent |
0 commit comments