|
| 1 | +# ------------------------------- |
| 2 | +# 1. Define the Python tool |
| 3 | +# ------------------------------- |
| 4 | +import copy |
| 5 | +import io |
| 6 | +import random |
| 7 | +import sys |
| 8 | +from typing import Dict, List |
| 9 | + |
| 10 | +import ray |
| 11 | +import torch |
| 12 | +from langchain_core.language_models import BaseChatModel |
| 13 | +from langchain_core.messages import AIMessage |
| 14 | +from langchain_core.outputs.chat_generation import ChatGeneration |
| 15 | +from langchain_core.outputs.chat_result import ChatResult |
| 16 | +from langchain_core.prompts import PromptTemplate |
| 17 | +from langchain_core.tools import tool |
| 18 | +from langgraph.checkpoint.memory import MemorySaver |
| 19 | +from langgraph.prebuilt import create_react_agent |
| 20 | +from tool_calling_llm import ToolCallingLLM |
| 21 | +from transformers import AutoTokenizer |
| 22 | + |
| 23 | +SYSTEM_PROMPT_TEMPLATE = """{task_description}. You have access to the following tools: |
| 24 | +
|
| 25 | +{tools} |
| 26 | +
|
| 27 | +Use the following format: |
| 28 | +
|
| 29 | +Question: the input question you must answer |
| 30 | +Thought: you should always think about what to do |
| 31 | +Action: the action to take, should be one of [{tool_names}] |
| 32 | +Action Input: the input to the action |
| 33 | +Observation: the result of the action |
| 34 | +... (this Thought/Action/Action Input/Observation can repeat N times) |
| 35 | +Thought: I now know the final answer |
| 36 | +Final Answer: the final answer to the original input question |
| 37 | +
|
| 38 | +Begin! |
| 39 | +
|
| 40 | +Question: {input} |
| 41 | +Thought:{agent_scratchpad}""" |
| 42 | + |
| 43 | +SYSTEM_PROMPT = PromptTemplate.from_template(SYSTEM_PROMPT_TEMPLATE) |
| 44 | + |
| 45 | + |
| 46 | +class Capturing(list): |
| 47 | + """Capture stdout prints inside exec()""" |
| 48 | + |
| 49 | + def __enter__(self): |
| 50 | + self._stdout = sys.stdout |
| 51 | + sys.stdout = self._stringio = io.StringIO() |
| 52 | + return self |
| 53 | + |
| 54 | + def __exit__(self, *args): |
| 55 | + self.extend(self._stringio.getvalue().splitlines()) |
| 56 | + sys.stdout = self._stdout |
| 57 | + |
| 58 | + |
| 59 | +@tool |
| 60 | +def python(code: str) -> str: |
| 61 | + """ |
| 62 | + This function executes a string of Python code and returns the printed output. |
| 63 | + You need to print the output. Please import all libraries used in the code string. |
| 64 | + """ |
| 65 | + local_vars = {} |
| 66 | + with Capturing() as output: |
| 67 | + exec(code, {}, local_vars) |
| 68 | + if output == []: |
| 69 | + return "Error: No output printed from the code. Please ensure you print the output." |
| 70 | + return "\n".join(output) |
| 71 | + |
| 72 | + |
| 73 | +# ------------------------------- |
| 74 | +# 2. Define a Custom API LLM wrapper |
| 75 | +# ------------------------------- |
| 76 | +class CustomOpenAIAPILLM: |
| 77 | + def __init__(self, cfg: dict, producer_idx, generation_workers=None): |
| 78 | + self.producer_idx = producer_idx |
| 79 | + self.generation_workers = generation_workers |
| 80 | + self.load_balancer_idx = producer_idx % len(self.generation_workers) |
| 81 | + assert "model" in cfg, "Please specify the model name in the config" |
| 82 | + self.tokenizer = AutoTokenizer.from_pretrained(cfg["model"]) |
| 83 | + self.role_mapping = { |
| 84 | + "system": "system", |
| 85 | + "user": "user", |
| 86 | + "assistant": "assistant", |
| 87 | + "human": "user", |
| 88 | + "tool": "tool", |
| 89 | + } |
| 90 | + |
| 91 | + def invoke(self, messages: List[Dict[str, str]], **kwargs) -> str: |
| 92 | + """ |
| 93 | + messages: list of {"role": "user"/"assistant"/"system", "content": "..."} |
| 94 | + """ |
| 95 | + # load balancing |
| 96 | + load = [ray.get(generation_worker.get_producer_load.remote()) for generation_worker in self.generation_workers] |
| 97 | + min_load = min(load) |
| 98 | + candidates = [i for i, l in enumerate(load) if l == min_load] |
| 99 | + # random tie break |
| 100 | + self.load_balancer_idx = random.choice(candidates) |
| 101 | + generation_worker = self.generation_workers[self.load_balancer_idx] |
| 102 | + transformer_messages = [] |
| 103 | + for message in messages: |
| 104 | + transformer_messages.append({"role": self.role_mapping[message.type], "content": message.content}) |
| 105 | + input_ids = self.tokenizer.apply_chat_template( |
| 106 | + transformer_messages, return_tensors="pt", tokenize=True, add_generation_prompt=True |
| 107 | + ) |
| 108 | + attention_mask = torch.ones_like(input_ids) |
| 109 | + rollouts = ray.get(generation_worker.generate.remote(input_ids, attention_mask, **kwargs)) |
| 110 | + response = self.tokenizer.batch_decode( |
| 111 | + rollouts["input_ids"][0][:, input_ids.size(-1) :], skip_special_tokens=True |
| 112 | + )[0] |
| 113 | + return response |
| 114 | + |
| 115 | + |
| 116 | +class LangChainCustomLLM(ToolCallingLLM, BaseChatModel): |
| 117 | + client: CustomOpenAIAPILLM = None |
| 118 | + |
| 119 | + def __init__(self, client: CustomOpenAIAPILLM): |
| 120 | + super().__init__() |
| 121 | + self.client = client |
| 122 | + |
| 123 | + def _generate(self, messages, stop=None, run_manager=None, **kwargs): |
| 124 | + # content = self.client.invoke([m.dict() for m in messages]) |
| 125 | + # chat_result = ChatResult( |
| 126 | + # generations=[ChatGeneration(message=AIMessage(content=content))] |
| 127 | + # ) |
| 128 | + print("messages:", messages) |
| 129 | + breakpoint() |
| 130 | + system_message, functions = self._generate_system_message_and_functions(kwargs) |
| 131 | + sample_params = {"stop": stop} if stop is not None else {} |
| 132 | + sample_params.update({k: v for k, v in kwargs.items() if k in ["temperature", "top_p", "top_k", "max_tokens"]}) |
| 133 | + messages_ = copy.deepcopy(messages) |
| 134 | + messages_[0].content = messages_[0].content + "\n" + system_message.content |
| 135 | + response_message = self.client.invoke( # type: ignore[safe-super] |
| 136 | + [system_message] + messages, **{"sample_params": sample_params} |
| 137 | + ) |
| 138 | + breakpoint() |
| 139 | + response = self._process_response(AIMessage(content=response_message), functions) |
| 140 | + return ChatResult(generations=[ChatGeneration(message=response)]) |
| 141 | + |
| 142 | + @property |
| 143 | + def _llm_type(self) -> str: |
| 144 | + return "custom-api-llm" |
| 145 | + |
| 146 | + |
| 147 | +# ------------------------------- |
| 148 | +# 3. Build a ReAct Agent with LangGraph |
| 149 | +# ------------------------------- |
| 150 | +def build_agent(): |
| 151 | + # Wrap custom API LLM in LangChain-compatible interface |
| 152 | + |
| 153 | + # Init LLM |
| 154 | + llm_client = CustomOpenAIAPILLM() |
| 155 | + llm = LangChainCustomLLM(llm_client) |
| 156 | + |
| 157 | + # Tools |
| 158 | + tools = [python] |
| 159 | + |
| 160 | + # Memory (optional) |
| 161 | + memory = MemorySaver() |
| 162 | + |
| 163 | + # Build ReAct agent |
| 164 | + agent = create_react_agent(llm, tools, checkpointer=memory) |
| 165 | + return agent |
| 166 | + |
| 167 | + |
| 168 | +# ------------------------------- |
| 169 | +# 4. Run the agent on a math problem |
| 170 | +# ------------------------------- |
| 171 | +if __name__ == "__main__": |
| 172 | + agent = build_agent() |
| 173 | + |
| 174 | + # Example math question |
| 175 | + user_input = "What is the least common multiple of 18 and 24? Use Python if needed." |
| 176 | + |
| 177 | + config = {"configurable": {"thread_id": "math-1"}} |
| 178 | + for event in agent.stream({"messages": [("user", user_input)]}, config): |
| 179 | + if "agent" in event: |
| 180 | + print("Agent event:", event["agent"]["messages"][-1].content) |
| 181 | + elif "tools" in event: |
| 182 | + print("Tool event:", event["tools"]["messages"][-1].content) |
| 183 | + |
| 184 | + final_state = agent.get_state(config) |
| 185 | + print("Final Answer:", final_state["messages"][-1].content) |
0 commit comments