Skip to content

Commit 62f82a7

Browse files
committed
add langgraph agent, still buggy
1 parent f315540 commit 62f82a7

File tree

9 files changed

+454
-151
lines changed

9 files changed

+454
-151
lines changed

applications/ColossalChat/coati/distributed/agent/agentic.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,11 @@
44

55
import ray
66
import torch
7-
from coati.distributed.agent.agentic_math_utils import TIR_SYSTEM, CustomTransformers
87
from coati.distributed.producer import BaseProducer
9-
from qwen_agent.agents import TIRMathAgent
108
from vllm import SamplingParams
119

1210

13-
@ray.remote
14-
class AgenticProducer(BaseProducer):
11+
class BaseAgenticProducer(BaseProducer):
1512
"""
1613
Asyncronous version of the producer that uses vLLM for generation.
1714
This class is designed to generate agentic response
@@ -29,7 +26,6 @@ def __init__(
2926
generate_config,
3027
async_producers,
3128
tokenizer_config=None,
32-
agentic_config=None,
3329
microbatch_size=1,
3430
backend="transformers",
3531
num_generations: int = 8,
@@ -82,10 +78,13 @@ def __init__(
8278
self.async_producers = async_producers
8379
self.num_generations = num_generations
8480
self.generate_config = generate_config
85-
self.agentic_config = model_config if not agentic_config else agentic_config
86-
self.agentic_config.update({"model": model_config["path"]})
87-
self.llm = CustomTransformers(self.agentic_config, self.producer_idx, generation_workers=self.async_producers)
88-
self.bot = TIRMathAgent(llm=self.llm, name=model_config["path"], system_message=TIR_SYSTEM)
81+
82+
def _run_agentic_pipeline(self, messages):
83+
"""
84+
Run the agentic pipeline to generate responses based on the input messages.
85+
This function should be implemented in subclasses.
86+
"""
87+
raise NotImplementedError
8988

9089
def rollout(self, **kwargs) -> Dict[str, torch.Tensor]:
9190
"""
@@ -110,9 +109,7 @@ def rollout(self, **kwargs) -> Dict[str, torch.Tensor]:
110109
}
111110
for i in range(self.num_generations):
112111
_messages = copy.deepcopy(messages)
113-
for response in self.bot.run(messages):
114-
continue
115-
_messages.extend(response)
112+
_messages = self._run_agentic_pipeline(_messages)
116113
response_input_ids = self.tokenizer.apply_chat_template(_messages, return_tensors="pt", tokenize=True)
117114
# truncate if too long
118115
response_input_ids = response_input_ids[:, : self.grpo_config["max_length"] - to_pad_left]
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from typing import Any, Dict
2+
3+
import ray
4+
from coati.distributed.agent.agentic import BaseAgenticProducer
5+
from coati.distributed.agent.langgraph_math_agentic_utils import CustomOpenAIAPILLM, LangChainCustomLLM, python
6+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
7+
from langgraph.checkpoint.memory import MemorySaver
8+
from langgraph.prebuilt import create_react_agent
9+
10+
11+
@ray.remote
12+
class LangGraphMathAgenticProducer(BaseAgenticProducer):
13+
"""
14+
Asyncronous version of the producer that uses vLLM for generation.
15+
This class is designed to generate agentic response
16+
"""
17+
18+
def __init__(
19+
self,
20+
producer_idx,
21+
num_producers,
22+
num_consumer_procs,
23+
num_episodes,
24+
batch_size,
25+
train_dataset_config,
26+
model_config,
27+
generate_config,
28+
async_producers,
29+
tokenizer_config=None,
30+
agentic_config=None,
31+
microbatch_size=1,
32+
backend="transformers",
33+
num_generations: int = 8,
34+
consumer_plugin_config=None,
35+
eval_dataset_config=None,
36+
eval_interval=-1, # disable evaluation
37+
grpo_config: Dict[str, Any] = None,
38+
eval_save_dir: str = "./eval",
39+
eval_generation_config={},
40+
project_name: str = None,
41+
run_name: str = None,
42+
wandb_group_name: str = None,
43+
log_rollout_interval: int = 20,
44+
rollout_log_file: str = "./rollout_log.jsonl",
45+
enable_profiling: bool = False,
46+
n_behind: int = 0,
47+
):
48+
assert microbatch_size == 1 # microbatch_size must be 1 for agentic producer
49+
assert batch_size == 1 # batch_size must be 1 for agentic producer
50+
super().__init__(
51+
producer_idx,
52+
num_producers,
53+
num_consumer_procs,
54+
num_episodes,
55+
batch_size,
56+
train_dataset_config,
57+
model_config,
58+
generate_config,
59+
async_producers,
60+
tokenizer_config,
61+
microbatch_size,
62+
backend,
63+
num_generations,
64+
consumer_plugin_config,
65+
eval_dataset_config=eval_dataset_config,
66+
eval_interval=eval_interval,
67+
grpo_config=grpo_config,
68+
eval_save_dir=eval_save_dir,
69+
eval_generation_config=eval_generation_config,
70+
project_name=project_name,
71+
run_name=run_name,
72+
wandb_group_name=wandb_group_name,
73+
log_rollout_interval=log_rollout_interval,
74+
rollout_log_file=rollout_log_file,
75+
enable_profiling=enable_profiling,
76+
n_behind=n_behind,
77+
)
78+
self.agentic_config = agentic_config
79+
self.agentic_config.pop("agentic_type", None)
80+
self.llm_client = CustomOpenAIAPILLM({"model": model_config["path"]}, producer_idx, self.async_producers)
81+
self.llm = LangChainCustomLLM(self.llm_client)
82+
# self.python_repl = PythonREPL()
83+
# repl_tool = Tool(
84+
# name="python_repl",
85+
# description="A Python shell. Use this to execute python commands. Input should be a valid python command. If you want to see the output of a value, you should print it out with `print(...)`.",
86+
# func=self.python_repl.run,
87+
# )
88+
# self.tools = [repl_tool]
89+
self.tools = [python]
90+
self.memory = MemorySaver()
91+
self.bot = create_react_agent(self.llm, self.tools, checkpointer=self.memory)
92+
93+
def _run_agentic_pipeline(self, messages):
94+
"""
95+
Run the agentic pipeline to generate responses based on the input messages using the LangGraph.
96+
"""
97+
assert (
98+
len(messages) == 2 and messages[0]["role"] == "system" and messages[1]["role"] == "user"
99+
), "Only support 1 system message and 1 user message as input."
100+
# inputs = messages
101+
for event in self.bot.stream(
102+
{"messages": [("system", messages[0]["content"]), ("user", "calculate the 1000th Fibonacci number")]},
103+
self.agentic_config,
104+
):
105+
continue
106+
breakpoint()
107+
108+
final_state = self.bot.get_state(self.agentic_config)
109+
transformer_messages = []
110+
for message in final_state[0]["messages"]:
111+
tool_calls = None
112+
if isinstance(message, SystemMessage):
113+
message.content
114+
elif isinstance(message, HumanMessage):
115+
message.content
116+
elif isinstance(message, AIMessage):
117+
message.content
118+
tool_calls = message.get("tool_calls", None) # [{"type": "function", "function": tool_call}]
119+
elif isinstance(message, ToolMessage):
120+
message.content
121+
122+
return transformer_messages
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# -------------------------------
2+
# 1. Define the Python tool
3+
# -------------------------------
4+
import copy
5+
import io
6+
import random
7+
import sys
8+
from typing import Dict, List
9+
10+
import ray
11+
import torch
12+
from langchain_core.language_models import BaseChatModel
13+
from langchain_core.messages import AIMessage
14+
from langchain_core.outputs.chat_generation import ChatGeneration
15+
from langchain_core.outputs.chat_result import ChatResult
16+
from langchain_core.prompts import PromptTemplate
17+
from langchain_core.tools import tool
18+
from langgraph.checkpoint.memory import MemorySaver
19+
from langgraph.prebuilt import create_react_agent
20+
from tool_calling_llm import ToolCallingLLM
21+
from transformers import AutoTokenizer
22+
23+
SYSTEM_PROMPT_TEMPLATE = """{task_description}. You have access to the following tools:
24+
25+
{tools}
26+
27+
Use the following format:
28+
29+
Question: the input question you must answer
30+
Thought: you should always think about what to do
31+
Action: the action to take, should be one of [{tool_names}]
32+
Action Input: the input to the action
33+
Observation: the result of the action
34+
... (this Thought/Action/Action Input/Observation can repeat N times)
35+
Thought: I now know the final answer
36+
Final Answer: the final answer to the original input question
37+
38+
Begin!
39+
40+
Question: {input}
41+
Thought:{agent_scratchpad}"""
42+
43+
SYSTEM_PROMPT = PromptTemplate.from_template(SYSTEM_PROMPT_TEMPLATE)
44+
45+
46+
class Capturing(list):
47+
"""Capture stdout prints inside exec()"""
48+
49+
def __enter__(self):
50+
self._stdout = sys.stdout
51+
sys.stdout = self._stringio = io.StringIO()
52+
return self
53+
54+
def __exit__(self, *args):
55+
self.extend(self._stringio.getvalue().splitlines())
56+
sys.stdout = self._stdout
57+
58+
59+
@tool
60+
def python(code: str) -> str:
61+
"""
62+
This function executes a string of Python code and returns the printed output.
63+
You need to print the output. Please import all libraries used in the code string.
64+
"""
65+
local_vars = {}
66+
with Capturing() as output:
67+
exec(code, {}, local_vars)
68+
if output == []:
69+
return "Error: No output printed from the code. Please ensure you print the output."
70+
return "\n".join(output)
71+
72+
73+
# -------------------------------
74+
# 2. Define a Custom API LLM wrapper
75+
# -------------------------------
76+
class CustomOpenAIAPILLM:
77+
def __init__(self, cfg: dict, producer_idx, generation_workers=None):
78+
self.producer_idx = producer_idx
79+
self.generation_workers = generation_workers
80+
self.load_balancer_idx = producer_idx % len(self.generation_workers)
81+
assert "model" in cfg, "Please specify the model name in the config"
82+
self.tokenizer = AutoTokenizer.from_pretrained(cfg["model"])
83+
self.role_mapping = {
84+
"system": "system",
85+
"user": "user",
86+
"assistant": "assistant",
87+
"human": "user",
88+
"tool": "tool",
89+
}
90+
91+
def invoke(self, messages: List[Dict[str, str]], **kwargs) -> str:
92+
"""
93+
messages: list of {"role": "user"/"assistant"/"system", "content": "..."}
94+
"""
95+
# load balancing
96+
load = [ray.get(generation_worker.get_producer_load.remote()) for generation_worker in self.generation_workers]
97+
min_load = min(load)
98+
candidates = [i for i, l in enumerate(load) if l == min_load]
99+
# random tie break
100+
self.load_balancer_idx = random.choice(candidates)
101+
generation_worker = self.generation_workers[self.load_balancer_idx]
102+
transformer_messages = []
103+
for message in messages:
104+
transformer_messages.append({"role": self.role_mapping[message.type], "content": message.content})
105+
input_ids = self.tokenizer.apply_chat_template(
106+
transformer_messages, return_tensors="pt", tokenize=True, add_generation_prompt=True
107+
)
108+
attention_mask = torch.ones_like(input_ids)
109+
rollouts = ray.get(generation_worker.generate.remote(input_ids, attention_mask, **kwargs))
110+
response = self.tokenizer.batch_decode(
111+
rollouts["input_ids"][0][:, input_ids.size(-1) :], skip_special_tokens=True
112+
)[0]
113+
return response
114+
115+
116+
class LangChainCustomLLM(ToolCallingLLM, BaseChatModel):
117+
client: CustomOpenAIAPILLM = None
118+
119+
def __init__(self, client: CustomOpenAIAPILLM):
120+
super().__init__()
121+
self.client = client
122+
123+
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
124+
# content = self.client.invoke([m.dict() for m in messages])
125+
# chat_result = ChatResult(
126+
# generations=[ChatGeneration(message=AIMessage(content=content))]
127+
# )
128+
print("messages:", messages)
129+
breakpoint()
130+
system_message, functions = self._generate_system_message_and_functions(kwargs)
131+
sample_params = {"stop": stop} if stop is not None else {}
132+
sample_params.update({k: v for k, v in kwargs.items() if k in ["temperature", "top_p", "top_k", "max_tokens"]})
133+
messages_ = copy.deepcopy(messages)
134+
messages_[0].content = messages_[0].content + "\n" + system_message.content
135+
response_message = self.client.invoke( # type: ignore[safe-super]
136+
[system_message] + messages, **{"sample_params": sample_params}
137+
)
138+
breakpoint()
139+
response = self._process_response(AIMessage(content=response_message), functions)
140+
return ChatResult(generations=[ChatGeneration(message=response)])
141+
142+
@property
143+
def _llm_type(self) -> str:
144+
return "custom-api-llm"
145+
146+
147+
# -------------------------------
148+
# 3. Build a ReAct Agent with LangGraph
149+
# -------------------------------
150+
def build_agent():
151+
# Wrap custom API LLM in LangChain-compatible interface
152+
153+
# Init LLM
154+
llm_client = CustomOpenAIAPILLM()
155+
llm = LangChainCustomLLM(llm_client)
156+
157+
# Tools
158+
tools = [python]
159+
160+
# Memory (optional)
161+
memory = MemorySaver()
162+
163+
# Build ReAct agent
164+
agent = create_react_agent(llm, tools, checkpointer=memory)
165+
return agent
166+
167+
168+
# -------------------------------
169+
# 4. Run the agent on a math problem
170+
# -------------------------------
171+
if __name__ == "__main__":
172+
agent = build_agent()
173+
174+
# Example math question
175+
user_input = "What is the least common multiple of 18 and 24? Use Python if needed."
176+
177+
config = {"configurable": {"thread_id": "math-1"}}
178+
for event in agent.stream({"messages": [("user", user_input)]}, config):
179+
if "agent" in event:
180+
print("Agent event:", event["agent"]["messages"][-1].content)
181+
elif "tools" in event:
182+
print("Tool event:", event["tools"]["messages"][-1].content)
183+
184+
final_state = agent.get_state(config)
185+
print("Final Answer:", final_state["messages"][-1].content)

0 commit comments

Comments
 (0)