Skip to content

Commit 803ef5f

Browse files
committed
add custom agentic producer
1 parent 673046a commit 803ef5f

12 files changed

+482
-378
lines changed
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
import copy
2+
import random
3+
import re
4+
from typing import Any, Dict
5+
from uuid import uuid4
6+
7+
import ray
8+
from coati.distributed.agent.base import BaseAgenticProducer
9+
from transformers import AutoTokenizer
10+
11+
DEFAULT_SYSTEM_MESSAGE = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The Assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <reason> </reason> and <answer> </answer> tags, respectively, i.e., <reason> reasoning process here </reason><answer> answer here </answer>."""
12+
13+
14+
@ray.remote
15+
class AgenticProducer(BaseAgenticProducer):
16+
"""
17+
Asyncronous version of the producer that uses vLLM for generation.
18+
This class is designed to generate agentic response
19+
20+
Please use the following SYSTEM message or a similar one for the agentic math model:
21+
'''A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
22+
The Assistant first thinks about the reasoning process in the mind and then provides the user with
23+
the answer. The reasoning process and answer are enclosed within <reason> </reason> and <answer>
24+
</answer> tags, respectively, i.e., <reason> reasoning process here </reason><answer> answer here </answer>.'''
25+
"""
26+
27+
def __init__(
28+
self,
29+
producer_idx,
30+
num_producers,
31+
num_consumer_procs,
32+
num_episodes,
33+
batch_size,
34+
train_dataset_config,
35+
model_config,
36+
generate_config,
37+
async_producers,
38+
tool_workers=[],
39+
tokenizer_config=None,
40+
agentic_config=None,
41+
microbatch_size=1,
42+
backend="transformers",
43+
num_generations: int = 8,
44+
consumer_plugin_config=None,
45+
eval_dataset_config=None,
46+
eval_interval=-1, # disable evaluation
47+
grpo_config: Dict[str, Any] = None,
48+
eval_save_dir: str = "./eval",
49+
eval_generation_config={},
50+
project_name: str = None,
51+
run_name: str = None,
52+
wandb_group_name: str = None,
53+
log_rollout_interval: int = 20,
54+
rollout_log_file: str = "./rollout_log.jsonl",
55+
enable_profiling: bool = False,
56+
n_behind: int = 0,
57+
):
58+
assert microbatch_size == 1 # microbatch_size must be 1 for agentic producer
59+
assert batch_size == 1 # batch_size must be 1 for agentic producer
60+
super().__init__(
61+
producer_idx,
62+
num_producers,
63+
num_consumer_procs,
64+
num_episodes,
65+
batch_size,
66+
train_dataset_config,
67+
model_config,
68+
generate_config,
69+
async_producers,
70+
tokenizer_config,
71+
microbatch_size,
72+
backend,
73+
num_generations,
74+
consumer_plugin_config,
75+
eval_dataset_config=eval_dataset_config,
76+
eval_interval=eval_interval,
77+
grpo_config=grpo_config,
78+
eval_save_dir=eval_save_dir,
79+
eval_generation_config=eval_generation_config,
80+
project_name=project_name,
81+
run_name=run_name,
82+
wandb_group_name=wandb_group_name,
83+
log_rollout_interval=log_rollout_interval,
84+
rollout_log_file=rollout_log_file,
85+
enable_profiling=enable_profiling,
86+
n_behind=n_behind,
87+
)
88+
self.tool_workers = tool_workers
89+
self.agentic_config = model_config if not agentic_config else agentic_config
90+
self.agentic_config.update({"model": model_config["path"]})
91+
tokenizer_path = None
92+
if tokenizer_config and "path" in tokenizer_config:
93+
tokenizer_path = tokenizer_config["path"]
94+
elif "path" in model_config:
95+
tokenizer_path = model_config["path"]
96+
assert tokenizer_path is not None, "Tokenizer path must be provided either in tokenizer_config or model_config."
97+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
98+
self.tools_schema = []
99+
self.tool_call_budget = self.agentic_config.get("tool_call_budget", 3)
100+
self.llm_call_budget = self.agentic_config.get("llm_call_budget", 10)
101+
self.async_llm_engine_map = {}
102+
self._get_tools()
103+
104+
def _get_tools(self):
105+
"""
106+
SYSTEM message for the agentic math model. Reference: r-start2 paper https://arxiv.org/pdf/2508.20722
107+
"""
108+
tools = ray.get(self.tool_workers[0].list_tools.remote())
109+
tool_descriptions = {tool: ray.get(self.tool_workers[0].get_tool_description.remote(tool)) for tool in tools}
110+
tool_arg_schemas = {tool: ray.get(self.tool_workers[0].get_args_schema.remote(tool)) for tool in tools}
111+
self.tools = []
112+
for tool in tools:
113+
tool_schema = {"name": tool, "description": tool_descriptions[tool], "parameters": tool_arg_schemas[tool]}
114+
self.tools.append(tool_schema)
115+
116+
def _build_prompt(
117+
self, messages, add_generation_prompt: bool = True, return_dict=True, return_tensors="pt"
118+
) -> dict:
119+
"""
120+
Build the prompt for the agentic math model.
121+
"""
122+
return self.tokenizer.apply_chat_template(
123+
messages,
124+
tools=self.tools,
125+
add_generation_prompt=add_generation_prompt,
126+
return_dict=return_dict,
127+
return_tensors=return_tensors,
128+
)
129+
130+
def _parse_response(self, response: str) -> Dict[str, Any]:
131+
"""
132+
Parse the response from the agentic math model.
133+
134+
Sample Assistant Response:
135+
The tool indicates that Singapore’s weather today is 31°C with partly cloudy skies and light showers. \\\\boxed{It is warm and slightly rainy in Singapore today.}<|im_end|>
136+
137+
Sample Assistant Response with Tool Call:
138+
To answer this, I will check both the weather and the timezone for New York.\n<tool_call>\n{"name": "get_weather", "arguments": {"location": "New York"}}\n</tool_call>\n<tool_call>\n{"name": "get_timezone", "arguments": {"location": "New York"}}\n</tool_call>
139+
140+
Sample Ouput:
141+
{
142+
"role": "assistant",
143+
"content": "Let me check the current weather in Singapore by calling the weather tool.",
144+
"tool_calls": [
145+
{
146+
"function": {
147+
"name": "get_weather",
148+
"arguments": {
149+
"location": "New York"
150+
}
151+
}
152+
},
153+
{
154+
"function": {
155+
"name": "get_timezone",
156+
"arguments": {
157+
"location": "New York"
158+
}
159+
}
160+
}
161+
]
162+
},
163+
{
164+
"role": "assistant",
165+
"content": "The tool indicates that Singapore’s weather today is 31°C with partly cloudy skies and light showers. \\\\boxed{It is warm and slightly rainy in Singapore today.}"
166+
}
167+
"""
168+
# split by <im_end|>
169+
response_chunked = response.split("<|im_end|>")[0].strip()
170+
if "<tool_call>" in response_chunked:
171+
assistant_content = response_chunked.split("<tool_call>")[0].strip()
172+
tool_call_sections = response_chunked[response_chunked.find("<tool_call>") :].strip()
173+
# extract all tool calls
174+
tool_calls = []
175+
pattern = "<tool_call>(.*?)</tool_call>"
176+
matches = re.findall(pattern, tool_call_sections, re.DOTALL)
177+
for match in matches:
178+
try:
179+
tool_call = eval(match.strip())
180+
name = tool_call["name"]
181+
arguments = tool_call["arguments"]
182+
tool_calls.append({"function": {"name": name, "arguments": arguments}})
183+
except Exception as e:
184+
print(f"Failed to parse tool call: {match.strip()}. Error: {e}")
185+
tool_calls.append({"function": {"name": "return_parsing_error", "arguments": {}}})
186+
else:
187+
assistant_content = response_chunked
188+
tool_calls = []
189+
assistant_message = {"role": "assistant", "content": assistant_content}
190+
if tool_calls:
191+
assistant_message["tool_calls"] = tool_calls
192+
return assistant_message
193+
194+
def _select_tool_worker(self) -> ray.actor.ActorHandle:
195+
"""
196+
Select a tool worker based on the current load.
197+
"""
198+
loads = ray.get([worker.get_load.remote() for worker in self.tool_workers])
199+
min_load = min(loads)
200+
candidates = [i for i, l in enumerate(loads) if l == min_load]
201+
selected_idx = random.choice(candidates) # random tie break
202+
ray.get(self.tool_workers[selected_idx].increase_load.remote())
203+
return self.tool_workers[selected_idx]
204+
205+
def _select_async_producer(self, request_id) -> ray.actor.ActorHandle:
206+
"""
207+
Select an async producer based on the current load.
208+
"""
209+
# use the last used async producer if exists to reuse kv cache (as vllm use paged kv cache,
210+
# it will reuse most of the kv cache pages without recomputation)
211+
if request_id in self.async_llm_engine_map:
212+
return self.async_producers[self.async_llm_engine_map[request_id]]
213+
# otherwise select the least loaded async producer
214+
loads = ray.get([proc.get_producer_load.remote() for proc in self.async_producers])
215+
min_load = min(loads)
216+
candidates = [i for i, l in enumerate(loads) if l == min_load]
217+
selected_idx = random.choice(candidates) # random tie break
218+
self.async_llm_engine_map[request_id] = selected_idx
219+
return self.async_producers[selected_idx]
220+
221+
def _run_agentic_pipeline(self, messages):
222+
"""
223+
Run the agentic pipeline to generate responses based on the input messages.
224+
"""
225+
tool_call_count = 0
226+
llm_call_count = 0
227+
num_prompt_tokens = 0
228+
request_id = str(uuid4())
229+
logprobs = None
230+
while True:
231+
# tokenize the messages
232+
if llm_call_count > self.llm_call_budget:
233+
print(f"LLM call budget exceeded: {llm_call_count} > {self.llm_call_budget}. Stopping.")
234+
del self.async_llm_engine_map[request_id]
235+
while messages[-1]["role"] == "tool":
236+
messages.pop()
237+
return messages, logprobs
238+
inputs = self._build_prompt(messages, return_dict=True, return_tensors="pt")
239+
if num_prompt_tokens == 0:
240+
num_prompt_tokens = inputs["input_ids"].size(-1)
241+
if inputs["input_ids"].size(-1) - num_prompt_tokens > self.generate_config["max_tokens"]:
242+
print(
243+
f"Max tokens exceeded: Current have generated {inputs['input_ids'].size(-1) - num_prompt_tokens} tokens > {self.generate_config.get('max_tokens', 512)}. Stopping."
244+
)
245+
del self.async_llm_engine_map[request_id]
246+
while messages[-1]["role"] == "tool":
247+
messages.pop()
248+
return messages, logprobs
249+
async_producer = self._select_async_producer(request_id=request_id)
250+
agentic_generate_config = copy.deepcopy(self.generate_config)
251+
agentic_generate_config["max_tokens"] = self.agentic_config.get("max_tokens", 2048)
252+
response = ray.get(
253+
async_producer.generate.remote(
254+
inputs["input_ids"],
255+
inputs["attention_mask"],
256+
request_id=request_id,
257+
**agentic_generate_config,
258+
)
259+
)
260+
llm_call_count += 1
261+
response_input_ids = response["input_ids"]
262+
logprobs = response["action_log_probs"]
263+
response_text = self.tokenizer.decode(
264+
response_input_ids[0][0][inputs["input_ids"].size(-1) :], skip_special_tokens=False
265+
)
266+
assistant_message = self._parse_response(response_text)
267+
messages.append(assistant_message)
268+
if "tool_calls" in assistant_message:
269+
if tool_call_count > self.tool_call_budget:
270+
print(f"Tool call budget exceeded: {tool_call_count} > {self.tool_call_budget}. Stopping.")
271+
del self.async_llm_engine_map[request_id]
272+
return messages, logprobs
273+
tool_call_count += len(assistant_message["tool_calls"])
274+
handlers = []
275+
for tool_call in assistant_message["tool_calls"]:
276+
# select a tool worker to execute the tool call
277+
tool_worker = self._select_tool_worker()
278+
handler = tool_worker.call.remote(tool_call["function"]["name"], tool_call["function"]["arguments"])
279+
handlers.append(handler)
280+
tool_results = ray.get(handlers)
281+
for tool_call, tool_result in zip(assistant_message["tool_calls"], tool_results):
282+
tool_message = {"role": "tool", "content": str(tool_result)}
283+
messages.append(tool_message)
284+
else:
285+
# no further tool call, return the messages
286+
del self.async_llm_engine_map[request_id]
287+
return messages, logprobs

applications/ColossalChat/coati/distributed/agent/agentic.py renamed to applications/ColossalChat/coati/distributed/agent/base.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import copy
22
import json
3+
from concurrent.futures import ThreadPoolExecutor
34
from typing import Any, Dict
45

56
import ray
@@ -86,16 +87,25 @@ def _run_agentic_pipeline(self, messages):
8687
"""
8788
raise NotImplementedError
8889

90+
def _build_prompt(
91+
self, messages, add_generation_prompt: bool = True, return_dict=True, return_tensors="pt"
92+
) -> dict:
93+
"""
94+
Build the prompt from the input messages.
95+
This function should be implemented in subclasses.
96+
"""
97+
raise NotImplementedError
98+
8999
def rollout(self, **kwargs) -> Dict[str, torch.Tensor]:
90100
"""
91101
Rollout function to generate responses for the input, for example, using LLM or agentic pipeline.
92102
This function should be implemented in subclasses.
93103
"""
94104
assert len(kwargs["messages"]) == 1, "Only support batch size of 1 for agentic producer"
95105
messages = kwargs["messages"][0]
96-
prompt_input_ids = self.tokenizer.apply_chat_template(
97-
messages, return_tensors="pt", tokenize=True, add_generation_prompt=True
98-
)
106+
prompt_input_ids = self._build_prompt(
107+
messages, return_dict=True, return_tensors="pt", add_generation_prompt=True
108+
)["input_ids"]
99109
# add left padding
100110
prompt_length = prompt_input_ids.shape[1]
101111
max_prompt_length = self.train_dataset_config["max_length"]
@@ -107,10 +117,16 @@ def rollout(self, **kwargs) -> Dict[str, torch.Tensor]:
107117
"action_log_probs": [],
108118
"response_idx": [],
109119
}
120+
with ThreadPoolExecutor(max_workers=self.num_generations) as executor:
121+
results = list(
122+
executor.map(self._run_agentic_pipeline, [copy.deepcopy(messages) for _ in range(self.num_generations)])
123+
)
124+
110125
for i in range(self.num_generations):
111-
_messages = copy.deepcopy(messages)
112-
_messages = self._run_agentic_pipeline(_messages)
113-
response_input_ids = self.tokenizer.apply_chat_template(_messages, return_tensors="pt", tokenize=True)
126+
_messages, logprobs = results[i]
127+
response_input_ids = self._build_prompt(
128+
_messages, return_dict=True, return_tensors="pt", add_generation_prompt=False
129+
)["input_ids"]
114130
# truncate if too long
115131
response_input_ids = response_input_ids[:, : self.grpo_config["max_length"] - to_pad_left]
116132
# add left right padding
@@ -127,9 +143,14 @@ def rollout(self, **kwargs) -> Dict[str, torch.Tensor]:
127143
) # [1, max_length-prompt_length]
128144
rollouts["attention_mask"].append(attention_mask)
129145
rollouts["action_mask"].append(action_mask)
130-
rollouts["action_log_probs"].append(
131-
torch.ones(size=(1, self.grpo_config["max_length"] - max_prompt_length))
132-
) # dummy log probs
146+
truncated_logprobs = logprobs[:, :, prompt_length : prompt_length + self.generate_config["max_tokens"]]
147+
logprobs_padded = torch.nn.functional.pad(
148+
truncated_logprobs,
149+
(0, self.generate_config["max_tokens"] - truncated_logprobs.size(-1)),
150+
"constant",
151+
value=0.0,
152+
) # [1, max_new_tokens]
153+
rollouts["action_log_probs"].append(logprobs_padded[0])
133154
rollouts["response_idx"].append(
134155
torch.tensor(
135156
[
@@ -141,7 +162,6 @@ def rollout(self, **kwargs) -> Dict[str, torch.Tensor]:
141162
)
142163
) # [1, 2]
143164
rollouts["input_ids"].append(input_ids)
144-
# breakpoint()
145165
rollouts = {k: torch.cat(v, dim=0).unsqueeze(0) for k, v in rollouts.items()} # [num_generations, ...]
146166
rollouts["temperature"] = torch.tensor([self.agentic_config.get("temperature", 1.0)])
147167
if hasattr(self, "rollout_log_file") and self.producer_idx == 0 and not self.eval_mode:

0 commit comments

Comments
 (0)