diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py
index 16fd385bad30..1283b9be0e36 100755
--- a/applications/ColossalChat/coati/dataset/loader.py
+++ b/applications/ColossalChat/coati/dataset/loader.py
@@ -4,6 +4,7 @@
Dataloader for sft, dpo, ppo
"""
+import copy
import os
from dataclasses import dataclass
from typing import Dict, Iterator, List, Optional, Sequence, Union
@@ -423,7 +424,9 @@ class RawConversationDataset(Dataset):
Each instance is a dictionary with fields `system`, `roles`, `messages`, `offset`, `sep_style`, `seps`.
"""
- def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int, system_prompt: str) -> None:
+ def __init__(
+ self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int, system_prompt: str, tokenize=True
+ ) -> None:
self.tokenizer = tokenizer
self.raw_texts = []
with jsonlines.open(input_file) as f:
@@ -432,30 +435,50 @@ def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length:
self.tokenized_texts = [None] * len(self.raw_texts)
self.max_length = max_length
self.system_prompt = system_prompt
+ self.tokenize = tokenize
def __len__(self) -> int:
return len(self.raw_texts)
def __getitem__(self, index: int):
- if self.tokenized_texts[index] is None:
- message = self.raw_texts[index]
- tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt)
- self.tokenized_texts[index] = dict(tokens)
- return self.tokenized_texts[index]
+ if self.tokenize:
+ if self.tokenized_texts[index] is None:
+ message = self.raw_texts[index]
+ tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt)
+ self.tokenized_texts[index] = dict(tokens)
+ return self.tokenized_texts[index]
+ else:
+ chat = copy.deepcopy(self.raw_texts[index])
+ chat["messages"] = [{"role": "system", "content": self.system_prompt}, chat["messages"]]
+ return chat
def collate_fn_grpo(batch):
- input_ids = [item["input_ids"] for item in batch]
- attention_mask = [item["attention_mask"] for item in batch]
- labels = [item["labels"] for item in batch]
- # Assume input_ids, attention_mask, labels are already of the same length,
- # otherwise use pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
- input_ids = torch.stack(input_ids)
- attention_mask = torch.stack(attention_mask)
- labels = torch.stack(labels)
- ret = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
- if "test_cases" in batch[0]:
- ret["test_cases"] = [item["test_cases"] for item in batch]
- if "gt_answer" in batch[0]:
- ret["gt_answer"] = [item["gt_answer"] for item in batch]
- return ret
+ if "input_ids" in batch[0]:
+ # tokenized format
+ input_ids = [item["input_ids"] for item in batch]
+ attention_mask = [item["attention_mask"] for item in batch]
+ labels = [item["labels"] for item in batch]
+ # Assume input_ids, attention_mask, labels are already of the same length,
+ # otherwise use pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
+ input_ids = torch.stack(input_ids)
+ attention_mask = torch.stack(attention_mask)
+ labels = torch.stack(labels)
+ ret = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
+ if "test_cases" in batch[0]:
+ ret["test_cases"] = [item["test_cases"] for item in batch]
+ if "gt_answer" in batch[0]:
+ ret["gt_answer"] = [item["gt_answer"] for item in batch]
+ return ret
+ elif "messages" in batch[0]:
+ # vllm format
+ ret = {
+ "messages": [item["messages"] for item in batch],
+ }
+ if "test_cases" in batch[0]:
+ ret["test_cases"] = [item["test_cases"] for item in batch]
+ if "gt_answer" in batch[0]:
+ ret["gt_answer"] = [item["gt_answer"] for item in batch]
+ return ret
+ else:
+ raise ValueError("Unsupported batch format")
diff --git a/applications/ColossalChat/coati/distributed/agent/=0.3, b/applications/ColossalChat/coati/distributed/agent/=0.3,
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/applications/ColossalChat/coati/distributed/agent/agentic_producer.py b/applications/ColossalChat/coati/distributed/agent/agentic_producer.py
new file mode 100644
index 000000000000..bd2f6e56d91f
--- /dev/null
+++ b/applications/ColossalChat/coati/distributed/agent/agentic_producer.py
@@ -0,0 +1,277 @@
+import copy
+import re
+from typing import Any, Dict
+from uuid import uuid4
+
+import ray
+from coati.distributed.agent.base import BaseAgenticProducer
+
+DEFAULT_SYSTEM_MESSAGE = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The Assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here ."""
+
+
+@ray.remote
+class AgenticProducer(BaseAgenticProducer):
+ """
+ Asyncronous version of the producer that uses vLLM for generation.
+ This class is designed to generate agentic response
+
+ Please use the following SYSTEM message or a similar one for the agentic math model:
+ '''A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
+ The Assistant first thinks about the reasoning process in the mind and then provides the user with
+ the answer. The reasoning process and answer are enclosed within and
+ tags, respectively, i.e., reasoning process here answer here .'''
+ """
+
+ def __init__(
+ self,
+ producer_idx,
+ num_producers,
+ num_consumer_procs,
+ num_episodes,
+ batch_size,
+ train_dataset_config,
+ model_config,
+ generate_config,
+ async_producers,
+ tool_workers=[],
+ tokenizer_config=None,
+ agentic_config=None,
+ microbatch_size=1,
+ backend="transformers",
+ num_generations: int = 8,
+ consumer_plugin_config=None,
+ eval_dataset_config=None,
+ eval_interval=-1, # disable evaluation
+ grpo_config: Dict[str, Any] = None,
+ eval_save_dir: str = "./eval",
+ eval_generation_config={},
+ project_name: str = None,
+ run_name: str = None,
+ wandb_group_name: str = None,
+ log_rollout_interval: int = 20,
+ rollout_log_file: str = "./rollout_log.jsonl",
+ enable_profiling: bool = False,
+ load_balancer=None,
+ n_behind: int = 0,
+ ):
+ assert microbatch_size == 1 # microbatch_size must be 1 for agentic producer
+ assert batch_size == 1 # batch_size must be 1 for agentic producer
+ super().__init__(
+ producer_idx,
+ num_producers,
+ num_consumer_procs,
+ num_episodes,
+ batch_size,
+ train_dataset_config,
+ model_config,
+ generate_config,
+ async_producers,
+ tokenizer_config,
+ microbatch_size,
+ backend,
+ num_generations,
+ consumer_plugin_config,
+ eval_dataset_config=eval_dataset_config,
+ eval_interval=eval_interval,
+ grpo_config=grpo_config,
+ eval_save_dir=eval_save_dir,
+ eval_generation_config=eval_generation_config,
+ project_name=project_name,
+ run_name=run_name,
+ wandb_group_name=wandb_group_name,
+ log_rollout_interval=log_rollout_interval,
+ rollout_log_file=rollout_log_file,
+ enable_profiling=enable_profiling,
+ n_behind=n_behind,
+ )
+ self.load_balancer = load_balancer
+ self.tool_workers = tool_workers
+ self.agentic_config = model_config if not agentic_config else agentic_config
+ self.agentic_config.update({"model": model_config["path"]})
+ self.tools_schema = []
+ self.tool_call_budget = self.agentic_config.get("tool_call_budget", 3)
+ self.llm_call_budget = self.agentic_config.get("llm_call_budget", 10)
+ self.async_llm_engine_map = {}
+ self._get_tools()
+
+ def _get_tools(self):
+ """
+ SYSTEM message for the agentic math model. Reference: r-start2 paper https://arxiv.org/pdf/2508.20722
+ """
+ tools = ray.get(self.tool_workers[0].list_tools.remote())
+ tool_descriptions = {tool: ray.get(self.tool_workers[0].get_tool_description.remote(tool)) for tool in tools}
+ tool_arg_schemas = {tool: ray.get(self.tool_workers[0].get_args_schema.remote(tool)) for tool in tools}
+ self.tools = []
+ for tool in tools:
+ tool_schema = {"name": tool, "description": tool_descriptions[tool], "parameters": tool_arg_schemas[tool]}
+ self.tools.append(tool_schema)
+
+ def _build_prompt(
+ self, messages, add_generation_prompt: bool = True, return_dict=True, return_tensors="pt"
+ ) -> dict:
+ """
+ Build the prompt for the agentic math model.
+ """
+ return self.tokenizer.apply_chat_template(
+ messages,
+ tools=self.tools,
+ add_generation_prompt=add_generation_prompt,
+ return_dict=return_dict,
+ return_tensors=return_tensors,
+ )
+
+ def _parse_response(self, response: str) -> Dict[str, Any]:
+ """
+ Parse the response from the agentic math model.
+
+ Sample Assistant Response:
+ The tool indicates that Singapore’s weather today is 31°C with partly cloudy skies and light showers. \\\\boxed{It is warm and slightly rainy in Singapore today.}<|im_end|>
+
+ Sample Assistant Response with Tool Call:
+ To answer this, I will check both the weather and the timezone for New York.\n\n{"name": "get_weather", "arguments": {"location": "New York"}}\n\n\n{"name": "get_timezone", "arguments": {"location": "New York"}}\n
+
+ Sample Ouput:
+ {
+ "role": "assistant",
+ "content": "Let me check the current weather in Singapore by calling the weather tool.",
+ "tool_calls": [
+ {
+ "function": {
+ "name": "get_weather",
+ "arguments": {
+ "location": "New York"
+ }
+ }
+ },
+ {
+ "function": {
+ "name": "get_timezone",
+ "arguments": {
+ "location": "New York"
+ }
+ }
+ }
+ ]
+ },
+ {
+ "role": "assistant",
+ "content": "The tool indicates that Singapore’s weather today is 31°C with partly cloudy skies and light showers. \\\\boxed{It is warm and slightly rainy in Singapore today.}"
+ }
+ """
+ # split by
+ response_chunked = response.split("<|im_end|>")[0].strip()
+ if "" in response_chunked:
+ assistant_content = response_chunked.split("")[0].strip()
+ tool_call_sections = response_chunked[response_chunked.find("") :].strip()
+ # extract all tool calls
+ tool_calls = []
+ pattern = "(.*?)"
+ matches = re.findall(pattern, tool_call_sections, re.DOTALL)
+ for match in matches:
+ try:
+ tool_call = eval(match.strip())
+ name = tool_call["name"]
+ arguments = tool_call["arguments"]
+ tool_calls.append({"function": {"name": name, "arguments": arguments}})
+ except Exception as e:
+ print(f"Failed to parse tool call: {match.strip()}. Error: {e}")
+ tool_calls.append({"function": {"name": "return_parsing_error", "arguments": {}}})
+ else:
+ assistant_content = response_chunked
+ tool_calls = []
+ assistant_message = {"role": "assistant", "content": assistant_content}
+ if tool_calls:
+ assistant_message["tool_calls"] = tool_calls
+ return assistant_message
+
+ def _select_tool_worker(self) -> int:
+ """
+ Select a tool worker based on the current load.
+ """
+ selected_idx, current_loads = ray.get(self.load_balancer.get_next_worker.remote("tool", amount=1))
+ return selected_idx
+
+ def _select_async_producer(self, request_id) -> int:
+ """
+ Select an async producer based on the current load.
+ """
+ # use the last used async producer if exists to reuse kv cache (as vllm use paged kv cache,
+ # it will reuse most of the kv cache pages without recomputation)
+ if request_id in self.async_llm_engine_map:
+ ray.get(self.load_balancer.increase_load.remote("async-llm", self.async_llm_engine_map[request_id], 1))
+ return self.async_llm_engine_map[request_id]
+ # otherwise select the least loaded async producer
+ selected_idx, current_loads = ray.get(self.load_balancer.get_next_worker.remote("async-llm", amount=1))
+ self.async_llm_engine_map[request_id] = selected_idx
+ return selected_idx
+
+ def _run_agentic_pipeline(self, messages):
+ """
+ Run the agentic pipeline to generate responses based on the input messages.
+ """
+ tool_call_count = 0
+ llm_call_count = 0
+ num_prompt_tokens = 0
+ request_id = str(uuid4())
+ logprobs = None
+ while True:
+ # tokenize the messages
+ if llm_call_count > self.llm_call_budget:
+ print(f"LLM call budget exceeded: {llm_call_count} > {self.llm_call_budget}. Stopping.")
+ del self.async_llm_engine_map[request_id]
+ return messages, response_input_ids, logprobs
+ inputs = self._build_prompt(messages, return_dict=True, return_tensors="pt")
+ if num_prompt_tokens == 0:
+ num_prompt_tokens = inputs["input_ids"].size(-1)
+ if inputs["input_ids"].size(-1) - num_prompt_tokens > self.generate_config["max_tokens"]:
+ print(
+ f"Max tokens exceeded: Current have generated {inputs['input_ids'].size(-1) - num_prompt_tokens} tokens > {self.generate_config.get('max_tokens', 512)}. Stopping."
+ )
+ del self.async_llm_engine_map[request_id]
+ return messages, response_input_ids, logprobs
+ async_producer = self.async_producers[self._select_async_producer(request_id=request_id)]
+ agentic_generate_config = copy.deepcopy(self.generate_config)
+ agentic_generate_config["max_tokens"] = self.agentic_config.get("max_tokens", 2048)
+ response = ray.get(
+ async_producer.generate.remote(
+ inputs["input_ids"],
+ inputs["attention_mask"],
+ request_id=request_id,
+ **agentic_generate_config,
+ )
+ )
+ llm_call_count += 1
+ ray.get(self.load_balancer.decrease_load.remote("async-llm", self.async_llm_engine_map[request_id], 1))
+ self.consumer_global_step = response.pop("consumer_global_step")
+ response_input_ids = response["input_ids"]
+ logprobs = response["action_log_probs"]
+ response_text = self.tokenizer.decode(
+ response_input_ids[0][0][inputs["input_ids"].size(-1) :], skip_special_tokens=False
+ )
+ assistant_message = self._parse_response(response_text)
+ messages.append(assistant_message)
+ if "tool_calls" in assistant_message:
+ if tool_call_count > self.tool_call_budget:
+ print(f"Tool call budget exceeded: {tool_call_count} > {self.tool_call_budget}. Stopping.")
+ del self.async_llm_engine_map[request_id]
+ return messages, response_input_ids, logprobs
+ tool_call_count += len(assistant_message["tool_calls"])
+ handlers = []
+ tool_workers_called = []
+ for tool_call in assistant_message["tool_calls"]:
+ # select a tool worker to execute the tool call
+ tool_worker_idx = self._select_tool_worker()
+ tool_workers_called.append(tool_worker_idx)
+ tool_worker = self.tool_workers[tool_worker_idx]
+ handler = tool_worker.call.remote(tool_call["function"]["name"], tool_call["function"]["arguments"])
+ handlers.append(handler)
+ tool_results = ray.get(handlers)
+ for idx in tool_workers_called:
+ ray.get(self.load_balancer.decrease_load.remote("tool", idx, 1))
+ for tool_call, tool_result in zip(assistant_message["tool_calls"], tool_results):
+ tool_message = {"role": "tool", "content": str(tool_result)}
+ messages.append(tool_message)
+ else:
+ # no further tool call, return the messages
+ del self.async_llm_engine_map[request_id]
+ return messages, response_input_ids, logprobs
diff --git a/applications/ColossalChat/coati/distributed/agent/base.py b/applications/ColossalChat/coati/distributed/agent/base.py
new file mode 100644
index 000000000000..e5ff9ffc588b
--- /dev/null
+++ b/applications/ColossalChat/coati/distributed/agent/base.py
@@ -0,0 +1,213 @@
+import copy
+import json
+from concurrent.futures import ThreadPoolExecutor
+from typing import Any, Dict
+
+import ray
+import torch
+from coati.distributed.producer import BaseProducer
+from vllm import SamplingParams
+
+
+class BaseAgenticProducer(BaseProducer):
+ """
+ Asyncronous version of the producer that uses vLLM for generation.
+ This class is designed to generate agentic response
+ """
+
+ def __init__(
+ self,
+ producer_idx,
+ num_producers,
+ num_consumer_procs,
+ num_episodes,
+ batch_size,
+ train_dataset_config,
+ model_config,
+ generate_config,
+ async_producers,
+ tokenizer_config=None,
+ microbatch_size=1,
+ backend="transformers",
+ num_generations: int = 8,
+ consumer_plugin_config=None,
+ eval_dataset_config=None,
+ eval_interval=-1, # disable evaluation
+ grpo_config: Dict[str, Any] = None,
+ eval_save_dir: str = "./eval",
+ eval_generation_config={},
+ project_name: str = None,
+ run_name: str = None,
+ wandb_group_name: str = None,
+ log_rollout_interval: int = 20,
+ rollout_log_file: str = "./rollout_log.jsonl",
+ enable_profiling: bool = False,
+ n_behind: int = 0,
+ ):
+ assert microbatch_size == 1 # microbatch_size must be 1 for agentic producer
+ assert batch_size == 1 # batch_size must be 1 for agentic producer
+ super().__init__(
+ producer_idx,
+ num_producers,
+ num_consumer_procs,
+ num_episodes,
+ batch_size,
+ train_dataset_config,
+ model_config,
+ generate_config,
+ tokenizer_config,
+ microbatch_size,
+ backend,
+ consumer_plugin_config,
+ eval_dataset_config=eval_dataset_config,
+ eval_interval=eval_interval,
+ grpo_config=grpo_config,
+ eval_save_dir=eval_save_dir,
+ project_name=project_name,
+ run_name=run_name,
+ wandb_group_name=wandb_group_name,
+ log_rollout_interval=log_rollout_interval,
+ rollout_log_file=rollout_log_file,
+ enable_profiling=enable_profiling,
+ n_behind=n_behind,
+ enable_agentic=True,
+ )
+ self.eval_generation_config = copy.deepcopy(generate_config)
+ self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
+ self.eval_generation_config.update(eval_generation_config)
+ self.eval_sample_params = SamplingParams(**self.eval_generation_config)
+ self.async_producers = async_producers
+ self.num_generations = num_generations
+ self.generate_config = generate_config
+
+ def _run_agentic_pipeline(self, messages):
+ """
+ Run the agentic pipeline to generate responses based on the input messages.
+ This function should be implemented in subclasses.
+ """
+ raise NotImplementedError
+
+ def _build_prompt(
+ self, messages, add_generation_prompt: bool = True, return_dict=True, return_tensors="pt"
+ ) -> dict:
+ """
+ Build the prompt from the input messages.
+ This function should be implemented in subclasses.
+ """
+ raise NotImplementedError
+
+ def rollout(self, **kwargs) -> Dict[str, torch.Tensor]:
+ """
+ Rollout function to generate responses for the input, for example, using LLM or agentic pipeline.
+ This function should be implemented in subclasses.
+ """
+ assert len(kwargs["messages"]) == 1, "Only support batch size of 1 for agentic producer"
+ messages = kwargs["messages"][0]
+ prompt_input_ids = self._build_prompt(
+ messages, return_dict=True, return_tensors="pt", add_generation_prompt=True
+ )["input_ids"]
+ # add left padding
+ prompt_length = prompt_input_ids.shape[1]
+ max_prompt_length = self.train_dataset_config["max_length"]
+ to_pad_left = max_prompt_length - prompt_length
+ rollouts = {
+ "input_ids": [],
+ "attention_mask": [],
+ "action_mask": [],
+ "action_log_probs": [],
+ "response_idx": [],
+ }
+ with ThreadPoolExecutor(max_workers=self.num_generations) as executor:
+ results = list(
+ executor.map(self._run_agentic_pipeline, [copy.deepcopy(messages) for _ in range(self.num_generations)])
+ )
+
+ for i in range(self.num_generations):
+ # due to the multiround feature, action_mask and attention_mask need to be recomputed
+ _messages, response_input_ids, logprobs = results[i]
+ # truncate if too long
+ response_input_ids = response_input_ids[0, :, : self.grpo_config["max_length"] - to_pad_left]
+ # add left right padding
+ to_pad_right = self.grpo_config["max_length"] - response_input_ids.size(-1) - to_pad_left
+ input_ids = torch.nn.functional.pad(
+ response_input_ids, (to_pad_left, to_pad_right), "constant", value=self.tokenizer.pad_token_id
+ ) # [1, max_length]
+ attention_mask = input_ids.ne(self.tokenizer.pad_token_id).int() # [1, max_length]
+ action_mask = input_ids[:, max_prompt_length:].ne(self.tokenizer.pad_token_id).int()
+ response_length = action_mask.sum().item()
+ rollouts["attention_mask"].append(attention_mask)
+ rollouts["action_mask"].append(action_mask)
+ truncated_logprobs = logprobs[
+ 0, :, prompt_length : prompt_length + self.generate_config["max_tokens"]
+ ] # truncate to max_new_tokens
+ logprobs_padded = torch.nn.functional.pad(
+ truncated_logprobs,
+ (0, self.generate_config["max_tokens"] - truncated_logprobs.size(-1)),
+ "constant",
+ value=0.0,
+ ) # [1, max_new_tokens]
+ rollouts["action_log_probs"].append(logprobs_padded)
+ rollouts["response_idx"].append(
+ torch.tensor(
+ [
+ [
+ self.train_dataset_config["max_length"],
+ self.train_dataset_config["max_length"] + response_length,
+ ]
+ ]
+ )
+ ) # [1, 2]
+ rollouts["input_ids"].append(input_ids)
+ rollouts = {k: torch.cat(v, dim=0).unsqueeze(0) for k, v in rollouts.items()} # [num_generations, ...]
+ rollouts["temperature"] = torch.tensor([self.agentic_config.get("temperature", 1.0)])
+ if hasattr(self, "rollout_log_file") and self.producer_idx == 0 and not self.eval_mode:
+ # for agentic producer, AsyncSimpleProducer is not the main producer, so we don't log rollouts
+ if (
+ self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
+ or self.latest_rollout_log_step == -1
+ ):
+ new_record = (
+ json.dumps(
+ {
+ "train_step": self.consumer_global_step,
+ "rollout": self.tokenizer.batch_decode(
+ rollouts["input_ids"][:, 0], skip_special_tokens=True
+ ),
+ },
+ ensure_ascii=False,
+ )
+ + "\n"
+ )
+ self.rollout_log_file.write(new_record)
+ self.rollout_log_file.flush()
+ self.latest_rollout_log_step = self.consumer_global_step
+
+ if "gt_answer" in kwargs:
+ rollouts["gt_answer"] = kwargs["gt_answer"]
+ if "test_cases" in kwargs:
+ rollouts["test_cases"] = kwargs["test_cases"]
+ return rollouts
+
+ def sync_model(self, episode, step) -> None:
+ """
+ sync model from consumer to self.async_producers
+ AgenticProducer does not hold any model weights, so no need to sync model to self.async_producers
+ """
+ tasks = []
+ for proc in self.async_producers:
+ tasks.append(proc.async_sync_model.remote(episode, step, self.num_producers))
+ ray.get(tasks)
+ return
+
+ def sync_data(self, data: Dict[str, torch.Tensor]) -> None:
+ """
+ sync data from self to consumer
+ """
+ tasks = []
+ for idx, proc in enumerate(self.async_producers):
+ if idx == self.producer_idx % len(self.async_producers):
+ tasks.append(proc.async_sync_data.remote(data, self.num_producers))
+ else:
+ tasks.append(proc.async_sync_data.remote({}, self.num_producers))
+ ray.get(tasks)
+ return
diff --git a/applications/ColossalChat/coati/distributed/agent/math_tools.py b/applications/ColossalChat/coati/distributed/agent/math_tools.py
new file mode 100644
index 000000000000..dba8b93b6519
--- /dev/null
+++ b/applications/ColossalChat/coati/distributed/agent/math_tools.py
@@ -0,0 +1,31 @@
+from langchain_core.tools import Tool
+from langchain_experimental.utilities import PythonREPL
+from pydantic import BaseModel, Field
+from pydantic.fields import FieldInfo
+
+
+def make_title(field_name: str, field_info: FieldInfo) -> str:
+ return field_name
+
+
+class PythonInput(BaseModel):
+ code: str = Field(description="The python code to execute", field_title_generator=make_title)
+
+
+python_repl = PythonREPL()
+
+
+def run_python_code(code: str) -> str:
+ if code.startswith("```python"):
+ code = code.replace("```python", "```", 1).strip()
+ if code.startswith("```py"): # qwen3 uses ```py
+ code = code.replace("```py", "```", 1).strip()
+ return python_repl.run(code, timeout=30)
+
+
+repl_tool = Tool(
+ name="python_repl",
+ description="A Python shell. Use this to execute python commands. Input should be a valid python command. If you want to see the output of a value, you should print it out with `print(...)`.",
+ func=run_python_code,
+ args_schema=PythonInput,
+)
diff --git a/applications/ColossalChat/coati/distributed/agent/tool_worker.py b/applications/ColossalChat/coati/distributed/agent/tool_worker.py
new file mode 100644
index 000000000000..454d2adba1d0
--- /dev/null
+++ b/applications/ColossalChat/coati/distributed/agent/tool_worker.py
@@ -0,0 +1,64 @@
+from typing import Any, Dict, List, Optional, Union
+
+import ray
+from langchain.tools import BaseTool
+
+
+@ray.remote(concurrency_groups={"io": 1, "compute": 5})
+class ToolWorker:
+ """
+ A unified wrapper class for LangChain tools, enabling a standard
+ interface to call tools regardless of their internal differences.
+ """
+
+ def __init__(self, tools: List[BaseTool]):
+ """
+ Initialize ToolWorker with a list of LangChain tools.
+
+ Args:
+ tools (List[BaseTool]): List of LangChain tools to register.
+ """
+ self._tool_registry: Dict[str, BaseTool] = {tool.name: tool for tool in tools}
+
+ @ray.method(concurrency_group="io")
+ def list_tools(self) -> List[str]:
+ """Return the list of available tool names."""
+ return list(self._tool_registry.keys())
+
+ @ray.method(concurrency_group="io")
+ def get_tool_description(self, tool_name: str) -> Optional[str]:
+ """Return the description of a specific tool."""
+ tool = self._tool_registry.get(tool_name)
+ return tool.description if tool else None
+
+ @ray.method(concurrency_group="io")
+ def get_args_schema(self, tool_name: str):
+ """Return the argument schema of a specific tool."""
+ assert tool_name in self._tool_registry, f"Tool '{tool_name}' not found. Available: {self.list_tools()}"
+ tool = self._tool_registry.get(tool_name)
+ schema = tool.args_schema.model_json_schema(by_alias=False)
+ return schema
+
+ @ray.method(concurrency_group="compute")
+ def call(self, tool_name: str, input_data: Union[str, Dict[str, Any]], **kwargs) -> Any:
+ """
+ Call a tool by name with input data.
+
+ Args:
+ tool_name (str): Name of the tool to call.
+ input_data (Union[str, Dict[str, Any]]): Input to pass to the tool.
+ **kwargs: Extra keyword arguments for the tool.
+
+ Returns:
+ Any: The tool's output.
+ """
+ if tool_name == "return_parsing_error":
+ return "Error: Tool call parsing error. Please use the correct JSON format."
+ if tool_name not in self._tool_registry:
+ return f"Error: Tool {tool_name} not found. Available tools: {self.list_tools()}"
+ tool = self._tool_registry[tool_name]
+ try:
+ ret = tool.run(input_data, **kwargs)
+ except Exception as e:
+ ret = f"Error: Tool {tool_name} execution failed with error: {str(e)}"
+ return ret
diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py
index 21da67161956..45840e7e488a 100644
--- a/applications/ColossalChat/coati/distributed/consumer.py
+++ b/applications/ColossalChat/coati/distributed/consumer.py
@@ -150,6 +150,7 @@ def loop(self) -> None:
self.profiler.enter("sync_model")
torch.cuda.empty_cache()
state_dict = self.state_dict()
+ print(f"[C{self.rank}]: Sync model before training")
if self.pp_size > 1:
if self.tp_rank == 0 and self.dp_rank == 0:
ray_broadcast_tensor_dict(
@@ -179,7 +180,6 @@ def loop(self) -> None:
for step in pbar:
torch.cuda.reset_peak_memory_stats()
i = 0
-
self.profiler.enter(f"rollout_episode_{episode}_step_{step}")
for _ in range(self.num_recv_per_update):
if self.n_behind > 0:
diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py
index 34827e4e2cf9..9f9d8e36ba68 100644
--- a/applications/ColossalChat/coati/distributed/inference_backend.py
+++ b/applications/ColossalChat/coati/distributed/inference_backend.py
@@ -1,8 +1,11 @@
+import asyncio
from typing import Any, Dict
+from uuid import uuid4
import torch
import torch.nn.functional as F
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer
+from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
from colossalai.utils import get_current_device
@@ -43,6 +46,27 @@ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
pass
+class AsyncInferenceBackend(BaseInferenceBackend):
+ async def generate(
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs
+ ) -> Dict[str, torch.Tensor]:
+ """Generate new tokens given input_ids and attention_mask.
+
+ Args:
+ input_ids (torch.Tensor): shape [B, S]
+ attention_mask (torch.Tensor): shape [B, S]
+
+ Returns:
+ Dict[str, torch.Tensor]: containing the
+ - input_ids (torch.Tensor): shape [B, S+N]
+ - attention_mask (torch.Tensor): shape [B, S+N]
+ - action_log_probs (torch.Tensor): shape [B, N]
+ - action_mask (torch.Tensor): shape [B, N]
+ where N is the number of generated tokens. And all tensors should be on CUDA.
+ """
+ raise NotImplementedError("Generate method must be implemented in subclass.")
+
+
class TransformersInferenceBackend(BaseInferenceBackend):
DEFAULT_MODEL_CONFIG = dict(
trust_remote_code=True,
@@ -59,6 +83,8 @@ def __init__(
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
+ microbatch_size: int = 1,
+ profiler=None,
):
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
model_config.update(self.FORCE_MODEL_CONFIG)
@@ -68,6 +94,7 @@ def __init__(
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
self.tokenizer = tokenizer
self.num_generations = num_generations
+ self.profiler = profiler
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
@@ -132,6 +159,8 @@ def __init__(
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
+ microbatch_size: int = 1,
+ profiler=None,
):
if sgl is None:
raise ImportError("sglang is not installed")
@@ -196,6 +225,8 @@ def __init__(
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
+ microbatch_size: int = 1,
+ profiler=None,
):
if LLM is None:
raise ImportError("vllm is not installed")
@@ -220,7 +251,12 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
micro_batch_input_ids_no_padding = [
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
]
- sample_params = kwargs.get("sample_params", self.sample_params)
+ sample_params = self.sample_params
+ if len(kwargs) > 0:
+ sample_params = self.generate_config.copy()
+ sample_params.update({k: v for k, v in kwargs.items() if k not in ["gt_answer", "test_cases", "labels"]})
+ sample_params.update(self.FORCE_GENERATE_CONFIG)
+ sample_params = SamplingParams(**sample_params)
outputs = self.llm.generate(
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=sample_params, use_tqdm=False
)
@@ -280,8 +316,149 @@ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
self.llm.llm_engine.model_executor.driver_worker.model_runner.model.load_weights(state_dict.items())
+class AsyncVLLMInferenceBackend(AsyncInferenceBackend):
+ DEFAULT_MODEL_CONFIG = dict(
+ trust_remote_code=True,
+ enable_sleep_mode=False,
+ )
+ FORCE_GENERATE_CONFIG = dict(
+ logprobs=0,
+ )
+
+ def __init__(
+ self,
+ model_config: Dict[str, Any],
+ generate_config: Dict[str, Any],
+ tokenizer: PreTrainedTokenizer,
+ num_generations: int = 8,
+ microbatch_size: int = 1,
+ profiler=None,
+ ):
+ if LLM is None:
+ raise ImportError("vllm is not installed")
+ model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
+ path = model_config.pop("path")
+ engine_args = AsyncEngineArgs(model=path, disable_log_stats=True, **model_config)
+ self.engine = AsyncLLMEngine.from_engine_args(engine_args)
+ generate_config = generate_config.copy()
+ generate_config.update(self.FORCE_GENERATE_CONFIG)
+ if "n" not in generate_config:
+ generate_config.update({"n": num_generations})
+ self.generate_config = generate_config
+ self.sample_params = SamplingParams(**generate_config)
+ self.model_config = model_config
+ self.tokenizer = tokenizer
+ self.num_generations = num_generations
+ self.running_requests = []
+ self.microbatch_size = microbatch_size
+ self.profiler = profiler
+
+ @torch.no_grad()
+ async def generate(
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs
+ ) -> Dict[str, torch.Tensor]:
+ """
+ Generate text from the model asynchronously.
+ Args:
+ input_ids (torch.Tensor): shape [B, S], B=1
+ attention_mask (torch.Tensor): shape [B, S]
+ """
+ assert input_ids.size(0) == attention_mask.size(0) == 1, "AsyncVLLMInferenceBackend only supports batch size 1"
+ request_id = (
+ str(uuid4()) if not "request_id" in kwargs else kwargs.pop("request_id")
+ ) # use fixed request_id to reuse kv cache
+ response_start_idx = input_ids.size(1)
+ first_non_padding_token_idx = (input_ids != self.tokenizer.pad_token_id).int().argmax(dim=1)
+ input_ids_no_padding = [input_ids.tolist()[0][first_non_padding_token_idx[0] :]]
+ sample_params = self.sample_params
+ if len(kwargs) > 0:
+ sample_params = self.generate_config.copy()
+ sample_params.update({k: v for k, v in kwargs.items() if k not in ["gt_answer", "test_cases", "labels"]})
+ sample_params.update(self.FORCE_GENERATE_CONFIG)
+ sample_params = SamplingParams(**sample_params)
+ out_tokens = []
+ out_len = []
+ log_probs = []
+ response_idx = []
+ while len(self.running_requests) >= self.microbatch_size:
+ await asyncio.sleep(0.1)
+ self.running_requests.append(request_id) # enqueue
+ # pop the first input_ids and attention_mask
+ prompt_token_ids = input_ids_no_padding[0]
+ self.profiler.enter(f"vllm generate {request_id}")
+ outputs = self.engine.generate(
+ prompt={"prompt_token_ids": prompt_token_ids}, sampling_params=sample_params, request_id=request_id
+ )
+ async for chunk in outputs:
+ # generate the output tokens, can yield to avoid blocking
+ pass
+ self.running_requests.remove(request_id) # dequeue
+ if self.generate_config.get("prompt_logprobs", None) is not None:
+ # when prompt_logprobs is not None, vllm will return logprobs for the whole sequence
+ # for agentic producer, we return the logprobs of the whole sequence
+ log_probs = [
+ [m[t].logprob if m is not None else 0.0 for m, t in zip(chunk.prompt_logprobs, chunk.prompt_token_ids)]
+ ]
+ for _ in range(sample_params.n - 1):
+ log_probs.append([t for t in log_probs[0]]) # repeat the same logprobs for num_generations times
+ else:
+ log_probs = [[] for _ in range(sample_params.n)]
+
+ for generation_id, output_i in enumerate(chunk.outputs):
+ out_len.append(len(output_i.token_ids))
+ out_tokens.append(list(output_i.token_ids))
+ response_idx.append((response_start_idx, response_start_idx + len(output_i.token_ids) - 1))
+ assert len(output_i.logprobs) == len(output_i.token_ids)
+ p = [m[t].logprob for m, t in zip(output_i.logprobs, output_i.token_ids)]
+ log_probs[generation_id].extend(p)
+ self.profiler.exit(f"vllm generate {request_id}")
+ # pad them
+ max_len = sample_params.max_tokens
+ action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype)
+
+ for i, new_token_ids in enumerate(out_tokens):
+ pad_len = max_len - out_len[i]
+ out_tokens[i] = new_token_ids + [self.tokenizer.pad_token_id] * pad_len
+ log_probs[i] = log_probs[i] + [0.0] * (max_len - len(log_probs[i]))
+ action_mask[i, out_len[i] :] = 0
+
+ out_tokens = torch.tensor(out_tokens)
+ log_probs = torch.tensor(log_probs)
+ response_idx = torch.tensor(response_idx)
+
+ if attention_mask.size(0) != action_mask.size(0):
+ assert action_mask.size(0) % attention_mask.size(0) == 0
+ num_returns = action_mask.size(0) // attention_mask.size(0)
+ attention_mask = attention_mask.repeat_interleave(num_returns, dim=0)
+ input_ids = input_ids.repeat_interleave(num_returns, dim=0)
+
+ out_tokens = torch.cat((input_ids, out_tokens), dim=1)
+ attention_mask = torch.cat((attention_mask, action_mask), dim=1)
+
+ data = {
+ "input_ids": out_tokens,
+ "attention_mask": attention_mask,
+ "action_log_probs": log_probs,
+ "action_mask": action_mask,
+ "response_idx": response_idx,
+ }
+
+ data = {k: v.view(1, -1, v.size(-1)) for k, v in data.items()}
+ data = {k: v.to(get_current_device()) for k, v in data.items()}
+ if "gt_answer" in kwargs:
+ data["gt_answer"] = kwargs["gt_answer"]
+ if "test_cases" in kwargs:
+ data["test_cases"] = kwargs["test_cases"]
+
+ return data
+
+ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
+ self.engine.engine.model_executor.driver_worker.model_runner.model.load_weights(state_dict.items())
+
+
BACKEND_MAP = {
"transformers": TransformersInferenceBackend,
# "sglang": SGLangInferenceBackend, # sglang backend will stuck the process due to unknown reason
"vllm": VLLMInferenceBackend,
+ "async-vllm": AsyncVLLMInferenceBackend,
}
diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py
index d60312e2b0b1..f060104db0c1 100644
--- a/applications/ColossalChat/coati/distributed/launch.py
+++ b/applications/ColossalChat/coati/distributed/launch.py
@@ -4,10 +4,13 @@
from typing import Any, Dict, Optional
import ray
+from coati.distributed.agent.agentic_producer import AgenticProducer
+from coati.distributed.agent.tool_worker import ToolWorker
from .consumer import SimpleConsumer
from .grpo_consumer import GRPOConsumer
-from .producer import SimpleProducer
+from .producer import AsyncSimpleProducer, SimpleProducer
+from .utils import LoadBalancer
ALGO_MAP = {
"Simple": SimpleConsumer,
@@ -16,6 +19,10 @@
"REINFORCE_PPB": GRPOConsumer,
"RLOO": GRPOConsumer,
}
+Producer_MAP = {"Simple": SimpleProducer, "Async": AsyncSimpleProducer}
+AGENTIC_PRODUCER_MAP = {
+ "Agentic": AgenticProducer,
+} # supported agentic producers
def get_jsonl_size_fast(path: str) -> int:
@@ -47,6 +54,7 @@ def launch_distributed(
generate_config: Dict[str, Any],
train_model_config: Dict[str, Any],
grpo_config: Dict[str, Any],
+ agentic_config: Optional[Dict[str, Any]],
plugin_config: Dict[str, Any],
tokenizer_config: Optional[Dict[str, Any]] = None,
inference_backend: str = "transformers",
@@ -61,7 +69,7 @@ def launch_distributed(
eval_interval: int = 100,
eval_save_dir: Optional[str] = None,
eval_generation_config: Optional[Dict[str, Any]] = None,
- log_rollout_interval: int = 20,
+ log_rollout_interval: int = 1,
rollout_save_dir: str = "./rollout",
enable_profiling: bool = False,
n_behind: int = 0,
@@ -79,7 +87,9 @@ def launch_distributed(
num_samples = get_jsonl_size_fast(dataset_path)
global_inference_batch_size = inference_batch_size * num_producers
num_update_per_episode = num_samples // global_inference_batch_size
- num_recv_per_update = inference_batch_size // inference_microbatch_size
+ num_recv_per_update = (
+ inference_batch_size // inference_microbatch_size if "async-agentic" not in inference_backend else 1
+ )
run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
wandb_group_name = str(uuid.uuid4())
@@ -110,6 +120,14 @@ def launch_distributed(
print(node_info)
producer_procs = []
+ if "async" in inference_backend:
+ core_producer = AsyncSimpleProducer
+ else:
+ core_producer = SimpleProducer
+ enable_agentic = "agentic" in inference_backend
+ if enable_agentic:
+ inference_backend = inference_backend.replace("agentic-", "")
+ inference_microbatch_size = inference_microbatch_size * num_generations
for i in range(num_producers):
node_id = gpu_to_node_id[0]
producer_ip_address = gpu_to_ip_address[0]
@@ -117,7 +135,7 @@ def launch_distributed(
gpu_to_node_id.pop(0)
gpu_to_ip_address.pop(0)
print(f"Schedual Producer P[{i}] which requires {num_proc_per_producer} GPUs on node {producer_ip_address}")
- producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
+ producer = core_producer.options(num_gpus=num_proc_per_producer).remote(
producer_idx=i,
num_producers=num_producers,
num_consumer_procs=num_consumer_procs,
@@ -140,12 +158,69 @@ def launch_distributed(
run_name=run_name,
wandb_group_name=wandb_group_name,
log_rollout_interval=log_rollout_interval,
- rollout_log_file=rollout_log_file,
+ rollout_log_file=rollout_log_file if not enable_agentic else None,
enable_profiling=enable_profiling,
n_behind=n_behind,
)
producer_procs.append(producer)
ray.get([p.setup.remote() for p in producer_procs])
+
+ if enable_agentic:
+ from coati.distributed.agent.math_tools import repl_tool
+
+ # setup tool workers
+ tool_workers = []
+ if agentic_config["agentic_producer"] == "Agentic":
+ # 10 tool workers can handle 50 queries simultaneously
+ # note that imported repl_tool will be serialized and deserialized in each tool worker, therefore all workers can run parallely
+ tool_workers = [ToolWorker.remote([repl_tool]) for _ in range(agentic_config.get("num_tool_workers", 10))]
+ # when agentic is enabled, we use core_producer as inference engine and
+ # AgenticProducer as the real producer
+ _producer_procs = producer_procs
+ assert (
+ "agentic_producer" in agentic_config
+ ), "Please specify the agentic producer through `agentic_producer` in agentic_config."
+ assert (
+ agentic_config["agentic_producer"] in AGENTIC_PRODUCER_MAP
+ ), f"Only {list(AGENTIC_PRODUCER_MAP.keys())} are supported as agentic producer so far."
+ load_balancer = LoadBalancer.remote({"tool": len(tool_workers), "async-llm": num_producers})
+ agentic_producer_cls = AGENTIC_PRODUCER_MAP[agentic_config["agentic_producer"]]
+ agentic_config.pop("agentic_producer")
+ producer_procs = [
+ agentic_producer_cls.options(num_cpus=1).remote(
+ producer_idx=producer_idx,
+ num_producers=num_producers * inference_batch_size,
+ num_consumer_procs=num_consumer_procs,
+ num_episodes=num_episodes,
+ batch_size=1, # batch_size must be 1 for agentic producer
+ train_dataset_config=train_dataset_config,
+ model_config=inference_model_config,
+ generate_config=generate_config,
+ async_producers=_producer_procs,
+ tool_workers=tool_workers,
+ tokenizer_config=tokenizer_config,
+ agentic_config=agentic_config,
+ microbatch_size=1, # microbatch_size must be 1 for agentic producer
+ backend=inference_backend,
+ num_generations=num_generations,
+ consumer_plugin_config=plugin_config,
+ eval_dataset_config=eval_dataset_config,
+ eval_interval=eval_interval,
+ grpo_config=grpo_config,
+ eval_save_dir=eval_save_dir,
+ eval_generation_config=eval_generation_config,
+ project_name=project_name,
+ run_name=run_name,
+ wandb_group_name=wandb_group_name,
+ log_rollout_interval=log_rollout_interval,
+ rollout_log_file=rollout_log_file,
+ enable_profiling=enable_profiling,
+ load_balancer=load_balancer,
+ n_behind=n_behind,
+ )
+ for producer_idx in range(num_producers * inference_batch_size)
+ ]
+
generate_config_consumer = copy.deepcopy(generate_config)
generate_config_consumer.update(
dict(
diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py
index ab38f987f65a..7fcfdba31f4d 100644
--- a/applications/ColossalChat/coati/distributed/loss.py
+++ b/applications/ColossalChat/coati/distributed/loss.py
@@ -37,9 +37,9 @@ def forward(
total_effective_tokens_in_batch: torch.Tensor = None,
) -> torch.Tensor:
if action_mask is None:
- ratio = (log_probs - log_probs.detach()).exp()
+ ratio = (log_probs - old_log_probs.detach()).exp()
else:
- ratio = ((log_probs - log_probs.detach()) * action_mask).exp()
+ ratio = ((log_probs - old_log_probs.detach()) * action_mask).exp()
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages
diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py
index 38a85b9b1c4d..ed0faa9fdae6 100644
--- a/applications/ColossalChat/coati/distributed/producer.py
+++ b/applications/ColossalChat/coati/distributed/producer.py
@@ -1,3 +1,4 @@
+import asyncio
import copy
import json
import os
@@ -56,6 +57,7 @@ def __init__(
log_rollout_interval: int = 20,
rollout_log_file: str = "./rollout_log.jsonl",
enable_profiling: bool = False,
+ enable_agentic: bool = False,
n_behind: int = 0,
):
self.producer_idx = producer_idx
@@ -64,8 +66,12 @@ def __init__(
self.num_episodes = num_episodes
self.batch_size = batch_size
self.microbatch_size = microbatch_size
- assert batch_size % microbatch_size == 0
- self.num_microbatches = batch_size // microbatch_size
+ if not isinstance(self, BaseAsyncProducer):
+ assert batch_size % microbatch_size == 0, "batch_size must be divisible by microbatch_size"
+ self.num_microbatches = batch_size // microbatch_size
+ else:
+ assert microbatch_size > 0, "microbatch_size must be positive"
+ self.num_microbatches = max(1, batch_size // microbatch_size)
self.latest_eval_step = -1
self.profiler = CustomProfiler(f"P{self.producer_idx}", disabled=not enable_profiling)
@@ -83,20 +89,28 @@ def __init__(
self.latest_rollout_log_step = -1
self.grpo_config = grpo_config
self.n_behind = n_behind
+ self.enable_agentic = enable_agentic
reward_model_kwargs = {
k: v
for k, v in grpo_config.items()
- if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length", "code_verifier_api_url"]
+ if k
+ in [
+ "soft_over_length_punishment",
+ "max_new_tokens",
+ "cache_length",
+ "code_verifier_api_url",
+ "forced_patterns",
+ ]
}
self.response_format_tags = grpo_config.get("response_format_tags", None)
- if producer_idx == 0:
+ if producer_idx == 0 and rollout_log_file is not None:
if os.path.exists(rollout_log_file):
raise ValueError(
f"Rollout log file {rollout_log_file} already exists. Please delete it or change the name."
)
else:
os.makedirs(os.path.dirname(rollout_log_file), exist_ok=True)
- self.rollout_log_file = open(rollout_log_file, "w", encoding="utf8")
+ self.rollout_log_file = open(rollout_log_file, "a", encoding="utf8")
if self.producer_idx == 0:
self.wandb_run = wandb.init(
project=project_name,
@@ -123,7 +137,9 @@ def __init__(
# init dataloader
train_dataset_path = train_dataset_config.pop("path")
- self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config)
+ self.train_dataset = RawConversationDataset(
+ self.tokenizer, train_dataset_path, **train_dataset_config, tokenize=not self.enable_agentic
+ )
self.train_dataloader = DataLoader(
self.train_dataset,
batch_size=microbatch_size,
@@ -161,7 +177,10 @@ def __init__(
for eval_task_name in self.eval_dataset_config:
eval_dataset_path = eval_dataset_config[eval_task_name].pop("path")
eval_dataset = RawConversationDataset(
- self.tokenizer, eval_dataset_path, **eval_dataset_config[eval_task_name]
+ self.tokenizer,
+ eval_dataset_path,
+ **eval_dataset_config[eval_task_name],
+ tokenize=not self.enable_agentic,
)
print(f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}")
self.eval_dataloaders[eval_task_name] = DataLoader(
@@ -209,18 +228,34 @@ def setup(self) -> None:
else:
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
+ @torch.no_grad()
+ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
+ """
+ Generate responses by running inference on the input_ids and attention_mask.
+ """
+ return self.model.generate(input_ids, attention_mask, **kwargs)
+
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
+ """
+ Rollout function to generate responses for the input, for example, using LLM or agentic pipeline.
+ This function should be implemented in subclasses.
+ """
raise NotImplementedError
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
raise NotImplementedError
- def loop(self) -> None:
-
+ def sync_model(self, episode, step) -> None:
+ """
+ Default implementation to sync model from consumer to producer.
+ """
torch.cuda.empty_cache()
self.profiler.enter("sync_model")
if self.consumer_pp_size > 1:
for pp_idx in range(self.consumer_pp_size):
+ print(
+ f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(step + 1) // self.num_microbatches - 1}"
+ )
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
)
@@ -228,17 +263,28 @@ def loop(self) -> None:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict)
else:
+ print(f"[P{self.producer_idx}] Sync model episode {episode} step {(step + 1) // self.num_microbatches - 1}")
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name="sync_model"
)
+ print(
+ f"[P{self.producer_idx}] Sync model episode {episode} step {(step + 1) // self.num_microbatches - 1} done"
+ )
if "consumer_global_step" in state_dict:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict)
self.profiler.exit("sync_model")
- print(f"[P{self.producer_idx}] Sync initial model done.")
del state_dict
torch.cuda.empty_cache()
+ def sync_data(self, data: Dict[str, torch.Tensor]) -> None:
+ """
+ Default implementation to sync data from producer to consumer.
+ """
+ ray_broadcast_tensor_dict(data, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}")
+
+ def loop(self) -> None:
+ self.sync_model(0, 0)
num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
num_valid_microbatches = num_update_per_episode * self.num_microbatches
@@ -312,9 +358,10 @@ def loop(self) -> None:
self.profiler.enter("rollout")
outputs = self.rollout(**batch)
self.profiler.exit("rollout")
- outputs["temperature"] = torch.tensor(
- [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
- ).to(outputs["input_ids"].device)
+ if "temperature" not in outputs:
+ outputs["temperature"] = torch.tensor(
+ [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
+ ).to(outputs["input_ids"].device)
bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1)
self.profiler.enter("calculate_reward")
if self.grpo_config["reward_fn_type"] == "code":
@@ -359,52 +406,16 @@ def loop(self) -> None:
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs = pre_send(outputs)
self.profiler.enter("send_broadcast_data")
- ray_broadcast_tensor_dict(
- outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
- )
+ self.sync_data(outputs)
self.profiler.exit("send_broadcast_data")
if (
(i + 1) % self.num_microbatches == 0
and (episode != self.num_episodes - 1 or i != num_valid_microbatches - 1)
and (episode != 0 or (i + 1) > self.n_behind * self.num_microbatches)
):
- if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
- "enable_sleep_mode", False
- ):
- self.model.llm.sleep() # revict KV_cache to avoid OOM
- # don't sync model for last iteration
- torch.cuda.empty_cache()
- self.profiler.enter("sync_model")
- if self.consumer_pp_size > 1:
- for pp_idx in range(self.consumer_pp_size):
- print(
- f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(i + 1) // self.num_microbatches - 1}"
- )
- state_dict = ray_broadcast_tensor_dict(
- None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
- )
- if "consumer_global_step" in state_dict:
- self.consumer_global_step = state_dict.pop("consumer_global_step").item()
- self.load_state_dict(state_dict)
- else:
- print(
- f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
- )
- state_dict = ray_broadcast_tensor_dict(
- None, self.num_producers, device=self.device, group_name="sync_model"
- )
- if "consumer_global_step" in state_dict:
- self.consumer_global_step = state_dict.pop("consumer_global_step").item()
- self.load_state_dict(state_dict)
- self.profiler.exit("sync_model")
- del state_dict
- torch.cuda.empty_cache()
- if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
- "enable_sleep_mode", False
- ):
- self.model.llm.wake_up()
+ self.sync_model(episode, i)
# linear annealing for 1 episode, temperature from initial to 0.9
- if episode <= 0:
+ if episode <= 0 and hasattr(self, "model"):
ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
"temperature"
@@ -473,7 +484,9 @@ def __init__(
enable_profiling=enable_profiling,
n_behind=n_behind,
)
- self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
+ self.model = self.backend_cls(
+ model_config, generate_config, self.tokenizer, num_generations, self.microbatch_size, profiler=self.profiler
+ )
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
self.eval_generation_config.update(eval_generation_config)
@@ -481,7 +494,7 @@ def __init__(
@torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs):
- rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
+ rollouts = self.generate(input_ids, attention_mask, **kwargs)
if self.producer_idx == 0 and not self.eval_mode:
if (
self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
@@ -494,7 +507,8 @@ def rollout(self, input_ids, attention_mask, **kwargs):
"rollout": self.tokenizer.batch_decode(
rollouts["input_ids"][:, 0], skip_special_tokens=True
),
- }
+ },
+ ensure_ascii=False,
)
+ "\n"
)
@@ -511,3 +525,356 @@ def __del__(self):
def load_state_dict(self, state_dict):
self.model.load_state_dict(state_dict)
+
+
+class BaseAsyncProducer(BaseProducer):
+ """
+ Asyncronous version of the producer that uses vLLM for generation.
+ """
+
+ def __init__(
+ self,
+ producer_idx,
+ num_producers,
+ num_consumer_procs,
+ num_episodes,
+ batch_size,
+ train_dataset_config,
+ model_config,
+ generate_config,
+ tokenizer_config=None,
+ microbatch_size=1,
+ backend="transformers",
+ num_generations: int = 8,
+ consumer_plugin_config=None,
+ eval_dataset_config=None,
+ eval_interval=-1, # disable evaluation
+ grpo_config: Dict[str, Any] = None,
+ eval_save_dir: str = "./eval",
+ eval_generation_config={},
+ project_name: str = None,
+ run_name: str = None,
+ wandb_group_name: str = None,
+ log_rollout_interval: int = 20,
+ rollout_log_file: str = "./rollout_log.jsonl",
+ enable_profiling: bool = False,
+ n_behind: int = 0,
+ ):
+ super().__init__(
+ producer_idx,
+ num_producers,
+ num_consumer_procs,
+ num_episodes,
+ batch_size,
+ train_dataset_config,
+ model_config,
+ generate_config,
+ tokenizer_config,
+ microbatch_size,
+ backend,
+ consumer_plugin_config,
+ eval_dataset_config=eval_dataset_config,
+ eval_interval=eval_interval,
+ grpo_config=grpo_config,
+ eval_save_dir=eval_save_dir,
+ project_name=project_name,
+ run_name=run_name,
+ wandb_group_name=wandb_group_name,
+ log_rollout_interval=log_rollout_interval,
+ rollout_log_file=rollout_log_file,
+ enable_profiling=enable_profiling,
+ n_behind=n_behind,
+ )
+ assert backend == "async-vllm", f"AsyncProducer only supports async-vllm backend, got {backend}"
+ self.model = self.backend_cls(
+ model_config, generate_config, self.tokenizer, num_generations, self.microbatch_size, profiler=self.profiler
+ )
+ self.eval_generation_config = copy.deepcopy(self.model.generate_config)
+ self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
+ self.eval_generation_config.update(eval_generation_config)
+ self.eval_sample_params = SamplingParams(**self.eval_generation_config)
+ self.ready_processes_sync_model = 0
+ self.ready_processes_sync_data = 0
+ self.sync_model_condition = asyncio.Condition()
+ self.sync_data_condition = asyncio.Condition()
+ self.data_ready_for_sending = []
+
+ @torch.no_grad()
+ async def generate(self, input_ids, attention_mask, **kwargs):
+ # naive rollout strategy
+ tasks = []
+ for prompt_id in range(input_ids.size(0)):
+ new_kwargs = copy.deepcopy(kwargs)
+ if "gt_answer" in new_kwargs:
+ new_kwargs["gt_answer"] = new_kwargs["gt_answer"][prompt_id]
+ if "test_cases" in new_kwargs:
+ new_kwargs["test_cases"] = new_kwargs["test_cases"][prompt_id]
+ tasks.append(
+ self.model.generate(
+ input_ids[prompt_id].unsqueeze(0),
+ attention_mask[prompt_id].unsqueeze(0),
+ **new_kwargs,
+ )
+ )
+ rollouts = await asyncio.gather(*tasks)
+ rollouts = {
+ k: (
+ torch.cat([r[k] for r in rollouts], dim=0).cpu()
+ if k not in ["gt_answer", "test_cases"]
+ else [r[k] for r in rollouts]
+ ) # CUDA tensor is not serializable by ray
+ for k in rollouts[0].keys()
+ }
+ rollouts["consumer_global_step"] = self.consumer_global_step
+ return rollouts
+
+ @torch.no_grad()
+ async def rollout(self, input_ids, attention_mask, **kwargs):
+ """
+ Advanced distributed rollout strategy that dispatches the generation tasks to different DP ranks.
+ Must be implemented in subclasses.
+ """
+ raise NotImplementedError("rollout must be implemented in subclasses")
+
+ async def async_sync_model(self, episode, step, num_processes: int = 1) -> None:
+ """
+ Asyncronous version to sync model from consumer to producer.
+ called by another producer, such as agentic producer.
+ """
+ async with self.sync_model_condition:
+ self.ready_processes_sync_model += 1
+ # Wait until all processes are ready
+ if self.ready_processes_sync_model < num_processes:
+ await self.sync_model_condition.wait()
+
+ # Only one process should reset `ready_processes_sync_model` and perform the sync
+ if self.ready_processes_sync_model == num_processes:
+ self.ready_processes_sync_model = 0
+ self.sync_model_condition.notify_all() # Notify all waiting processes
+ self.sync_model(episode, step)
+
+ async def async_sync_data(self, data: Dict[str, torch.Tensor], num_processes: int = 1) -> None:
+ # merge data dict
+ async with self.sync_data_condition:
+ self.ready_processes_sync_data += 1
+ if data:
+ self.data_ready_for_sending.append(data)
+
+ # Wait until all processes are ready
+ if self.ready_processes_sync_data < num_processes:
+ await self.sync_data_condition.wait()
+
+ # Only one process should reset `ready_processes` and perform the sync
+ if self.ready_processes_sync_data == num_processes: # wait for all producers to join
+ self.ready_processes_sync_data = 0
+ self.sync_data_condition.notify_all()
+ # merge data for sending
+ if len(self.data_ready_for_sending) >= 1:
+ batch_rollout_data = {}
+ for key in self.data_ready_for_sending[0]:
+ batch_rollout_data[key] = torch.cat([d[key] for d in self.data_ready_for_sending], dim=0).to(
+ self.device
+ )
+ self.sync_data(batch_rollout_data)
+ self.data_ready_for_sending = [] # reset
+
+ async def loop(self) -> None:
+ self.sync_model(0, 0)
+ num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
+ num_valid_microbatches = num_update_per_episode * self.num_microbatches
+
+ print(
+ f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.train_dataloader)}"
+ )
+ for episode in range(self.num_episodes):
+ self.train_dataloader.sampler.set_epoch(episode)
+ for i, batch in enumerate(self.train_dataloader):
+ if i >= num_valid_microbatches:
+ break
+ if self.eval_interval > 0 and self.eval_dataset_config is not None:
+ if (
+ self.consumer_global_step - self.latest_eval_step >= self.eval_interval
+ and self.consumer_global_step > self.latest_eval_step
+ ) or self.latest_eval_step == -1:
+ to_log_msg = {}
+ self.eval_mode = True
+ for eval_task_name in self.eval_dataloaders:
+ if self.producer_idx == 0:
+ print(
+ f"[P{self.producer_idx}] Evaluate model at training step {self.consumer_global_step} on task {eval_task_name}"
+ )
+ eval_results = []
+ eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
+ for eval_batch in tqdm.tqdm(
+ self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0
+ ):
+ eval_outputs = await self.rollout(**eval_batch, sample_params=self.eval_sample_params)
+ eval_results = eval_results + [
+ self.evaluation_function(
+ eval_outputs["input_ids"][m][n],
+ eval_outputs[
+ (
+ "test_cases"
+ if self.grpo_config["reward_fn_type"] == "code"
+ else "gt_answer"
+ )
+ ][m],
+ eval_outputs["response_idx"][m][n],
+ tokenizer=self.tokenizer,
+ eval_mode=True,
+ tags=self.response_format_tags,
+ )
+ for m in range(eval_outputs["input_ids"].size(0))
+ for n in range(eval_outputs["input_ids"].size(1))
+ ]
+ eval_statistics_tensor[0] += sum([max(0, res["ans_valid"]) for res in eval_results])
+ eval_statistics_tensor[1] += len(eval_results)
+ allreduce(eval_statistics_tensor, op=ReduceOp.SUM, group_name="producer_group")
+ to_log_msg[f"eval/{eval_task_name}"] = (
+ eval_statistics_tensor[0].item() / eval_statistics_tensor[1].item()
+ )
+ if self.producer_idx == 0:
+ print(
+ f"[P{self.producer_idx}]: Accuracy on {eval_task_name}: {to_log_msg[f'eval/{eval_task_name}']}"
+ )
+ # save eval results
+ safe_append_to_jsonl_file(
+ os.path.join(
+ self.eval_save_dir,
+ f"{eval_task_name}_training_step_{self.consumer_global_step}.jsonl",
+ ),
+ eval_results,
+ )
+
+ if self.producer_idx == 0:
+ self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
+ self.eval_mode = False
+ self.latest_eval_step = self.consumer_global_step
+ self.profiler.enter("rollout")
+ outputs = await self.rollout(**batch)
+ outputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in outputs.items()}
+ self.profiler.exit("rollout")
+ outputs["temperature"] = torch.tensor(
+ [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
+ ).to(outputs["input_ids"].device)
+ bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1)
+ self.profiler.enter("calculate_reward")
+ if self.grpo_config["reward_fn_type"] == "code":
+ test_cases = []
+ for prompt_id in range(bs):
+ test_cases.extend([outputs["test_cases"][prompt_id]] * num_gen)
+ reward_model_output = self.reward_model(
+ outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))),
+ test_cases=test_cases,
+ response_idx=outputs["response_idx"].view((-1, 2)),
+ )
+ else:
+ gt_answer = []
+ for prompt_id in range(bs):
+ gt_answer.extend([outputs["gt_answer"][prompt_id]] * num_gen)
+ reward_model_output = self.reward_model(
+ outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))),
+ gt_answer=gt_answer,
+ response_idx=outputs["response_idx"].view((-1, 2)),
+ )
+ outputs["reward"] = (
+ torch.tensor([value[0] for value in reward_model_output])
+ .to(outputs["input_ids"].device)
+ .view((bs, num_gen, 1))
+ )
+ outputs["format_acc"] = (
+ torch.tensor([value[1] for value in reward_model_output])
+ .to(outputs["input_ids"].device)
+ .view((bs, num_gen, 1))
+ )
+ outputs["ans_acc"] = (
+ torch.tensor([value[2] for value in reward_model_output])
+ .to(outputs["input_ids"].device)
+ .view((bs, num_gen, 1))
+ )
+ if "gt_answer" in outputs:
+ outputs.pop("gt_answer")
+ if "test_cases" in outputs:
+ outputs.pop("test_cases")
+ if "consumer_global_step" in outputs:
+ outputs.pop("consumer_global_step")
+ self.profiler.exit("calculate_reward")
+
+ print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
+ outputs = pre_send(outputs)
+ self.profiler.enter("send_broadcast_data")
+ self.sync_data(outputs)
+ self.profiler.exit("send_broadcast_data")
+ if (
+ (i + 1) % self.num_microbatches == 0
+ and (episode != self.num_episodes - 1 or i != num_valid_microbatches - 1)
+ and (episode != 0 or (i + 1) > self.n_behind * self.num_microbatches)
+ ):
+ if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
+ "enable_sleep_mode", False
+ ):
+ self.model.llm.sleep() # revict KV_cache to avoid OOM
+ # don't sync model for last iteration
+ self.sync_model(episode, i)
+ if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
+ "enable_sleep_mode", False
+ ):
+ self.model.llm.wake_up()
+ # linear annealing for 1 episode, temperature from initial to 0.9
+ if episode <= 0:
+ ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
+ self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
+ "temperature"
+ ] + ratio * 0.9
+ if isinstance(self.model, BACKEND_MAP["vllm"]):
+ self.model.sample_params.temperature = (1 - ratio) * self.generate_config[
+ "temperature"
+ ] + ratio * 0.9
+
+ def __del__(self):
+ if self.producer_idx == 0:
+ self.wandb_run.finish()
+ if hasattr(self, "rollout_log_file"):
+ self.rollout_log_file.close()
+
+ def load_state_dict(self, state_dict):
+ self.model.load_state_dict(state_dict)
+
+
+@ray.remote
+class AsyncSimpleProducer(BaseAsyncProducer):
+ """
+ Asyncronous version of the producer that uses vLLM for generation.
+ This class is designed to handle multiple producer actors and distribute tasks among them.
+ """
+
+ @torch.no_grad()
+ async def rollout(self, input_ids, attention_mask, **kwargs):
+ # naive rollout strategy without load balancing
+ rollouts = await self.generate(input_ids, attention_mask, **kwargs)
+ if hasattr(self, "rollout_log_file") and self.producer_idx == 0 and not self.eval_mode:
+ # for agentic producer, AsyncSimpleProducer is not the main producer, so we don't log rollouts
+ if (
+ self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
+ or self.latest_rollout_log_step == -1
+ ):
+ new_record = (
+ json.dumps(
+ {
+ "train_step": self.consumer_global_step,
+ "rollout": self.tokenizer.batch_decode(
+ rollouts["input_ids"][:, 0], skip_special_tokens=True
+ ),
+ },
+ ensure_ascii=False,
+ )
+ + "\n"
+ )
+ self.rollout_log_file.write(new_record)
+ self.rollout_log_file.flush()
+ self.latest_rollout_log_step = self.consumer_global_step
+ return rollouts
+
+ async def generate(self, input_ids, attention_mask, **kwargs):
+ rollouts = await super().generate(input_ids, attention_mask, **kwargs)
+ return rollouts
diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py
index f7a2fb89cadb..57217ab94d08 100644
--- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py
+++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py
@@ -19,6 +19,7 @@
import json
+import re
import torch
from latex2sympy2_extended import NormalizationConfig
@@ -126,6 +127,12 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
format_valid = validate_response_structure(processed_str, kwargs["tags"])
+ if "forced_patterns" in kwargs and kwargs["forced_patterns"]:
+ forced_patterns = kwargs["forced_patterns"]
+ format_valid = format_valid and all(
+ [re.search(pattern, decoded_final_answer) is not None for pattern in forced_patterns]
+ )
+
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
if final_answer is not None:
if eval_mode or format_valid:
@@ -161,7 +168,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
tokenizer = kwargs["tokenizer"]
eval_mode = kwargs.get("eval_mode", False)
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
- acc_score = 10.0
+ acc_score = 1.0
reward = torch.tensor(0.0)
format_acc = torch.tensor(0.0)
ans_acc = torch.tensor(0.0)
@@ -182,7 +189,6 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
raise ValueError("no gt_answer is provided, please check your training dataset.")
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
-
final_answer = extract_boxed_solution(decoded_final_answer)
format_valid = final_answer is not None
if "tags" in kwargs and kwargs["tags"]:
@@ -190,7 +196,11 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
format_valid = format_valid and all(
[decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags]
)
-
+ if "forced_patterns" in kwargs and kwargs["forced_patterns"]:
+ forced_patterns = kwargs["forced_patterns"]
+ format_valid = format_valid and all(
+ [re.search(pattern, decoded_final_answer) is not None for pattern in forced_patterns]
+ )
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
if final_answer is not None:
if eval_mode or format_valid:
diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py
index 466914cc0d4d..48b823cb0937 100644
--- a/applications/ColossalChat/coati/distributed/utils.py
+++ b/applications/ColossalChat/coati/distributed/utils.py
@@ -1,7 +1,9 @@
import json
import os
+import random
from typing import Any, Dict, List
+import ray
import torch
from filelock import FileLock
@@ -165,3 +167,25 @@ def safe_append_to_jsonl_file(file_path, data):
for entry in data:
json_line = json.dumps(entry, ensure_ascii=False)
f.write(json_line + "\n")
+
+
+@ray.remote
+class LoadBalancer:
+ def __init__(self, worker_counts):
+ self.load = {}
+ for type in worker_counts:
+ self.load[type] = {k: 0 for k in range(worker_counts[type])}
+
+ def get_next_worker(self, worker_type, amount=1):
+ loads = [(k, v) for k, v in self.load[worker_type].items()]
+ min_load = min(loads, key=lambda x: x[1])
+ candidates = [k for k, v in loads if v == min_load[1]]
+ chosen = random.choice(candidates)
+ self.load[worker_type][chosen] += amount
+ return chosen, self.load[worker_type]
+
+ def increase_load(self, worker_type, worker_id, amount=1):
+ self.load[worker_type][worker_id] += amount
+
+ def decrease_load(self, worker_type, worker_id, amount=1):
+ self.load[worker_type][worker_id] -= amount
diff --git a/applications/ColossalChat/conversation_template/qwen3.json b/applications/ColossalChat/conversation_template/qwen3.json
new file mode 100644
index 000000000000..7c713d2b173c
--- /dev/null
+++ b/applications/ColossalChat/conversation_template/qwen3.json
@@ -0,0 +1,8 @@
+{
+ "chat_template": "{%- if tools %}\\n {{- \'<|im_start|>system\\\\n\' }}\\n {%- if messages[0].role == \'system\' %}\\n {{- messages[0].content + \'\\\\n\\\\n\' }}\\n {%- endif %}\\n {{- \\"# Tools\\\\n\\\\nYou may call one or more functions to assist with the user query.\\\\n\\\\nYou are provided with function signatures within XML tags:\\\\n\\" }}\\n {%- for tool in tools %}\\n {{- \\"\\\\n\\" }}\\n {{- tool | tojson }}\\n {%- endfor %}\\n {{- \\"\\\\n\\\\n\\\\nFor each function call, return a json object with function name and arguments within XML tags:\\\\n\\\\n{\\\\\\"name\\\\\\": , \\\\\\"arguments\\\\\\": }\\\\n<|im_end|>\\\\n\\" }}\\n{%- else %}\\n {%- if messages[0].role == \'system\' %}\\n {{- \'<|im_start|>system\\\\n\' + messages[0].content + \'<|im_end|>\\\\n\' }}\\n {%- endif %}\\n{%- endif %}\\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\\n{%- for message in messages[::-1] %}\\n {%- set index = (messages|length - 1) - loop.index0 %}\\n {%- if ns.multi_step_tool and message.role == \\"user\\" and message.content is string and not(message.content.startswith(\'\') and message.content.endswith(\'\')) %}\\n {%- set ns.multi_step_tool = false %}\\n {%- set ns.last_query_index = index %}\\n {%- endif %}\\n{%- endfor %}\\n{%- for message in messages %}\\n {%- if message.content is string %}\\n {%- set content = message.content %}\\n {%- else %}\\n {%- set content = \'\' %}\\n {%- endif %}\\n {%- if (message.role == \\"user\\") or (message.role == \\"system\\" and not loop.first) %}\\n {{- \'<|im_start|>\' + message.role + \'\\\\n\' + content + \'<|im_end|>\' + \'\\\\n\' }}\\n {%- elif message.role == \\"assistant\\" %}\\n {{- \'<|im_start|>\' + message.role + \'\\\\n\' + content }}\\n {%- if message.tool_calls %}\\n {%- for tool_call in message.tool_calls %}\\n {%- if (loop.first and content) or (not loop.first) %}\\n {{- \'\\\\n\' }}\\n {%- endif %}\\n {%- if tool_call.function %}\\n {%- set tool_call = tool_call.function %}\\n {%- endif %}\\n {{- \'\\\\n{\\"name\\": \\"\' }}\\n {{- tool_call.name }}\\n {{- \'\\", \\"arguments\\": \' }}\\n {%- if tool_call.arguments is string %}\\n {{- tool_call.arguments }}\\n {%- else %}\\n {{- tool_call.arguments | tojson }}\\n {%- endif %}\\n {{- \'}\\\\n\' }}\\n {%- endfor %}\\n {%- endif %}\\n {{- \'<|im_end|>\\\\n\' }}\\n {%- elif message.role == \\"tool\\" %}\\n {%- if loop.first or (messages[loop.index0 - 1].role != \\"tool\\") %}\\n {{- \'<|im_start|>user\' }}\\n {%- endif %}\\n {{- \'\\\\n\\\\n\' }}\\n {{- content }}\\n {{- \'\\\\n\' }}\\n {%- if loop.last or (messages[loop.index0 + 1].role != \\"tool\\") %}\\n {{- \'<|im_end|>\\\\n\' }}\\n {%- endif %}\\n {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n {{- \'<|im_start|>assistant\\\\n\' }}\\n {%- if enable_thinking is defined and enable_thinking is false %}\\n {{- \'\\\\n\\\\n\\\\n\\\\n\' }}\\n {%- endif %}\\n{%- endif %}",
+ "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
+ "stop_ids": [
+ 7
+ ],
+ "end_of_assistant": "<|im_end|>"
+}
diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py
index b584b940ccaa..5c798bdc2b0c 100644
--- a/applications/ColossalChat/rl_example.py
+++ b/applications/ColossalChat/rl_example.py
@@ -109,7 +109,13 @@
)
# Sampling parameters
- parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
+ parser.add_argument(
+ "-b",
+ "--backend",
+ type=str,
+ default="transformers",
+ choices=["transformers", "vllm", "async-vllm", "async-agentic-vllm"],
+ )
parser.add_argument("-temp", "--temperature", type=float, default=1.0, help="Temperature for sampling.")
parser.add_argument(
"-topk",
@@ -125,6 +131,7 @@
default=1.0,
help="Top p for sampling. Please check the generation arguments documentation for your backend.",
)
+ parser.add_argument("-ct", "--chat-template", type=str, default=None, help="Chat template to use for the model.")
parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.")
parser.add_argument("-mnt", "--max-new-tokens", type=int, default=1024 * 4 - 512, help="Max length for generation.")
parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.")
@@ -135,6 +142,13 @@
default=1,
help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.",
)
+ parser.add_argument(
+ "-pdp",
+ "--producer-data-parallel-size",
+ type=int,
+ default=1,
+ help="Data parallel size for the producer. Increase this value to scale up the data parallelism of the inference backend.",
+ )
# GRPO parameters
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO", "REINFORCE_PPB", "RLOO"])
@@ -148,6 +162,13 @@
choices=["think_answer_tags", "boxed", "code"],
help="Reward type for GRPO.",
)
+ parser.add_argument(
+ "--agentic-type",
+ type=str,
+ default="Agentic",
+ choices=["Agentic"],
+ help="Agentic model type for agentic training.",
+ )
parser.add_argument(
"-cv",
"--code-verifier-api-url",
@@ -207,7 +228,7 @@
runtime_env={
"env_vars": {
# "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray
- "TOKENIZERS_PARALLELISM": "false"
+ "TOKENIZERS_PARALLELISM": "false",
},
},
)
@@ -220,7 +241,7 @@
runtime_env={
"env_vars": {
# "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray
- "TOKENIZERS_PARALLELISM": "false"
+ "TOKENIZERS_PARALLELISM": "false",
},
},
)
@@ -228,7 +249,7 @@
if args.top_k is None:
if args.backend == "transformers":
args.top_k = 50
- elif args.backend == "vllm":
+ elif "vllm" in args.backend:
args.top_k = -1
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock
@@ -256,11 +277,14 @@
)
)
eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation
- elif args.backend == "vllm":
+ elif args.backend == "vllm" or args.backend == "async-vllm" or args.backend == "async-agentic-vllm":
+ # os.environ["VLLM_DP_SIZE"] = str(args.producer_data_parallel_size)
inference_model_config.update(
dict(
- gpu_memory_utilization=0.7,
- enforce_eager=True,
+ gpu_memory_utilization=0.8,
+ max_num_batched_tokens=4096,
+ max_num_seqs=1024,
+ enforce_eager=False,
enable_chunked_prefill=True,
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
tensor_parallel_size=args.producer_tensor_parallel_size,
@@ -394,9 +418,34 @@
# Default system prompt
args.system_prompt = DEFAUT_SYSTEM_PROMPT[args.reward_type]
+ if "agentic" in args.backend:
+ assert "vllm" in args.backend, "Agentic backend only supports async-agentic-vllm backends."
+ generate_config["n"] = 1 # agentic producer use AsyncProducer which processes one request a time
+ if args.agentic_type == "Agentic":
+ generate_config["stop"] = ["<|im_end|>"]
+ generate_config["prompt_logprobs"] = 0
+ agentic_config = {
+ "agentic_producer": "Agentic",
+ "tool_call_budget": 5,
+ "llm_call_budget": 10,
+ "max_tokens": 2048,
+ }
+ grpo_config["forced_patterns"] = [
+ r"\n.+\n" # please modify based on your tool response format
+ ] # force at least one correct tool call
+ else:
+ raise ValueError(f"Unsupported agentic model type: {args.agentic_type}")
+ else:
+ agentic_config = None
+
+ tokenizer_config = {"path": args.model, "trust_remote_code": True}
+ if args.chat_template is not None:
+ tokenizer_config["chat_template"] = args.chat_template
+
launch_distributed(
num_producers=args.num_inferencer,
- num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size),
+ num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size)
+ * inference_model_config.get("data_parallel_size", args.producer_data_parallel_size),
num_consumer_procs=args.num_trainers,
num_episodes=args.num_episodes,
inference_batch_size=args.inference_batch_size,
@@ -413,6 +462,8 @@
num_generations=args.num_generations,
train_model_config=train_model_config,
grpo_config=grpo_config,
+ agentic_config=agentic_config,
+ tokenizer_config=tokenizer_config,
plugin_config={
"tp_size": args.tensor_parallel_size,
"pp_size": args.pipeline_parallel_size,
@@ -440,7 +491,7 @@
eval_interval=args.eval_interval,
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
eval_generation_config=eval_generation_config,
- log_rollout_interval=20,
+ log_rollout_interval=1,
rollout_save_dir=args.rollout_save_dir,
enable_profiling=args.enable_profiling,
n_behind=args.n_behind,