|
| 1 | +import copy |
| 2 | +import random |
| 3 | +import re |
| 4 | +from typing import Any, Dict |
| 5 | +from uuid import uuid4 |
| 6 | + |
| 7 | +import ray |
| 8 | +from coati.distributed.agent.base import BaseAgenticProducer |
| 9 | +from transformers import AutoTokenizer |
| 10 | + |
| 11 | +DEFAULT_SYSTEM_MESSAGE = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The Assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <reason> </reason> and <answer> </answer> tags, respectively, i.e., <reason> reasoning process here </reason><answer> answer here </answer>.""" |
| 12 | + |
| 13 | + |
| 14 | +@ray.remote |
| 15 | +class AgenticProducer(BaseAgenticProducer): |
| 16 | + """ |
| 17 | + Asyncronous version of the producer that uses vLLM for generation. |
| 18 | + This class is designed to generate agentic response |
| 19 | +
|
| 20 | + Please use the following SYSTEM message or a similar one for the agentic math model: |
| 21 | + '''A conversation between User and Assistant. The user asks a question, and the Assistant solves it. |
| 22 | + The Assistant first thinks about the reasoning process in the mind and then provides the user with |
| 23 | + the answer. The reasoning process and answer are enclosed within <reason> </reason> and <answer> |
| 24 | + </answer> tags, respectively, i.e., <reason> reasoning process here </reason><answer> answer here </answer>.''' |
| 25 | + """ |
| 26 | + |
| 27 | + def __init__( |
| 28 | + self, |
| 29 | + producer_idx, |
| 30 | + num_producers, |
| 31 | + num_consumer_procs, |
| 32 | + num_episodes, |
| 33 | + batch_size, |
| 34 | + train_dataset_config, |
| 35 | + model_config, |
| 36 | + generate_config, |
| 37 | + async_producers, |
| 38 | + tool_workers=[], |
| 39 | + tokenizer_config=None, |
| 40 | + agentic_config=None, |
| 41 | + microbatch_size=1, |
| 42 | + backend="transformers", |
| 43 | + num_generations: int = 8, |
| 44 | + consumer_plugin_config=None, |
| 45 | + eval_dataset_config=None, |
| 46 | + eval_interval=-1, # disable evaluation |
| 47 | + grpo_config: Dict[str, Any] = None, |
| 48 | + eval_save_dir: str = "./eval", |
| 49 | + eval_generation_config={}, |
| 50 | + project_name: str = None, |
| 51 | + run_name: str = None, |
| 52 | + wandb_group_name: str = None, |
| 53 | + log_rollout_interval: int = 20, |
| 54 | + rollout_log_file: str = "./rollout_log.jsonl", |
| 55 | + enable_profiling: bool = False, |
| 56 | + n_behind: int = 0, |
| 57 | + ): |
| 58 | + assert microbatch_size == 1 # microbatch_size must be 1 for agentic producer |
| 59 | + assert batch_size == 1 # batch_size must be 1 for agentic producer |
| 60 | + super().__init__( |
| 61 | + producer_idx, |
| 62 | + num_producers, |
| 63 | + num_consumer_procs, |
| 64 | + num_episodes, |
| 65 | + batch_size, |
| 66 | + train_dataset_config, |
| 67 | + model_config, |
| 68 | + generate_config, |
| 69 | + async_producers, |
| 70 | + tokenizer_config, |
| 71 | + microbatch_size, |
| 72 | + backend, |
| 73 | + num_generations, |
| 74 | + consumer_plugin_config, |
| 75 | + eval_dataset_config=eval_dataset_config, |
| 76 | + eval_interval=eval_interval, |
| 77 | + grpo_config=grpo_config, |
| 78 | + eval_save_dir=eval_save_dir, |
| 79 | + eval_generation_config=eval_generation_config, |
| 80 | + project_name=project_name, |
| 81 | + run_name=run_name, |
| 82 | + wandb_group_name=wandb_group_name, |
| 83 | + log_rollout_interval=log_rollout_interval, |
| 84 | + rollout_log_file=rollout_log_file, |
| 85 | + enable_profiling=enable_profiling, |
| 86 | + n_behind=n_behind, |
| 87 | + ) |
| 88 | + self.tool_workers = tool_workers |
| 89 | + self.agentic_config = model_config if not agentic_config else agentic_config |
| 90 | + self.agentic_config.update({"model": model_config["path"]}) |
| 91 | + tokenizer_path = None |
| 92 | + if tokenizer_config and "path" in tokenizer_config: |
| 93 | + tokenizer_path = tokenizer_config["path"] |
| 94 | + elif "path" in model_config: |
| 95 | + tokenizer_path = model_config["path"] |
| 96 | + assert tokenizer_path is not None, "Tokenizer path must be provided either in tokenizer_config or model_config." |
| 97 | + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) |
| 98 | + self.tools_schema = [] |
| 99 | + self.tool_call_budget = self.agentic_config.get("tool_call_budget", 3) |
| 100 | + self.llm_call_budget = self.agentic_config.get("llm_call_budget", 10) |
| 101 | + self.async_llm_engine_map = {} |
| 102 | + self._get_tools() |
| 103 | + |
| 104 | + def _get_tools(self): |
| 105 | + """ |
| 106 | + SYSTEM message for the agentic math model. Reference: r-start2 paper https://arxiv.org/pdf/2508.20722 |
| 107 | + """ |
| 108 | + tools = ray.get(self.tool_workers[0].list_tools.remote()) |
| 109 | + tool_descriptions = {tool: ray.get(self.tool_workers[0].get_tool_description.remote(tool)) for tool in tools} |
| 110 | + tool_arg_schemas = {tool: ray.get(self.tool_workers[0].get_args_schema.remote(tool)) for tool in tools} |
| 111 | + self.tools = [] |
| 112 | + for tool in tools: |
| 113 | + tool_schema = {"name": tool, "description": tool_descriptions[tool], "parameters": tool_arg_schemas[tool]} |
| 114 | + self.tools.append(tool_schema) |
| 115 | + |
| 116 | + def _build_prompt( |
| 117 | + self, messages, add_generation_prompt: bool = True, return_dict=True, return_tensors="pt" |
| 118 | + ) -> dict: |
| 119 | + """ |
| 120 | + Build the prompt for the agentic math model. |
| 121 | + """ |
| 122 | + return self.tokenizer.apply_chat_template( |
| 123 | + messages, |
| 124 | + tools=self.tools, |
| 125 | + add_generation_prompt=add_generation_prompt, |
| 126 | + return_dict=return_dict, |
| 127 | + return_tensors=return_tensors, |
| 128 | + ) |
| 129 | + |
| 130 | + def _parse_response(self, response: str) -> Dict[str, Any]: |
| 131 | + """ |
| 132 | + Parse the response from the agentic math model. |
| 133 | +
|
| 134 | + Sample Assistant Response: |
| 135 | + The tool indicates that Singapore’s weather today is 31°C with partly cloudy skies and light showers. \\\\boxed{It is warm and slightly rainy in Singapore today.}<|im_end|> |
| 136 | +
|
| 137 | + Sample Assistant Response with Tool Call: |
| 138 | + To answer this, I will check both the weather and the timezone for New York.\n<tool_call>\n{"name": "get_weather", "arguments": {"location": "New York"}}\n</tool_call>\n<tool_call>\n{"name": "get_timezone", "arguments": {"location": "New York"}}\n</tool_call> |
| 139 | +
|
| 140 | + Sample Ouput: |
| 141 | + { |
| 142 | + "role": "assistant", |
| 143 | + "content": "Let me check the current weather in Singapore by calling the weather tool.", |
| 144 | + "tool_calls": [ |
| 145 | + { |
| 146 | + "function": { |
| 147 | + "name": "get_weather", |
| 148 | + "arguments": { |
| 149 | + "location": "New York" |
| 150 | + } |
| 151 | + } |
| 152 | + }, |
| 153 | + { |
| 154 | + "function": { |
| 155 | + "name": "get_timezone", |
| 156 | + "arguments": { |
| 157 | + "location": "New York" |
| 158 | + } |
| 159 | + } |
| 160 | + } |
| 161 | + ] |
| 162 | + }, |
| 163 | + { |
| 164 | + "role": "assistant", |
| 165 | + "content": "The tool indicates that Singapore’s weather today is 31°C with partly cloudy skies and light showers. \\\\boxed{It is warm and slightly rainy in Singapore today.}" |
| 166 | + } |
| 167 | + """ |
| 168 | + # split by <im_end|> |
| 169 | + response_chunked = response.split("<|im_end|>")[0].strip() |
| 170 | + if "<tool_call>" in response_chunked: |
| 171 | + assistant_content = response_chunked.split("<tool_call>")[0].strip() |
| 172 | + tool_call_sections = response_chunked[response_chunked.find("<tool_call>") :].strip() |
| 173 | + # extract all tool calls |
| 174 | + tool_calls = [] |
| 175 | + pattern = "<tool_call>(.*?)</tool_call>" |
| 176 | + matches = re.findall(pattern, tool_call_sections, re.DOTALL) |
| 177 | + for match in matches: |
| 178 | + try: |
| 179 | + tool_call = eval(match.strip()) |
| 180 | + name = tool_call["name"] |
| 181 | + arguments = tool_call["arguments"] |
| 182 | + tool_calls.append({"function": {"name": name, "arguments": arguments}}) |
| 183 | + except Exception as e: |
| 184 | + print(f"Failed to parse tool call: {match.strip()}. Error: {e}") |
| 185 | + tool_calls.append({"function": {"name": "return_parsing_error", "arguments": {}}}) |
| 186 | + else: |
| 187 | + assistant_content = response_chunked |
| 188 | + tool_calls = [] |
| 189 | + assistant_message = {"role": "assistant", "content": assistant_content} |
| 190 | + if tool_calls: |
| 191 | + assistant_message["tool_calls"] = tool_calls |
| 192 | + return assistant_message |
| 193 | + |
| 194 | + def _select_tool_worker(self) -> ray.actor.ActorHandle: |
| 195 | + """ |
| 196 | + Select a tool worker based on the current load. |
| 197 | + """ |
| 198 | + loads = ray.get([worker.get_load.remote() for worker in self.tool_workers]) |
| 199 | + min_load = min(loads) |
| 200 | + candidates = [i for i, l in enumerate(loads) if l == min_load] |
| 201 | + selected_idx = random.choice(candidates) # random tie break |
| 202 | + ray.get(self.tool_workers[selected_idx].increase_load.remote()) |
| 203 | + return self.tool_workers[selected_idx] |
| 204 | + |
| 205 | + def _select_async_producer(self, request_id) -> ray.actor.ActorHandle: |
| 206 | + """ |
| 207 | + Select an async producer based on the current load. |
| 208 | + """ |
| 209 | + # use the last used async producer if exists to reuse kv cache (as vllm use paged kv cache, |
| 210 | + # it will reuse most of the kv cache pages without recomputation) |
| 211 | + if request_id in self.async_llm_engine_map: |
| 212 | + return self.async_producers[self.async_llm_engine_map[request_id]] |
| 213 | + # otherwise select the least loaded async producer |
| 214 | + loads = ray.get([proc.get_producer_load.remote() for proc in self.async_producers]) |
| 215 | + min_load = min(loads) |
| 216 | + candidates = [i for i, l in enumerate(loads) if l == min_load] |
| 217 | + selected_idx = random.choice(candidates) # random tie break |
| 218 | + self.async_llm_engine_map[request_id] = selected_idx |
| 219 | + return self.async_producers[selected_idx] |
| 220 | + |
| 221 | + def _run_agentic_pipeline(self, messages): |
| 222 | + """ |
| 223 | + Run the agentic pipeline to generate responses based on the input messages. |
| 224 | + """ |
| 225 | + tool_call_count = 0 |
| 226 | + llm_call_count = 0 |
| 227 | + num_prompt_tokens = 0 |
| 228 | + request_id = str(uuid4()) |
| 229 | + logprobs = None |
| 230 | + while True: |
| 231 | + # tokenize the messages |
| 232 | + if llm_call_count > self.llm_call_budget: |
| 233 | + print(f"LLM call budget exceeded: {llm_call_count} > {self.llm_call_budget}. Stopping.") |
| 234 | + del self.async_llm_engine_map[request_id] |
| 235 | + while messages[-1]["role"] == "tool": |
| 236 | + messages.pop() |
| 237 | + return messages, logprobs |
| 238 | + inputs = self._build_prompt(messages, return_dict=True, return_tensors="pt") |
| 239 | + if num_prompt_tokens == 0: |
| 240 | + num_prompt_tokens = inputs["input_ids"].size(-1) |
| 241 | + if inputs["input_ids"].size(-1) - num_prompt_tokens > self.generate_config["max_tokens"]: |
| 242 | + print( |
| 243 | + f"Max tokens exceeded: Current have generated {inputs['input_ids'].size(-1) - num_prompt_tokens} tokens > {self.generate_config.get('max_tokens', 512)}. Stopping." |
| 244 | + ) |
| 245 | + del self.async_llm_engine_map[request_id] |
| 246 | + while messages[-1]["role"] == "tool": |
| 247 | + messages.pop() |
| 248 | + return messages, logprobs |
| 249 | + async_producer = self._select_async_producer(request_id=request_id) |
| 250 | + agentic_generate_config = copy.deepcopy(self.generate_config) |
| 251 | + agentic_generate_config["max_tokens"] = self.agentic_config.get("max_tokens", 2048) |
| 252 | + response = ray.get( |
| 253 | + async_producer.generate.remote( |
| 254 | + inputs["input_ids"], |
| 255 | + inputs["attention_mask"], |
| 256 | + request_id=request_id, |
| 257 | + **agentic_generate_config, |
| 258 | + ) |
| 259 | + ) |
| 260 | + llm_call_count += 1 |
| 261 | + response_input_ids = response["input_ids"] |
| 262 | + logprobs = response["action_log_probs"] |
| 263 | + response_text = self.tokenizer.decode( |
| 264 | + response_input_ids[0][0][inputs["input_ids"].size(-1) :], skip_special_tokens=False |
| 265 | + ) |
| 266 | + assistant_message = self._parse_response(response_text) |
| 267 | + messages.append(assistant_message) |
| 268 | + if "tool_calls" in assistant_message: |
| 269 | + if tool_call_count > self.tool_call_budget: |
| 270 | + print(f"Tool call budget exceeded: {tool_call_count} > {self.tool_call_budget}. Stopping.") |
| 271 | + del self.async_llm_engine_map[request_id] |
| 272 | + return messages, logprobs |
| 273 | + tool_call_count += len(assistant_message["tool_calls"]) |
| 274 | + handlers = [] |
| 275 | + for tool_call in assistant_message["tool_calls"]: |
| 276 | + # select a tool worker to execute the tool call |
| 277 | + tool_worker = self._select_tool_worker() |
| 278 | + handler = tool_worker.call.remote(tool_call["function"]["name"], tool_call["function"]["arguments"]) |
| 279 | + handlers.append(handler) |
| 280 | + tool_results = ray.get(handlers) |
| 281 | + for tool_call, tool_result in zip(assistant_message["tool_calls"], tool_results): |
| 282 | + tool_message = {"role": "tool", "content": str(tool_result)} |
| 283 | + messages.append(tool_message) |
| 284 | + else: |
| 285 | + # no further tool call, return the messages |
| 286 | + del self.async_llm_engine_map[request_id] |
| 287 | + return messages, logprobs |
0 commit comments