From ae51e5b2440c98ddaab9b7bc0960762cdd1d6838 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Mon, 18 Aug 2025 17:40:55 +0800 Subject: [PATCH 01/11] support asyncllm --- .../coati/distributed/agent/model.py | 149 ++++++++++++ .../coati/distributed/agent/tools.py | 112 +++++++++ .../coati/distributed/inference_backend.py | 145 +++++++++++ .../ColossalChat/coati/distributed/launch.py | 9 +- .../coati/distributed/producer.py | 225 +++++++++++++++++- .../coati/distributed/reward/reward_fn.py | 2 +- applications/ColossalChat/rl_example.py | 27 ++- 7 files changed, 656 insertions(+), 13 deletions(-) create mode 100644 applications/ColossalChat/coati/distributed/agent/model.py create mode 100644 applications/ColossalChat/coati/distributed/agent/tools.py diff --git a/applications/ColossalChat/coati/distributed/agent/model.py b/applications/ColossalChat/coati/distributed/agent/model.py new file mode 100644 index 000000000000..c52d0d99df12 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/agent/model.py @@ -0,0 +1,149 @@ +""" +MIT License + +Copyright (c) 2025 LangChain + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +from typing import Any, Dict, Iterator, List, Optional + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import AIMessage, BaseMessage +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from pydantic import Field + + +class LangChainChatModel(BaseChatModel): + """A custom chat model that echoes the first `parrot_buffer_length` characters + of the input. + + When contributing an implementation to LangChain, carefully document + the model including the initialization parameters, include + an example of how to initialize the model and include any relevant + links to the underlying models documentation or API. + + Example: + + .. code-block:: python + + model = LangChainChatModel(parrot_buffer_length=2, model="bird-brain-001") + result = model.invoke([HumanMessage(content="hello")]) + result = model.batch([[HumanMessage(content="hello")], + [HumanMessage(content="world")]]) + """ + + model_name: str = Field(alias="model") + temperature: Optional[float] = None + max_tokens: Optional[int] = None + timeout: Optional[int] = None + stop: Optional[List[str]] = None + async_server_manager: Optional[Any] = None + max_retries: int = 2 + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Override the _generate method to implement the chat model logic. + + This can be a call to an API, a call to a local model, or any other + implementation that generates a response to the input prompt. + + Args: + messages: the prompt composed of a list of messages. + stop: a list of strings on which the model should stop generating. + If generation stops due to a stop token, the stop token itself + SHOULD BE INCLUDED as part of the output. This is not enforced + across models right now, but it's a good practice to follow since + it makes it much easier to parse the output of the model + downstream and understand why generation stopped. + run_manager: A run manager with callbacks for the LLM. + """ + self.async_server_manager.generate(messages, stop, run_manager, **kwargs) + tokens = last_message.content[: self.parrot_buffer_length] + ct_input_tokens = sum(len(message.content) for message in messages) + ct_output_tokens = len(tokens) + message = AIMessage( + content=tokens, + additional_kwargs={}, # Used to add additional payload to the message + response_metadata={ # Use for response metadata + "time_in_seconds": 3, + "model_name": self.model_name, + }, + usage_metadata={ + "input_tokens": ct_input_tokens, + "output_tokens": ct_output_tokens, + "total_tokens": ct_input_tokens + ct_output_tokens, + }, + ) + ## + + generation = ChatGeneration(message=message) + return ChatResult(generations=[generation]) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + """Stream the output of the model. + + This method should be implemented if the model can generate output + in a streaming fashion. If the model does not support streaming, + do not implement it. In that case streaming requests will be automatically + handled by the _generate method. + + Args: + messages: the prompt composed of a list of messages. + stop: a list of strings on which the model should stop generating. + If generation stops due to a stop token, the stop token itself + SHOULD BE INCLUDED as part of the output. This is not enforced + across models right now, but it's a good practice to follow since + it makes it much easier to parse the output of the model + downstream and understand why generation stopped. + run_manager: A run manager with callbacks for the LLM. + """ + raise NotImplementedError("Streaming is not implemented for this model. Please implement the _stream method.") + + @property + def _llm_type(self) -> str: + """Get the type of language model used by this chat model.""" + return "echoing-chat-model-advanced" + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Return a dictionary of identifying parameters. + + This information is used by the LangChain callback system, which + is used for tracing purposes make it possible to monitor LLMs. + """ + return { + # The model name allows users to specify custom token counting + # rules in LLM monitoring applications (e.g., in LangSmith users + # can provide per token pricing for their model and monitor + # costs for the given LLM.) + "model_name": self.model_name, + } diff --git a/applications/ColossalChat/coati/distributed/agent/tools.py b/applications/ColossalChat/coati/distributed/agent/tools.py new file mode 100644 index 000000000000..a39a32f335f6 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/agent/tools.py @@ -0,0 +1,112 @@ +""" +MIT License + +Copyright (c) 2025 LangChain + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import builtins +import contextlib +import io +import math +from typing import Any + + +def eval(code: str, _locals: dict[str, Any]) -> tuple[str, dict[str, Any]]: + # Store original keys before execution + original_keys = set(_locals.keys()) + + try: + with contextlib.redirect_stdout(io.StringIO()) as f: + exec(code, builtins.__dict__, _locals) + result = f.getvalue() + if not result: + result = "" + except Exception as e: + result = f"Error during execution: {repr(e)}" + + # Determine new variables created during execution + new_keys = set(_locals.keys()) - original_keys + new_vars = {key: _locals[key] for key in new_keys} + return result, new_vars + + +def add(a: float, b: float) -> float: + """Add two numbers together.""" + return a + b + + +def multiply(a: float, b: float) -> float: + """Multiply two numbers together.""" + return a * b + + +def divide(a: float, b: float) -> float: + """Divide two numbers.""" + return a / b + + +def subtract(a: float, b: float) -> float: + """Subtract two numbers.""" + return a - b + + +def sin(a: float) -> float: + """Take the sine of a number.""" + return math.sin(a) + + +def cos(a: float) -> float: + """Take the cosine of a number.""" + return math.cos(a) + + +def radians(a: float) -> float: + """Convert degrees to radians.""" + return math.radians(a) + + +def exponentiation(a: float, b: float) -> float: + """Raise one number to the power of another.""" + return a**b + + +def sqrt(a: float) -> float: + """Take the square root of a number.""" + return math.sqrt(a) + + +def ceil(a: float) -> float: + """Round a number up to the nearest integer.""" + return math.ceil(a) + + +tools = [ + add, + multiply, + divide, + subtract, + sin, + cos, + radians, + exponentiation, + sqrt, + ceil, +] diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 34827e4e2cf9..dbf8b94b0673 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -1,8 +1,11 @@ +import asyncio from typing import Any, Dict +from uuid import uuid4 import torch import torch.nn.functional as F from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer +from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams from colossalai.utils import get_current_device @@ -43,6 +46,27 @@ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: pass +class AsyncInferenceBackend(BaseInferenceBackend): + async def generate( + self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs + ) -> Dict[str, torch.Tensor]: + """Generate new tokens given input_ids and attention_mask. + + Args: + input_ids (torch.Tensor): shape [B, S] + attention_mask (torch.Tensor): shape [B, S] + + Returns: + Dict[str, torch.Tensor]: containing the + - input_ids (torch.Tensor): shape [B, S+N] + - attention_mask (torch.Tensor): shape [B, S+N] + - action_log_probs (torch.Tensor): shape [B, N] + - action_mask (torch.Tensor): shape [B, N] + where N is the number of generated tokens. And all tensors should be on CUDA. + """ + raise NotImplementedError("AsyncInferenceBackend does not support generate method.") + + class TransformersInferenceBackend(BaseInferenceBackend): DEFAULT_MODEL_CONFIG = dict( trust_remote_code=True, @@ -59,6 +83,7 @@ def __init__( generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer, num_generations: int = 8, + microbatch_size: int = 1, ): model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG) model_config.update(self.FORCE_MODEL_CONFIG) @@ -132,6 +157,7 @@ def __init__( generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer, num_generations: int = 8, + microbatch_size: int = 1, ): if sgl is None: raise ImportError("sglang is not installed") @@ -196,6 +222,7 @@ def __init__( generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer, num_generations: int = 8, + microbatch_size: int = 1, ): if LLM is None: raise ImportError("vllm is not installed") @@ -280,8 +307,126 @@ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: self.llm.llm_engine.model_executor.driver_worker.model_runner.model.load_weights(state_dict.items()) +class AsyncVLLMInferenceBackend(AsyncInferenceBackend): + DEFAULT_MODEL_CONFIG = dict( + trust_remote_code=True, + enable_sleep_mode=False, + ) + FORCE_GENERATE_CONFIG = dict( + logprobs=0, + ) + + def __init__( + self, + model_config: Dict[str, Any], + generate_config: Dict[str, Any], + tokenizer: PreTrainedTokenizer, + num_generations: int = 8, + microbatch_size: int = 1, + ): + if LLM is None: + raise ImportError("vllm is not installed") + model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG) + path = model_config.pop("path") + engine_args = AsyncEngineArgs(model=path, disable_log_stats=True, **model_config) + self.engine = AsyncLLMEngine.from_engine_args(engine_args) + generate_config = generate_config.copy() + generate_config.update(self.FORCE_GENERATE_CONFIG) + generate_config.update({"n": num_generations}) + self.generate_config = generate_config + self.sample_params = SamplingParams(**generate_config) + self.model_config = model_config + self.tokenizer = tokenizer + self.num_generations = num_generations + self.queued_requests = [] + self.microbatch_size = microbatch_size + + @torch.no_grad() + async def generate( + self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs + ) -> Dict[str, torch.Tensor]: + """ + Generate text from the model asynchronously. + Args: + input_ids (torch.Tensor): shape [B, S], B=1 + attention_mask (torch.Tensor): shape [B, S] + """ + assert input_ids.size(0) == attention_mask.size(0) == 1 + response_start_idx = input_ids.size(1) + first_non_padding_token_idx = (input_ids != self.tokenizer.pad_token_id).int().argmax(dim=1) + input_ids_no_padding = [input_ids.tolist()[0][first_non_padding_token_idx[0] :]] + sample_params = kwargs.get("sample_params", self.sample_params) + out_tokens = [] + out_len = [] + log_probs = [] + response_idx = [] + while len(self.queued_requests) >= self.microbatch_size: + await asyncio.sleep(0.1) + request_id = str(uuid4()) + self.queued_requests.append(request_id) # enqueue + # pop the first input_ids and attention_mask + prompt_token_ids = input_ids_no_padding[0] + outputs = self.engine.generate( + prompt={"prompt_token_ids": prompt_token_ids}, sampling_params=sample_params, request_id=request_id + ) + async for chunk in outputs: + # generate the output tokens, can yield to avoid blocking + pass + self.queued_requests.remove(request_id) # dequeue + for output_i in chunk.outputs: + out_len.append(len(output_i.token_ids)) + out_tokens.append(list(output_i.token_ids)) + response_idx.append((response_start_idx, response_start_idx + len(output_i.token_ids) - 1)) + assert len(output_i.logprobs) == len(output_i.token_ids) + p = [m[t].logprob for m, t in zip(output_i.logprobs, output_i.token_ids)] + log_probs.append(p) + # pad them + max_len = self.sample_params.max_tokens + action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype) + + for i, new_token_ids in enumerate(out_tokens): + pad_len = max_len - out_len[i] + out_tokens[i] = new_token_ids + [self.tokenizer.pad_token_id] * pad_len + log_probs[i] = log_probs[i] + [0.0] * pad_len + action_mask[i, out_len[i] :] = 0 + + out_tokens = torch.tensor(out_tokens) + log_probs = torch.tensor(log_probs) + response_idx = torch.tensor(response_idx) + + if attention_mask.size(0) != action_mask.size(0): + assert action_mask.size(0) % attention_mask.size(0) == 0 + num_returns = action_mask.size(0) // attention_mask.size(0) + attention_mask = attention_mask.repeat_interleave(num_returns, dim=0) + input_ids = input_ids.repeat_interleave(num_returns, dim=0) + + out_tokens = torch.cat((input_ids, out_tokens), dim=1) + attention_mask = torch.cat((attention_mask, action_mask), dim=1) + + data = { + "input_ids": out_tokens, + "attention_mask": attention_mask, + "action_log_probs": log_probs, + "action_mask": action_mask, + "response_idx": response_idx, + } + + data = {k: v.view(1, -1, v.size(-1)) for k, v in data.items()} + data = {k: v.to(get_current_device()) for k, v in data.items()} + if "gt_answer" in kwargs: + data["gt_answer"] = kwargs["gt_answer"] + if "test_cases" in kwargs: + data["test_cases"] = kwargs["test_cases"] + + return data + + def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: + self.engine.engine.model_executor.driver_worker.model_runner.model.load_weights(state_dict.items()) + + BACKEND_MAP = { "transformers": TransformersInferenceBackend, # "sglang": SGLangInferenceBackend, # sglang backend will stuck the process due to unknown reason "vllm": VLLMInferenceBackend, + "async-vllm": AsyncVLLMInferenceBackend, } diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index d60312e2b0b1..8795af51f31a 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -7,7 +7,7 @@ from .consumer import SimpleConsumer from .grpo_consumer import GRPOConsumer -from .producer import SimpleProducer +from .producer import AsyncProducer, SimpleProducer ALGO_MAP = { "Simple": SimpleConsumer, @@ -16,6 +16,7 @@ "REINFORCE_PPB": GRPOConsumer, "RLOO": GRPOConsumer, } +Producer_MAP = {"Simple": SimpleProducer, "Async": AsyncProducer} def get_jsonl_size_fast(path: str) -> int: @@ -110,6 +111,10 @@ def launch_distributed( print(node_info) producer_procs = [] + if "async" in inference_backend: + core_producer = AsyncProducer + else: + core_producer = Producer_MAP.get("Simple", SimpleProducer) for i in range(num_producers): node_id = gpu_to_node_id[0] producer_ip_address = gpu_to_ip_address[0] @@ -117,7 +122,7 @@ def launch_distributed( gpu_to_node_id.pop(0) gpu_to_ip_address.pop(0) print(f"Schedual Producer P[{i}] which requires {num_proc_per_producer} GPUs on node {producer_ip_address}") - producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote( + producer = core_producer.options(num_gpus=num_proc_per_producer).remote( producer_idx=i, num_producers=num_producers, num_consumer_procs=num_consumer_procs, diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 38a85b9b1c4d..7d3bbaec2c27 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -1,3 +1,4 @@ +import asyncio import copy import json import os @@ -310,7 +311,10 @@ def loop(self) -> None: self.eval_mode = False self.latest_eval_step = self.consumer_global_step self.profiler.enter("rollout") - outputs = self.rollout(**batch) + if isinstance(self.model, BACKEND_MAP["async-vllm"]): + outputs = asyncio.run(self.rollout(**batch)) + else: + outputs = self.rollout(**batch) self.profiler.exit("rollout") outputs["temperature"] = torch.tensor( [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0) @@ -473,7 +477,9 @@ def __init__( enable_profiling=enable_profiling, n_behind=n_behind, ) - self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) + self.model = self.backend_cls( + model_config, generate_config, self.tokenizer, num_generations, self.microbatch_size + ) self.eval_generation_config = copy.deepcopy(self.model.generate_config) self.eval_generation_config["n"] = 1 # use 1 generation for evaluation self.eval_generation_config.update(eval_generation_config) @@ -511,3 +517,218 @@ def __del__(self): def load_state_dict(self, state_dict): self.model.load_state_dict(state_dict) + + +@ray.remote +class AsyncProducer(BaseProducer): + """ + Asyncronous version of the producer that uses vLLM for generation. + """ + + def __init__( + self, + producer_idx, + num_producers, + num_consumer_procs, + num_episodes, + batch_size, + train_dataset_config, + model_config, + generate_config, + tokenizer_config=None, + microbatch_size=1, + backend="transformers", + num_generations: int = 8, + consumer_plugin_config=None, + eval_dataset_config=None, + eval_interval=-1, # disable evaluation + grpo_config: Dict[str, Any] = None, + eval_save_dir: str = "./eval", + eval_generation_config={}, + project_name: str = None, + run_name: str = None, + wandb_group_name: str = None, + log_rollout_interval: int = 20, + rollout_log_file: str = "./rollout_log.jsonl", + enable_profiling: bool = False, + n_behind: int = 0, + ): + super().__init__( + producer_idx, + num_producers, + num_consumer_procs, + num_episodes, + batch_size, + train_dataset_config, + model_config, + generate_config, + tokenizer_config, + microbatch_size, + backend, + consumer_plugin_config, + eval_dataset_config=eval_dataset_config, + eval_interval=eval_interval, + grpo_config=grpo_config, + eval_save_dir=eval_save_dir, + project_name=project_name, + run_name=run_name, + wandb_group_name=wandb_group_name, + log_rollout_interval=log_rollout_interval, + rollout_log_file=rollout_log_file, + enable_profiling=enable_profiling, + n_behind=n_behind, + ) + assert backend == "async-vllm", f"AsyncProducer only supports async-vllm backend, got {backend}" + self.model = self.backend_cls( + model_config, generate_config, self.tokenizer, num_generations, self.microbatch_size + ) + self.eval_generation_config = copy.deepcopy(self.model.generate_config) + self.eval_generation_config["n"] = 1 # use 1 generation for evaluation + self.eval_generation_config.update(eval_generation_config) + self.eval_sample_params = SamplingParams(**self.eval_generation_config) + + @torch.no_grad() + async def rollout(self, input_ids, attention_mask, **kwargs): + tasks = [] + for prompt_id in range(input_ids.size(0)): + new_kwargs = copy.deepcopy(kwargs) + if "gt_answer" in new_kwargs: + new_kwargs["gt_answer"] = new_kwargs["gt_answer"][prompt_id] + if "test_cases" in new_kwargs: + new_kwargs["test_cases"] = new_kwargs["test_cases"][prompt_id] + tasks.append( + self.model.generate( + input_ids[prompt_id].unsqueeze(0), + attention_mask[prompt_id].unsqueeze(0), + **new_kwargs, + ) + ) + # print(f"Producer {self.producer_idx} running {len(tasks)} tasks") + rollouts = await asyncio.gather(*tasks) + rollouts = { + k: ( + torch.cat([r[k] for r in rollouts], dim=0) + if k not in ["gt_answer", "test_cases"] + else [r[k] for r in rollouts] + ) + for k in rollouts[0].keys() + } + if self.producer_idx == 0 and not self.eval_mode: + if ( + self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval + or self.latest_rollout_log_step == -1 + ): + new_record = ( + json.dumps( + { + "train_step": self.consumer_global_step, + "rollout": self.tokenizer.batch_decode( + rollouts["input_ids"][:, 0], skip_special_tokens=True + ), + } + ) + + "\n" + ) + self.rollout_log_file.write(new_record) + self.rollout_log_file.flush() + self.latest_rollout_log_step = self.consumer_global_step + return rollouts + + def __del__(self): + if self.producer_idx == 0: + self.wandb_run.finish() + if hasattr(self, "rollout_log_file"): + self.rollout_log_file.close() + + def load_state_dict(self, state_dict): + self.model.load_state_dict(state_dict) + + +@ray.remote +class AsyncServer: + """ + A async worker for inference only + """ + + def __init__( + self, + producer_idx, + num_producers, + model_config, + generate_config, + tokenizer_config=None, + microbatch_size=1, + backend="transformers", + num_generations: int = 8, + eval_generation_config={}, + ): + tokenizer_path = model_config["path"] + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_config) + self.tokenizer.padding_side = "left" + self.microbatch_size = microbatch_size + self.producer_idx = producer_idx + self.num_producers = num_producers + assert backend == "async-vllm", f"AsyncProducer only supports async-vllm backend, got {backend}" + self.model = self.backend_cls( + model_config, generate_config, self.tokenizer, num_generations, self.microbatch_size + ) + self.eval_generation_config = copy.deepcopy(self.model.generate_config) + self.eval_generation_config["n"] = 1 # use 1 generation for evaluation + self.eval_generation_config.update(eval_generation_config) + self.eval_sample_params = SamplingParams(**self.eval_generation_config) + + @torch.no_grad() + async def rollout(self, input_ids, attention_mask, **kwargs): + tasks = [] + for prompt_id in range(input_ids.size(0)): + new_kwargs = copy.deepcopy(kwargs) + if "gt_answer" in new_kwargs: + new_kwargs["gt_answer"] = new_kwargs["gt_answer"][prompt_id] + if "test_cases" in new_kwargs: + new_kwargs["test_cases"] = new_kwargs["test_cases"][prompt_id] + tasks.append( + self.model.generate( + input_ids[prompt_id].unsqueeze(0), + attention_mask[prompt_id].unsqueeze(0), + **new_kwargs, + ) + ) + # print(f"Producer {self.producer_idx} running {len(tasks)} tasks") + rollouts = await asyncio.gather(*tasks) + rollouts = { + k: ( + torch.cat([r[k] for r in rollouts], dim=0) + if k not in ["gt_answer", "test_cases"] + else [r[k] for r in rollouts] + ) + for k in rollouts[0].keys() + } + if self.producer_idx == 0 and not self.eval_mode: + if ( + self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval + or self.latest_rollout_log_step == -1 + ): + new_record = ( + json.dumps( + { + "train_step": self.consumer_global_step, + "rollout": self.tokenizer.batch_decode( + rollouts["input_ids"][:, 0], skip_special_tokens=True + ), + } + ) + + "\n" + ) + self.rollout_log_file.write(new_record) + self.rollout_log_file.flush() + self.latest_rollout_log_step = self.consumer_global_step + return rollouts + + def __del__(self): + if self.producer_idx == 0: + self.wandb_run.finish() + if hasattr(self, "rollout_log_file"): + self.rollout_log_file.close() + + def load_state_dict(self, state_dict): + self.model.load_state_dict(state_dict) diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index f7a2fb89cadb..9aa39788f8b3 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -182,7 +182,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): raise ValueError("no gt_answer is provided, please check your training dataset.") decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) - + # print(f"decoded_final_answer: {decoded_final_answer[-100:]}", gt_answer) final_answer = extract_boxed_solution(decoded_final_answer) format_valid = final_answer is not None if "tags" in kwargs and kwargs["tags"]: diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index b584b940ccaa..54ef4e303771 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -109,7 +109,9 @@ ) # Sampling parameters - parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"]) + parser.add_argument( + "-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm", "async-vllm"] + ) parser.add_argument("-temp", "--temperature", type=float, default=1.0, help="Temperature for sampling.") parser.add_argument( "-topk", @@ -135,6 +137,13 @@ default=1, help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.", ) + parser.add_argument( + "-pdp", + "--producer-data-parallel-size", + type=int, + default=1, + help="Data parallel size for the producer. Increase this value to scale up the data parallelism of the inference backend.", + ) # GRPO parameters parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO", "REINFORCE_PPB", "RLOO"]) @@ -206,8 +215,8 @@ namespace="ray-example", runtime_env={ "env_vars": { - # "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray - "TOKENIZERS_PARALLELISM": "false" + "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray + "TOKENIZERS_PARALLELISM": "false", }, }, ) @@ -219,8 +228,8 @@ _temp_dir=args.ray_dir, runtime_env={ "env_vars": { - # "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray - "TOKENIZERS_PARALLELISM": "false" + "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray + "TOKENIZERS_PARALLELISM": "false", }, }, ) @@ -228,7 +237,7 @@ if args.top_k is None: if args.backend == "transformers": args.top_k = 50 - elif args.backend == "vllm": + elif args.backend == "vllm" or args.backend == "async-vllm": args.top_k = -1 os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock @@ -256,7 +265,8 @@ ) ) eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation - elif args.backend == "vllm": + elif args.backend == "vllm" or args.backend == "async-vllm": + # os.environ["VLLM_DP_SIZE"] = str(args.producer_data_parallel_size) inference_model_config.update( dict( gpu_memory_utilization=0.7, @@ -396,7 +406,8 @@ launch_distributed( num_producers=args.num_inferencer, - num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size), + num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size) + * inference_model_config.get("data_parallel_size", args.producer_data_parallel_size), num_consumer_procs=args.num_trainers, num_episodes=args.num_episodes, inference_batch_size=args.inference_batch_size, From f3155409b5bc1c2df6081a46402156ddd3cb5de9 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Wed, 3 Sep 2025 15:12:46 +0800 Subject: [PATCH 02/11] support agentic with asyncllm --- .../ColossalChat/coati/dataset/loader.py | 63 ++- .../coati/distributed/agent/=0.3, | 0 .../coati/distributed/agent/agentic.py | 199 ++++++++ .../distributed/agent/agentic_math_utils.py | 170 +++++++ .../coati/distributed/agent/model.py | 149 ------ .../distributed/agent/test_api_based_agent.py | 126 +++++ .../coati/distributed/agent/tools.py | 112 ----- .../coati/distributed/consumer.py | 4 +- .../coati/distributed/inference_backend.py | 14 +- .../ColossalChat/coati/distributed/launch.py | 74 ++- .../coati/distributed/producer.py | 440 ++++++++++++------ applications/ColossalChat/rl_example.py | 34 +- 12 files changed, 947 insertions(+), 438 deletions(-) create mode 100644 applications/ColossalChat/coati/distributed/agent/=0.3, create mode 100644 applications/ColossalChat/coati/distributed/agent/agentic.py create mode 100644 applications/ColossalChat/coati/distributed/agent/agentic_math_utils.py delete mode 100644 applications/ColossalChat/coati/distributed/agent/model.py create mode 100644 applications/ColossalChat/coati/distributed/agent/test_api_based_agent.py delete mode 100644 applications/ColossalChat/coati/distributed/agent/tools.py diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py index 16fd385bad30..1283b9be0e36 100755 --- a/applications/ColossalChat/coati/dataset/loader.py +++ b/applications/ColossalChat/coati/dataset/loader.py @@ -4,6 +4,7 @@ Dataloader for sft, dpo, ppo """ +import copy import os from dataclasses import dataclass from typing import Dict, Iterator, List, Optional, Sequence, Union @@ -423,7 +424,9 @@ class RawConversationDataset(Dataset): Each instance is a dictionary with fields `system`, `roles`, `messages`, `offset`, `sep_style`, `seps`. """ - def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int, system_prompt: str) -> None: + def __init__( + self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int, system_prompt: str, tokenize=True + ) -> None: self.tokenizer = tokenizer self.raw_texts = [] with jsonlines.open(input_file) as f: @@ -432,30 +435,50 @@ def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: self.tokenized_texts = [None] * len(self.raw_texts) self.max_length = max_length self.system_prompt = system_prompt + self.tokenize = tokenize def __len__(self) -> int: return len(self.raw_texts) def __getitem__(self, index: int): - if self.tokenized_texts[index] is None: - message = self.raw_texts[index] - tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt) - self.tokenized_texts[index] = dict(tokens) - return self.tokenized_texts[index] + if self.tokenize: + if self.tokenized_texts[index] is None: + message = self.raw_texts[index] + tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt) + self.tokenized_texts[index] = dict(tokens) + return self.tokenized_texts[index] + else: + chat = copy.deepcopy(self.raw_texts[index]) + chat["messages"] = [{"role": "system", "content": self.system_prompt}, chat["messages"]] + return chat def collate_fn_grpo(batch): - input_ids = [item["input_ids"] for item in batch] - attention_mask = [item["attention_mask"] for item in batch] - labels = [item["labels"] for item in batch] - # Assume input_ids, attention_mask, labels are already of the same length, - # otherwise use pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) - input_ids = torch.stack(input_ids) - attention_mask = torch.stack(attention_mask) - labels = torch.stack(labels) - ret = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} - if "test_cases" in batch[0]: - ret["test_cases"] = [item["test_cases"] for item in batch] - if "gt_answer" in batch[0]: - ret["gt_answer"] = [item["gt_answer"] for item in batch] - return ret + if "input_ids" in batch[0]: + # tokenized format + input_ids = [item["input_ids"] for item in batch] + attention_mask = [item["attention_mask"] for item in batch] + labels = [item["labels"] for item in batch] + # Assume input_ids, attention_mask, labels are already of the same length, + # otherwise use pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) + input_ids = torch.stack(input_ids) + attention_mask = torch.stack(attention_mask) + labels = torch.stack(labels) + ret = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} + if "test_cases" in batch[0]: + ret["test_cases"] = [item["test_cases"] for item in batch] + if "gt_answer" in batch[0]: + ret["gt_answer"] = [item["gt_answer"] for item in batch] + return ret + elif "messages" in batch[0]: + # vllm format + ret = { + "messages": [item["messages"] for item in batch], + } + if "test_cases" in batch[0]: + ret["test_cases"] = [item["test_cases"] for item in batch] + if "gt_answer" in batch[0]: + ret["gt_answer"] = [item["gt_answer"] for item in batch] + return ret + else: + raise ValueError("Unsupported batch format") diff --git a/applications/ColossalChat/coati/distributed/agent/=0.3, b/applications/ColossalChat/coati/distributed/agent/=0.3, new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/applications/ColossalChat/coati/distributed/agent/agentic.py b/applications/ColossalChat/coati/distributed/agent/agentic.py new file mode 100644 index 000000000000..f348eb69d946 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/agent/agentic.py @@ -0,0 +1,199 @@ +import copy +import json +from typing import Any, Dict + +import ray +import torch +from coati.distributed.agent.agentic_math_utils import TIR_SYSTEM, CustomTransformers +from coati.distributed.producer import BaseProducer +from qwen_agent.agents import TIRMathAgent +from vllm import SamplingParams + + +@ray.remote +class AgenticProducer(BaseProducer): + """ + Asyncronous version of the producer that uses vLLM for generation. + This class is designed to generate agentic response + """ + + def __init__( + self, + producer_idx, + num_producers, + num_consumer_procs, + num_episodes, + batch_size, + train_dataset_config, + model_config, + generate_config, + async_producers, + tokenizer_config=None, + agentic_config=None, + microbatch_size=1, + backend="transformers", + num_generations: int = 8, + consumer_plugin_config=None, + eval_dataset_config=None, + eval_interval=-1, # disable evaluation + grpo_config: Dict[str, Any] = None, + eval_save_dir: str = "./eval", + eval_generation_config={}, + project_name: str = None, + run_name: str = None, + wandb_group_name: str = None, + log_rollout_interval: int = 20, + rollout_log_file: str = "./rollout_log.jsonl", + enable_profiling: bool = False, + n_behind: int = 0, + ): + assert microbatch_size == 1 # microbatch_size must be 1 for agentic producer + assert batch_size == 1 # batch_size must be 1 for agentic producer + super().__init__( + producer_idx, + num_producers, + num_consumer_procs, + num_episodes, + batch_size, + train_dataset_config, + model_config, + generate_config, + tokenizer_config, + microbatch_size, + backend, + consumer_plugin_config, + eval_dataset_config=eval_dataset_config, + eval_interval=eval_interval, + grpo_config=grpo_config, + eval_save_dir=eval_save_dir, + project_name=project_name, + run_name=run_name, + wandb_group_name=wandb_group_name, + log_rollout_interval=log_rollout_interval, + rollout_log_file=rollout_log_file, + enable_profiling=enable_profiling, + n_behind=n_behind, + enable_agentic=True, + ) + self.eval_generation_config = copy.deepcopy(generate_config) + self.eval_generation_config["n"] = 1 # use 1 generation for evaluation + self.eval_generation_config.update(eval_generation_config) + self.eval_sample_params = SamplingParams(**self.eval_generation_config) + self.async_producers = async_producers + self.num_generations = num_generations + self.generate_config = generate_config + self.agentic_config = model_config if not agentic_config else agentic_config + self.agentic_config.update({"model": model_config["path"]}) + self.llm = CustomTransformers(self.agentic_config, self.producer_idx, generation_workers=self.async_producers) + self.bot = TIRMathAgent(llm=self.llm, name=model_config["path"], system_message=TIR_SYSTEM) + + def rollout(self, **kwargs) -> Dict[str, torch.Tensor]: + """ + Rollout function to generate responses for the input, for example, using LLM or agentic pipeline. + This function should be implemented in subclasses. + """ + assert len(kwargs["messages"]) == 1, "Only support batch size of 1 for agentic producer" + messages = kwargs["messages"][0] + prompt_input_ids = self.tokenizer.apply_chat_template( + messages, return_tensors="pt", tokenize=True, add_generation_prompt=True + ) + # add left padding + prompt_length = prompt_input_ids.shape[1] + max_prompt_length = self.train_dataset_config["max_length"] + to_pad_left = max_prompt_length - prompt_length + rollouts = { + "input_ids": [], + "attention_mask": [], + "action_mask": [], + "action_log_probs": [], + "response_idx": [], + } + for i in range(self.num_generations): + _messages = copy.deepcopy(messages) + for response in self.bot.run(messages): + continue + _messages.extend(response) + response_input_ids = self.tokenizer.apply_chat_template(_messages, return_tensors="pt", tokenize=True) + # truncate if too long + response_input_ids = response_input_ids[:, : self.grpo_config["max_length"] - to_pad_left] + # add left right padding + to_pad_right = self.grpo_config["max_length"] - response_input_ids.shape[1] - to_pad_left + response_length = response_input_ids.shape[1] - prompt_length + input_ids = torch.nn.functional.pad( + response_input_ids, (to_pad_left, to_pad_right), "constant", value=self.tokenizer.pad_token_id + ) # [1, max_length] + attention_mask = torch.nn.functional.pad( + torch.ones_like(response_input_ids), (to_pad_left, to_pad_right), "constant", value=0 + ) # [1, max_length] + action_mask = torch.nn.functional.pad( + torch.ones(size=(1, response_length)), (0, to_pad_right), "constant", value=0 + ) # [1, max_length-prompt_length] + rollouts["attention_mask"].append(attention_mask) + rollouts["action_mask"].append(action_mask) + rollouts["action_log_probs"].append( + torch.ones(size=(1, self.grpo_config["max_length"] - max_prompt_length)) + ) # dummy log probs + rollouts["response_idx"].append( + torch.tensor( + [ + [ + self.train_dataset_config["max_length"], + self.train_dataset_config["max_length"] + response_length, + ] + ] + ) + ) # [1, 2] + rollouts["input_ids"].append(input_ids) + # breakpoint() + rollouts = {k: torch.cat(v, dim=0).unsqueeze(0) for k, v in rollouts.items()} # [num_generations, ...] + rollouts["temperature"] = torch.tensor([self.agentic_config.get("temperature", 1.0)]) + if hasattr(self, "rollout_log_file") and self.producer_idx == 0 and not self.eval_mode: + # for agentic producer, AsyncSimpleProducer is not the main producer, so we don't log rollouts + if ( + self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval + or self.latest_rollout_log_step == -1 + ): + new_record = ( + json.dumps( + { + "train_step": self.consumer_global_step, + "rollout": self.tokenizer.batch_decode( + rollouts["input_ids"][:, 0], skip_special_tokens=True + ), + } + ) + + "\n" + ) + self.rollout_log_file.write(new_record) + self.rollout_log_file.flush() + self.latest_rollout_log_step = self.consumer_global_step + + if "gt_answer" in kwargs: + rollouts["gt_answer"] = kwargs["gt_answer"] + if "test_cases" in kwargs: + rollouts["test_cases"] = kwargs["test_cases"] + return rollouts + + def sync_model(self, episode, step) -> None: + """ + sync model from consumer to self.async_producers + AgenticProducer does not hold any model weights, so no need to sync model to self.async_producers + """ + tasks = [] + for proc in self.async_producers: + tasks.append(proc.async_sync_model.remote(episode, step, self.num_producers)) + ray.get(tasks) + return + + def sync_data(self, data: Dict[str, torch.Tensor]) -> None: + """ + sync data from self to consumer + """ + tasks = [] + for idx, proc in enumerate(self.async_producers): + if idx == self.producer_idx % len(self.async_producers): + tasks.append(proc.async_sync_data.remote(data, self.num_producers)) + else: + tasks.append(proc.async_sync_data.remote({}, self.num_producers)) + ray.get(tasks) + return diff --git a/applications/ColossalChat/coati/distributed/agent/agentic_math_utils.py b/applications/ColossalChat/coati/distributed/agent/agentic_math_utils.py new file mode 100644 index 000000000000..eb44f8a93092 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/agent/agentic_math_utils.py @@ -0,0 +1,170 @@ +# Copyright 2023 The Qwen team, Alibaba Group. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A TIR(tool-integrated reasoning) math agent +```bash +python tir_math.py +``` +""" +import os +import random + +import ray +from qwen_agent.agents import TIRMathAgent +from qwen_agent.llm.base import register_llm +from qwen_agent.llm.function_calling import BaseFnCallModel +from qwen_agent.llm.transformers_llm import Transformers +from qwen_agent.log import logger +from transformers import AutoTokenizer + +ROOT_RESOURCE = os.path.join(os.path.dirname(__file__), "resource") + +# We use the following two systems to distinguish between COT mode and TIR mode +TIR_SYSTEM = """Please integrate natural language reasoning with programs to solve the problem above, and put your final answer within \\boxed{}.""" +COT_SYSTEM = """Please reason step by step, and put your final answer within \\boxed{}.""" + +from transformers import StoppingCriteria + +tokenizer = AutoTokenizer.from_pretrained("/mnt/nfs/share/data/model/Qwen2.5-Math-7B-Instruct", trust_remote_code=True) + + +class StopOnTokens(StoppingCriteria): + def __init__(self, stop_token_ids): + self.stop_token_ids = stop_token_ids + + def __call__(self, input_ids, scores, **kwargs): + # Check if the last token is one of the stop tokens + if input_ids[0, -1].item() in self.stop_token_ids: + return True + return False + + +class LocalLLMFromGenerationWorkers: + """ + A class that wraps the Transformers model to support API-based text generation. + """ + + def __init__(self, generation_worker=None): + self.device = "cpu" + self.generation_worker = generation_worker + + def generate(self, **kwargs): + rollouts = ray.get(self.generation_worker.generate.remote(**kwargs)) + return rollouts["input_ids"] + + +@register_llm("api_based_transformers") +class CustomTransformers(Transformers): + """ + Transformers class that supports API-based text generation. + """ + + def __init__(self, cfg: dict, producer_idx, generation_workers=None): + BaseFnCallModel.__init__(self, cfg) # skip the super() init of Transformers to avoid loading hf model + ############ Setup logic from Transformers.__init__ ############### + if "model" not in cfg: + raise ValueError("Please provide the model id or directory through `model` in cfg.") + + try: + from transformers import AutoConfig, AutoProcessor, PreTrainedTokenizer, PreTrainedTokenizerFast + except ImportError as e: + raise ImportError( + "Could not import classes from transformers. " "Please install it with `pip install -U transformers`" + ) from e + + self.hf_config = AutoConfig.from_pretrained(cfg["model"]) + arch = self.hf_config.architectures[0] + if len(self.hf_config.architectures) > 1: + logger.warning( + f"The config for the transformers model type contains more than one architecture, choosing the first: {arch}" + ) + + # try loading a processor, if got a tokenizer, regarding the model as text-only + processor = AutoProcessor.from_pretrained(cfg["model"]) + if isinstance(processor, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + logger.info(f"Regarding the transformers model as text-only since its processor is a tokenizer.") + self.tokenizer = processor + self._support_multimodal_input = False + else: + self.processor = processor + self.tokenizer = self.processor.tokenizer + self._support_multimodal_input = True + ################################################################ + self.generation_workers = generation_workers + self.hf_models = [ + LocalLLMFromGenerationWorkers(generation_worker=generation_worker) + for generation_worker in generation_workers + ] + self.producer_idx = producer_idx + self.load_balancer_idx = producer_idx % len(self.generation_workers) + + @property + def hf_model(self): + # Simple round-robin load balancing + model = self.hf_models[self.load_balancer_idx] + return model + + def _chat_stream( + self, + messages, + delta_stream: bool, + generate_cfg: dict, + ): + # overwrite streaming because streamer is not serializable + # determine load balancer idx based on producer load, refresh every generation + load = [ray.get(generation_worker.get_producer_load.remote()) for generation_worker in self.generation_workers] + min_load = min(load) + candidates = [i for i, l in enumerate(load) if l == min_load] + # random tie break + self.load_balancer_idx = random.choice(candidates) + response = self._chat_no_stream(messages=messages, generate_cfg=generate_cfg) + # if self.producer_idx == 0: + # print(response) + yield response + + +def init_agent_service(): + llm_cfg = { + # Use the OpenAI-compatible model service provided by DashScope: + "model": "/mnt/nfs/share/data/model/Qwen2.5-Math-7B-Instruct", + "model_type": "transformers", + "generate_cfg": { + # Using the API's native tool call interface + "top_k": 1, + }, + } + llm = CustomTransformers(llm_cfg) + bot = TIRMathAgent(llm=llm, name="Qwen2.5-Math", system_message=TIR_SYSTEM) + return bot + + +def app_tui(): + # Define the agent + bot = init_agent_service() + + # Chat + messages = [] + while True: + # Query example: 斐波那契数列前10个数字 + query = input("user question: ") + messages.append({"role": "user", "content": query}) + response = [] + for response in bot.run(messages): + print("bot response:", response) + messages.extend(response) + + +# if __name__ == '__main__': +# # Test the TIR math agent locally +# app_tui() diff --git a/applications/ColossalChat/coati/distributed/agent/model.py b/applications/ColossalChat/coati/distributed/agent/model.py deleted file mode 100644 index c52d0d99df12..000000000000 --- a/applications/ColossalChat/coati/distributed/agent/model.py +++ /dev/null @@ -1,149 +0,0 @@ -""" -MIT License - -Copyright (c) 2025 LangChain - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. -""" - -from typing import Any, Dict, Iterator, List, Optional - -from langchain_core.callbacks import CallbackManagerForLLMRun -from langchain_core.language_models import BaseChatModel -from langchain_core.messages import AIMessage, BaseMessage -from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from pydantic import Field - - -class LangChainChatModel(BaseChatModel): - """A custom chat model that echoes the first `parrot_buffer_length` characters - of the input. - - When contributing an implementation to LangChain, carefully document - the model including the initialization parameters, include - an example of how to initialize the model and include any relevant - links to the underlying models documentation or API. - - Example: - - .. code-block:: python - - model = LangChainChatModel(parrot_buffer_length=2, model="bird-brain-001") - result = model.invoke([HumanMessage(content="hello")]) - result = model.batch([[HumanMessage(content="hello")], - [HumanMessage(content="world")]]) - """ - - model_name: str = Field(alias="model") - temperature: Optional[float] = None - max_tokens: Optional[int] = None - timeout: Optional[int] = None - stop: Optional[List[str]] = None - async_server_manager: Optional[Any] = None - max_retries: int = 2 - - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - """Override the _generate method to implement the chat model logic. - - This can be a call to an API, a call to a local model, or any other - implementation that generates a response to the input prompt. - - Args: - messages: the prompt composed of a list of messages. - stop: a list of strings on which the model should stop generating. - If generation stops due to a stop token, the stop token itself - SHOULD BE INCLUDED as part of the output. This is not enforced - across models right now, but it's a good practice to follow since - it makes it much easier to parse the output of the model - downstream and understand why generation stopped. - run_manager: A run manager with callbacks for the LLM. - """ - self.async_server_manager.generate(messages, stop, run_manager, **kwargs) - tokens = last_message.content[: self.parrot_buffer_length] - ct_input_tokens = sum(len(message.content) for message in messages) - ct_output_tokens = len(tokens) - message = AIMessage( - content=tokens, - additional_kwargs={}, # Used to add additional payload to the message - response_metadata={ # Use for response metadata - "time_in_seconds": 3, - "model_name": self.model_name, - }, - usage_metadata={ - "input_tokens": ct_input_tokens, - "output_tokens": ct_output_tokens, - "total_tokens": ct_input_tokens + ct_output_tokens, - }, - ) - ## - - generation = ChatGeneration(message=message) - return ChatResult(generations=[generation]) - - def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - """Stream the output of the model. - - This method should be implemented if the model can generate output - in a streaming fashion. If the model does not support streaming, - do not implement it. In that case streaming requests will be automatically - handled by the _generate method. - - Args: - messages: the prompt composed of a list of messages. - stop: a list of strings on which the model should stop generating. - If generation stops due to a stop token, the stop token itself - SHOULD BE INCLUDED as part of the output. This is not enforced - across models right now, but it's a good practice to follow since - it makes it much easier to parse the output of the model - downstream and understand why generation stopped. - run_manager: A run manager with callbacks for the LLM. - """ - raise NotImplementedError("Streaming is not implemented for this model. Please implement the _stream method.") - - @property - def _llm_type(self) -> str: - """Get the type of language model used by this chat model.""" - return "echoing-chat-model-advanced" - - @property - def _identifying_params(self) -> Dict[str, Any]: - """Return a dictionary of identifying parameters. - - This information is used by the LangChain callback system, which - is used for tracing purposes make it possible to monitor LLMs. - """ - return { - # The model name allows users to specify custom token counting - # rules in LLM monitoring applications (e.g., in LangSmith users - # can provide per token pricing for their model and monitor - # costs for the given LLM.) - "model_name": self.model_name, - } diff --git a/applications/ColossalChat/coati/distributed/agent/test_api_based_agent.py b/applications/ColossalChat/coati/distributed/agent/test_api_based_agent.py new file mode 100644 index 000000000000..5e63bb5a366c --- /dev/null +++ b/applications/ColossalChat/coati/distributed/agent/test_api_based_agent.py @@ -0,0 +1,126 @@ +# ------------------------------- +# 1. Define the Python tool +# ------------------------------- +import io +import sys +from typing import Dict, List + +import requests +from langchain_core.tools import tool +from langgraph.checkpoint.memory import MemorySaver +from langgraph.prebuilt import create_react_agent + + +class Capturing(list): + """Capture stdout prints inside exec()""" + + def __enter__(self): + self._stdout = sys.stdout + sys.stdout = self._stringio = io.StringIO() + return self + + def __exit__(self, *args): + self.extend(self._stringio.getvalue().splitlines()) + sys.stdout = self._stdout + + +@tool +def python(code: str) -> str: + """ + This function executes a string of Python code and returns the printed output. + You need to print the output. Please import all libraries used in the code string. + """ + local_vars = {} + with Capturing() as output: + exec(code, {}, local_vars) + if output == []: + return "Error: No output printed from the code. Please ensure you print the output." + return "\n".join(output) + + +# ------------------------------- +# 2. Define a Custom API LLM wrapper +# ------------------------------- +class CustomAPILLM: + def __init__(self, api_url: str, api_key: str = None): + self.api_url = api_url + self.api_key = api_key + + def invoke(self, messages: List[Dict[str, str]]) -> str: + """ + messages: list of {"role": "user"/"assistant"/"system", "content": "..."} + """ + headers = {"Content-Type": "application/json"} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + + payload = { + "model": "custom-model", # depends on your API + "messages": messages, + "temperature": 0, + } + + response = requests.post(self.api_url, headers=headers, json=payload) + response.raise_for_status() + data = response.json() + + # Adjust according to your API response format + return data["choices"][0]["message"]["content"] + + +# ------------------------------- +# 3. Build a ReAct Agent with LangGraph +# ------------------------------- +def build_agent(): + # Wrap custom API LLM in LangChain-compatible interface + from langchain_core.language_models import BaseChatModel + from langchain_core.messages import AIMessage + + class LangChainCustomLLM(BaseChatModel): + client: CustomAPILLM = None + + def __init__(self, client: CustomAPILLM): + super().__init__() + self.client = client + + def _generate(self, messages, stop=None, run_manager=None, **kwargs): + content = self.client.invoke([m.dict() for m in messages]) + return self._create_chat_result([AIMessage(content=content)]) + + @property + def _llm_type(self) -> str: + return "custom-api-llm" + + # Init LLM + llm_client = CustomAPILLM(api_url="http://localhost:8000/v1/chat/completions") + llm = LangChainCustomLLM(llm_client) + + # Tools + tools = [python] + + # Memory (optional) + memory = MemorySaver() + + # Build ReAct agent + agent = create_react_agent(llm, tools, checkpointer=memory) + return agent + + +# ------------------------------- +# 4. Run the agent on a math problem +# ------------------------------- +if __name__ == "__main__": + agent = build_agent() + + # Example math question + user_input = "What is the least common multiple of 18 and 24? Use Python if needed." + + config = {"configurable": {"thread_id": "math-1"}} + for event in agent.stream({"messages": [("user", user_input)]}, config): + if "agent" in event: + print("Agent event:", event["agent"]["messages"][-1].content) + elif "tools" in event: + print("Tool event:", event["tools"]["messages"][-1].content) + + final_state = agent.get_state(config) + print("Final Answer:", final_state["messages"][-1].content) diff --git a/applications/ColossalChat/coati/distributed/agent/tools.py b/applications/ColossalChat/coati/distributed/agent/tools.py deleted file mode 100644 index a39a32f335f6..000000000000 --- a/applications/ColossalChat/coati/distributed/agent/tools.py +++ /dev/null @@ -1,112 +0,0 @@ -""" -MIT License - -Copyright (c) 2025 LangChain - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. -""" - -import builtins -import contextlib -import io -import math -from typing import Any - - -def eval(code: str, _locals: dict[str, Any]) -> tuple[str, dict[str, Any]]: - # Store original keys before execution - original_keys = set(_locals.keys()) - - try: - with contextlib.redirect_stdout(io.StringIO()) as f: - exec(code, builtins.__dict__, _locals) - result = f.getvalue() - if not result: - result = "" - except Exception as e: - result = f"Error during execution: {repr(e)}" - - # Determine new variables created during execution - new_keys = set(_locals.keys()) - original_keys - new_vars = {key: _locals[key] for key in new_keys} - return result, new_vars - - -def add(a: float, b: float) -> float: - """Add two numbers together.""" - return a + b - - -def multiply(a: float, b: float) -> float: - """Multiply two numbers together.""" - return a * b - - -def divide(a: float, b: float) -> float: - """Divide two numbers.""" - return a / b - - -def subtract(a: float, b: float) -> float: - """Subtract two numbers.""" - return a - b - - -def sin(a: float) -> float: - """Take the sine of a number.""" - return math.sin(a) - - -def cos(a: float) -> float: - """Take the cosine of a number.""" - return math.cos(a) - - -def radians(a: float) -> float: - """Convert degrees to radians.""" - return math.radians(a) - - -def exponentiation(a: float, b: float) -> float: - """Raise one number to the power of another.""" - return a**b - - -def sqrt(a: float) -> float: - """Take the square root of a number.""" - return math.sqrt(a) - - -def ceil(a: float) -> float: - """Round a number up to the nearest integer.""" - return math.ceil(a) - - -tools = [ - add, - multiply, - divide, - subtract, - sin, - cos, - radians, - exponentiation, - sqrt, - ceil, -] diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 21da67161956..6a885f23b1e4 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -150,6 +150,7 @@ def loop(self) -> None: self.profiler.enter("sync_model") torch.cuda.empty_cache() state_dict = self.state_dict() + print(f"[C{self.rank}]: Sync model before training") if self.pp_size > 1: if self.tp_rank == 0 and self.dp_rank == 0: ray_broadcast_tensor_dict( @@ -164,6 +165,7 @@ def loop(self) -> None: state_dict, src=self.num_producers, device=self.device, group_name="sync_model" ) del state_dict + print(f"[C{self.rank}]: Sync model before training done") torch.cuda.empty_cache() self.profiler.exit("sync_model") @@ -323,7 +325,7 @@ def loop(self) -> None: ) # for setting start index when resuming training if self.rank == 0: print(f"Saved model checkpoint at step {step + 1} in folder {self.save_dir}") - + # breakpoint() if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and ( episode != 0 or step >= self.n_behind ): diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index dbf8b94b0673..c35c45bddf39 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -64,7 +64,7 @@ async def generate( - action_mask (torch.Tensor): shape [B, N] where N is the number of generated tokens. And all tensors should be on CUDA. """ - raise NotImplementedError("AsyncInferenceBackend does not support generate method.") + raise NotImplementedError("Generate method must be implemented in subclass.") class TransformersInferenceBackend(BaseInferenceBackend): @@ -84,6 +84,7 @@ def __init__( tokenizer: PreTrainedTokenizer, num_generations: int = 8, microbatch_size: int = 1, + profiler=None, ): model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG) model_config.update(self.FORCE_MODEL_CONFIG) @@ -93,6 +94,7 @@ def __init__( self.generate_config.update(self.FORCE_GENERATE_CONFIG) self.tokenizer = tokenizer self.num_generations = num_generations + self.profiler = profiler @torch.no_grad() def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: @@ -158,6 +160,7 @@ def __init__( tokenizer: PreTrainedTokenizer, num_generations: int = 8, microbatch_size: int = 1, + profiler=None, ): if sgl is None: raise ImportError("sglang is not installed") @@ -223,6 +226,7 @@ def __init__( tokenizer: PreTrainedTokenizer, num_generations: int = 8, microbatch_size: int = 1, + profiler=None, ): if LLM is None: raise ImportError("vllm is not installed") @@ -323,6 +327,7 @@ def __init__( tokenizer: PreTrainedTokenizer, num_generations: int = 8, microbatch_size: int = 1, + profiler=None, ): if LLM is None: raise ImportError("vllm is not installed") @@ -332,7 +337,8 @@ def __init__( self.engine = AsyncLLMEngine.from_engine_args(engine_args) generate_config = generate_config.copy() generate_config.update(self.FORCE_GENERATE_CONFIG) - generate_config.update({"n": num_generations}) + if "n" not in generate_config: + generate_config.update({"n": num_generations}) self.generate_config = generate_config self.sample_params = SamplingParams(**generate_config) self.model_config = model_config @@ -340,6 +346,7 @@ def __init__( self.num_generations = num_generations self.queued_requests = [] self.microbatch_size = microbatch_size + self.profiler = profiler @torch.no_grad() async def generate( @@ -351,6 +358,7 @@ async def generate( input_ids (torch.Tensor): shape [B, S], B=1 attention_mask (torch.Tensor): shape [B, S] """ + # breakpoint() assert input_ids.size(0) == attention_mask.size(0) == 1 response_start_idx = input_ids.size(1) first_non_padding_token_idx = (input_ids != self.tokenizer.pad_token_id).int().argmax(dim=1) @@ -366,6 +374,7 @@ async def generate( self.queued_requests.append(request_id) # enqueue # pop the first input_ids and attention_mask prompt_token_ids = input_ids_no_padding[0] + self.profiler.enter(f"vllm generate {request_id}") outputs = self.engine.generate( prompt={"prompt_token_ids": prompt_token_ids}, sampling_params=sample_params, request_id=request_id ) @@ -380,6 +389,7 @@ async def generate( assert len(output_i.logprobs) == len(output_i.token_ids) p = [m[t].logprob for m, t in zip(output_i.logprobs, output_i.token_ids)] log_probs.append(p) + self.profiler.exit(f"vllm generate {request_id}") # pad them max_len = self.sample_params.max_tokens action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype) diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 8795af51f31a..793cd932e92c 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -4,10 +4,11 @@ from typing import Any, Dict, Optional import ray +from coati.distributed.agent.agentic import AgenticProducer from .consumer import SimpleConsumer from .grpo_consumer import GRPOConsumer -from .producer import AsyncProducer, SimpleProducer +from .producer import AsyncSimpleProducer, SimpleProducer ALGO_MAP = { "Simple": SimpleConsumer, @@ -16,7 +17,7 @@ "REINFORCE_PPB": GRPOConsumer, "RLOO": GRPOConsumer, } -Producer_MAP = {"Simple": SimpleProducer, "Async": AsyncProducer} +Producer_MAP = {"Simple": SimpleProducer, "Async": AsyncSimpleProducer} def get_jsonl_size_fast(path: str) -> int: @@ -48,6 +49,7 @@ def launch_distributed( generate_config: Dict[str, Any], train_model_config: Dict[str, Any], grpo_config: Dict[str, Any], + agentic_config: Optional[Dict[str, Any]], plugin_config: Dict[str, Any], tokenizer_config: Optional[Dict[str, Any]] = None, inference_backend: str = "transformers", @@ -80,7 +82,7 @@ def launch_distributed( num_samples = get_jsonl_size_fast(dataset_path) global_inference_batch_size = inference_batch_size * num_producers num_update_per_episode = num_samples // global_inference_batch_size - num_recv_per_update = inference_batch_size // inference_microbatch_size + num_recv_per_update = inference_batch_size // inference_microbatch_size if "async" not in inference_backend else 1 run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}" wandb_group_name = str(uuid.uuid4()) @@ -112,9 +114,12 @@ def launch_distributed( producer_procs = [] if "async" in inference_backend: - core_producer = AsyncProducer + core_producer = AsyncSimpleProducer else: - core_producer = Producer_MAP.get("Simple", SimpleProducer) + core_producer = SimpleProducer + enable_agentic = "agentic" in inference_backend + if enable_agentic: + inference_backend = inference_backend.replace("agentic-", "") for i in range(num_producers): node_id = gpu_to_node_id[0] producer_ip_address = gpu_to_ip_address[0] @@ -132,7 +137,11 @@ def launch_distributed( model_config=inference_model_config, generate_config=generate_config, tokenizer_config=tokenizer_config, - microbatch_size=inference_microbatch_size, + microbatch_size=( + inference_microbatch_size * num_generations + if "async" in inference_backend + else inference_microbatch_size + ), backend=inference_backend, num_generations=num_generations, consumer_plugin_config=plugin_config, @@ -145,12 +154,63 @@ def launch_distributed( run_name=run_name, wandb_group_name=wandb_group_name, log_rollout_interval=log_rollout_interval, - rollout_log_file=rollout_log_file, + rollout_log_file=rollout_log_file if not enable_agentic else None, enable_profiling=enable_profiling, n_behind=n_behind, ) producer_procs.append(producer) ray.get([p.setup.remote() for p in producer_procs]) + """ + # test async generate + import torch + import asyncio + import time + async def test(): + res_ref = producer_procs[0].generate.remote(torch.ones((2, 10), dtype=torch.int), torch.ones((2, 10), dtype=torch.int)) + res = await res_ref + return res + res = asyncio.run(test()) + print(res) + time.sleep(1000) + """ + + if enable_agentic: + # when agentic is enabled, we use core_producer as inference engine and + # AgenticProducer as the real producer + _producer_procs = producer_procs + producer_procs = [ + AgenticProducer.options(num_cpus=1).remote( + producer_idx=producer_idx, + num_producers=num_producers * train_batch_size, + num_consumer_procs=num_consumer_procs, + num_episodes=num_episodes, + batch_size=1, # batch_size must be 1 for agentic producer + train_dataset_config=train_dataset_config, + model_config=inference_model_config, + generate_config=generate_config, + async_producers=_producer_procs, + tokenizer_config=tokenizer_config, + agentic_config=agentic_config, + microbatch_size=1, # microbatch_size must be 1 for agentic producer + backend=inference_backend, + num_generations=num_generations, + consumer_plugin_config=plugin_config, + eval_dataset_config=eval_dataset_config, + eval_interval=eval_interval, + grpo_config=grpo_config, + eval_save_dir=eval_save_dir, + eval_generation_config=eval_generation_config, + project_name=project_name, + run_name=run_name, + wandb_group_name=wandb_group_name, + log_rollout_interval=log_rollout_interval, + rollout_log_file=rollout_log_file, + enable_profiling=enable_profiling, + n_behind=n_behind, + ) + for producer_idx in range(num_producers * inference_batch_size) + ] + generate_config_consumer = copy.deepcopy(generate_config) generate_config_consumer.update( dict( diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 7d3bbaec2c27..44a6214ff6b0 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -57,6 +57,7 @@ def __init__( log_rollout_interval: int = 20, rollout_log_file: str = "./rollout_log.jsonl", enable_profiling: bool = False, + enable_agentic: bool = False, n_behind: int = 0, ): self.producer_idx = producer_idx @@ -65,8 +66,12 @@ def __init__( self.num_episodes = num_episodes self.batch_size = batch_size self.microbatch_size = microbatch_size - assert batch_size % microbatch_size == 0 - self.num_microbatches = batch_size // microbatch_size + if not isinstance(self, BaseAsyncProducer): + assert batch_size % microbatch_size == 0, "batch_size must be divisible by microbatch_size" + self.num_microbatches = batch_size // microbatch_size + else: + assert microbatch_size > 0, "microbatch_size must be positive" + self.num_microbatches = max(1, batch_size // microbatch_size) self.latest_eval_step = -1 self.profiler = CustomProfiler(f"P{self.producer_idx}", disabled=not enable_profiling) @@ -84,13 +89,14 @@ def __init__( self.latest_rollout_log_step = -1 self.grpo_config = grpo_config self.n_behind = n_behind + self.enable_agentic = enable_agentic reward_model_kwargs = { k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length", "code_verifier_api_url"] } self.response_format_tags = grpo_config.get("response_format_tags", None) - if producer_idx == 0: + if producer_idx == 0 and rollout_log_file is not None: if os.path.exists(rollout_log_file): raise ValueError( f"Rollout log file {rollout_log_file} already exists. Please delete it or change the name." @@ -124,7 +130,9 @@ def __init__( # init dataloader train_dataset_path = train_dataset_config.pop("path") - self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config) + self.train_dataset = RawConversationDataset( + self.tokenizer, train_dataset_path, **train_dataset_config, tokenize=not self.enable_agentic + ) self.train_dataloader = DataLoader( self.train_dataset, batch_size=microbatch_size, @@ -162,7 +170,10 @@ def __init__( for eval_task_name in self.eval_dataset_config: eval_dataset_path = eval_dataset_config[eval_task_name].pop("path") eval_dataset = RawConversationDataset( - self.tokenizer, eval_dataset_path, **eval_dataset_config[eval_task_name] + self.tokenizer, + eval_dataset_path, + **eval_dataset_config[eval_task_name], + tokenize=not self.enable_agentic, ) print(f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}") self.eval_dataloaders[eval_task_name] = DataLoader( @@ -210,18 +221,34 @@ def setup(self) -> None: else: cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model") + @torch.no_grad() + def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: + """ + Generate responses by running inference on the input_ids and attention_mask. + """ + return self.model.generate(input_ids, attention_mask, **kwargs) + def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: + """ + Rollout function to generate responses for the input, for example, using LLM or agentic pipeline. + This function should be implemented in subclasses. + """ raise NotImplementedError def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: raise NotImplementedError - def loop(self) -> None: - + def sync_model(self, episode, step) -> None: + """ + Default implementation to sync model from consumer to producer. + """ torch.cuda.empty_cache() self.profiler.enter("sync_model") if self.consumer_pp_size > 1: for pp_idx in range(self.consumer_pp_size): + print( + f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(step + 1) // self.num_microbatches - 1}" + ) state_dict = ray_broadcast_tensor_dict( None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}" ) @@ -229,6 +256,7 @@ def loop(self) -> None: self.consumer_global_step = state_dict.pop("consumer_global_step").item() self.load_state_dict(state_dict) else: + print(f"[P{self.producer_idx}] Sync model episode {episode} step {(step + 1) // self.num_microbatches - 1}") state_dict = ray_broadcast_tensor_dict( None, self.num_producers, device=self.device, group_name="sync_model" ) @@ -236,10 +264,18 @@ def loop(self) -> None: self.consumer_global_step = state_dict.pop("consumer_global_step").item() self.load_state_dict(state_dict) self.profiler.exit("sync_model") - print(f"[P{self.producer_idx}] Sync initial model done.") del state_dict torch.cuda.empty_cache() + def sync_data(self, data: Dict[str, torch.Tensor]) -> None: + """ + Default implementation to sync data from producer to consumer. + """ + ray_broadcast_tensor_dict(data, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}") + + def loop(self) -> None: + # breakpoint() + self.sync_model(0, 0) num_update_per_episode = len(self.train_dataloader) // self.num_microbatches num_valid_microbatches = num_update_per_episode * self.num_microbatches @@ -311,14 +347,12 @@ def loop(self) -> None: self.eval_mode = False self.latest_eval_step = self.consumer_global_step self.profiler.enter("rollout") - if isinstance(self.model, BACKEND_MAP["async-vllm"]): - outputs = asyncio.run(self.rollout(**batch)) - else: - outputs = self.rollout(**batch) + outputs = self.rollout(**batch) self.profiler.exit("rollout") - outputs["temperature"] = torch.tensor( - [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0) - ).to(outputs["input_ids"].device) + if "temperature" not in outputs: + outputs["temperature"] = torch.tensor( + [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0) + ).to(outputs["input_ids"].device) bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1) self.profiler.enter("calculate_reward") if self.grpo_config["reward_fn_type"] == "code": @@ -363,52 +397,16 @@ def loop(self) -> None: print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") outputs = pre_send(outputs) self.profiler.enter("send_broadcast_data") - ray_broadcast_tensor_dict( - outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}" - ) + self.sync_data(outputs) self.profiler.exit("send_broadcast_data") if ( (i + 1) % self.num_microbatches == 0 and (episode != self.num_episodes - 1 or i != num_valid_microbatches - 1) and (episode != 0 or (i + 1) > self.n_behind * self.num_microbatches) ): - if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( - "enable_sleep_mode", False - ): - self.model.llm.sleep() # revict KV_cache to avoid OOM - # don't sync model for last iteration - torch.cuda.empty_cache() - self.profiler.enter("sync_model") - if self.consumer_pp_size > 1: - for pp_idx in range(self.consumer_pp_size): - print( - f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(i + 1) // self.num_microbatches - 1}" - ) - state_dict = ray_broadcast_tensor_dict( - None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}" - ) - if "consumer_global_step" in state_dict: - self.consumer_global_step = state_dict.pop("consumer_global_step").item() - self.load_state_dict(state_dict) - else: - print( - f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}" - ) - state_dict = ray_broadcast_tensor_dict( - None, self.num_producers, device=self.device, group_name="sync_model" - ) - if "consumer_global_step" in state_dict: - self.consumer_global_step = state_dict.pop("consumer_global_step").item() - self.load_state_dict(state_dict) - self.profiler.exit("sync_model") - del state_dict - torch.cuda.empty_cache() - if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( - "enable_sleep_mode", False - ): - self.model.llm.wake_up() + self.sync_model(episode, i) # linear annealing for 1 episode, temperature from initial to 0.9 - if episode <= 0: + if episode <= 0 and hasattr(self, "model"): ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader) self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[ "temperature" @@ -478,7 +476,7 @@ def __init__( n_behind=n_behind, ) self.model = self.backend_cls( - model_config, generate_config, self.tokenizer, num_generations, self.microbatch_size + model_config, generate_config, self.tokenizer, num_generations, self.microbatch_size, profiler=self.profiler ) self.eval_generation_config = copy.deepcopy(self.model.generate_config) self.eval_generation_config["n"] = 1 # use 1 generation for evaluation @@ -487,7 +485,7 @@ def __init__( @torch.no_grad() def rollout(self, input_ids, attention_mask, **kwargs): - rollouts = self.model.generate(input_ids, attention_mask, **kwargs) + rollouts = self.generate(input_ids, attention_mask, **kwargs) if self.producer_idx == 0 and not self.eval_mode: if ( self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval @@ -519,8 +517,7 @@ def load_state_dict(self, state_dict): self.model.load_state_dict(state_dict) -@ray.remote -class AsyncProducer(BaseProducer): +class BaseAsyncProducer(BaseProducer): """ Asyncronous version of the producer that uses vLLM for generation. """ @@ -580,15 +577,39 @@ def __init__( ) assert backend == "async-vllm", f"AsyncProducer only supports async-vllm backend, got {backend}" self.model = self.backend_cls( - model_config, generate_config, self.tokenizer, num_generations, self.microbatch_size + model_config, generate_config, self.tokenizer, num_generations, self.microbatch_size, profiler=self.profiler ) self.eval_generation_config = copy.deepcopy(self.model.generate_config) self.eval_generation_config["n"] = 1 # use 1 generation for evaluation self.eval_generation_config.update(eval_generation_config) self.eval_sample_params = SamplingParams(**self.eval_generation_config) + self.ready_processes = 0 + self.condition = asyncio.Condition() + self.data_ready_for_sending = [] + + # @torch.no_grad() + # async def generate(self, input_ids, attention_mask, **kwargs): + # tasks = [] + # print("input_ids:", input_ids) + # for prompt_id in range(input_ids.size(0)): + # new_kwargs = copy.deepcopy(kwargs) + # if "gt_answer" in new_kwargs: + # new_kwargs["gt_answer"] = new_kwargs["gt_answer"][prompt_id] + # if "test_cases" in new_kwargs: + # new_kwargs["test_cases"] = new_kwargs["test_cases"][prompt_id] + # tasks.append( + # self.model.generate( + # input_ids[prompt_id].unsqueeze(0), + # attention_mask[prompt_id].unsqueeze(0), + # **new_kwargs, + # ) + # ) + # rollouts = await asyncio.gather(*tasks) + # return rollouts @torch.no_grad() - async def rollout(self, input_ids, attention_mask, **kwargs): + async def generate(self, input_ids, attention_mask, **kwargs): + # naive rollout strategy tasks = [] for prompt_id in range(input_ids.size(0)): new_kwargs = copy.deepcopy(kwargs) @@ -603,36 +624,223 @@ async def rollout(self, input_ids, attention_mask, **kwargs): **new_kwargs, ) ) - # print(f"Producer {self.producer_idx} running {len(tasks)} tasks") rollouts = await asyncio.gather(*tasks) rollouts = { k: ( torch.cat([r[k] for r in rollouts], dim=0) if k not in ["gt_answer", "test_cases"] else [r[k] for r in rollouts] - ) + ).cpu() # CUDA tensor is not serializable by ray for k in rollouts[0].keys() } - if self.producer_idx == 0 and not self.eval_mode: - if ( - self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval - or self.latest_rollout_log_step == -1 - ): - new_record = ( - json.dumps( - { - "train_step": self.consumer_global_step, - "rollout": self.tokenizer.batch_decode( - rollouts["input_ids"][:, 0], skip_special_tokens=True - ), - } + return rollouts + + @torch.no_grad() + async def rollout(self, input_ids, attention_mask, **kwargs): + """ + Advanced distributed rollout strategy that dispatches the generation tasks to different DP ranks. + Must be implemented in subclasses. + """ + raise NotImplementedError("rollout must be implemented in subclasses") + + async def get_producer_load(self): + """ + Get the load of each producer. + """ + return len(self.model.queued_requests) + + async def async_sync_model(self, episode, step, num_processes: int = 1) -> None: + """ + Asyncronous version to sync model from consumer to producer. + called by another producer, such as agentic producer. + """ + async with self.condition: + self.ready_processes += 1 + # Wait until all processes are ready + if self.ready_processes < num_processes: + await self.condition.wait() + + # Only one process should reset `ready_processes` and perform the sync + if self.ready_processes == num_processes: + self.ready_processes = 0 + self.condition.notify_all() # Notify all waiting processes + self.sync_model(episode, step) + + async def async_sync_data(self, data: Dict[str, torch.Tensor], num_processes: int = 1) -> None: + # merge data dict + async with self.condition: + self.ready_processes += 1 + if data: + self.data_ready_for_sending.append(data) + + # Wait until all processes are ready + if self.ready_processes < num_processes: + await self.condition.wait() + + # Only one process should reset `ready_processes` and perform the sync + if self.ready_processes == num_processes: # wait for all producers to join + self.ready_processes = 0 + self.condition.notify_all() + # merge data for sending + if len(self.data_ready_for_sending) >= 1: + batch_rollout_data = {} + for key in self.data_ready_for_sending[0]: + batch_rollout_data[key] = torch.cat([d[key] for d in self.data_ready_for_sending], dim=0).to( + self.device + ) + self.sync_data(batch_rollout_data) + self.data_ready_for_sending = [] # reset + + async def loop(self) -> None: + self.sync_model(0, 0) + num_update_per_episode = len(self.train_dataloader) // self.num_microbatches + num_valid_microbatches = num_update_per_episode * self.num_microbatches + + print( + f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.train_dataloader)}" + ) + for episode in range(self.num_episodes): + self.train_dataloader.sampler.set_epoch(episode) + for i, batch in enumerate(self.train_dataloader): + if i >= num_valid_microbatches: + break + if self.eval_interval > 0 and self.eval_dataset_config is not None: + if ( + self.consumer_global_step - self.latest_eval_step >= self.eval_interval + and self.consumer_global_step > self.latest_eval_step + ) or self.latest_eval_step == -1: + to_log_msg = {} + self.eval_mode = True + for eval_task_name in self.eval_dataloaders: + if self.producer_idx == 0: + print( + f"[P{self.producer_idx}] Evaluate model at training step {self.consumer_global_step} on task {eval_task_name}" + ) + eval_results = [] + eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device) + for eval_batch in tqdm.tqdm( + self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0 + ): + eval_outputs = await self.rollout(**eval_batch, sample_params=self.eval_sample_params) + eval_results = eval_results + [ + self.evaluation_function( + eval_outputs["input_ids"][m][n], + eval_outputs[ + ( + "test_cases" + if self.grpo_config["reward_fn_type"] == "code" + else "gt_answer" + ) + ][m], + eval_outputs["response_idx"][m][n], + tokenizer=self.tokenizer, + eval_mode=True, + tags=self.response_format_tags, + ) + for m in range(eval_outputs["input_ids"].size(0)) + for n in range(eval_outputs["input_ids"].size(1)) + ] + eval_statistics_tensor[0] += sum([max(0, res["ans_valid"]) for res in eval_results]) + eval_statistics_tensor[1] += len(eval_results) + allreduce(eval_statistics_tensor, op=ReduceOp.SUM, group_name="producer_group") + to_log_msg[f"eval/{eval_task_name}"] = ( + eval_statistics_tensor[0].item() / eval_statistics_tensor[1].item() + ) + if self.producer_idx == 0: + print( + f"[P{self.producer_idx}]: Accuracy on {eval_task_name}: {to_log_msg[f'eval/{eval_task_name}']}" + ) + # save eval results + safe_append_to_jsonl_file( + os.path.join( + self.eval_save_dir, + f"{eval_task_name}_training_step_{self.consumer_global_step}.jsonl", + ), + eval_results, + ) + + if self.producer_idx == 0: + self.wandb_run.log(to_log_msg, step=self.consumer_global_step) + self.eval_mode = False + self.latest_eval_step = self.consumer_global_step + self.profiler.enter("rollout") + # breakpoint() + outputs = await self.rollout(**batch) + self.profiler.exit("rollout") + outputs["temperature"] = torch.tensor( + [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0) + ).to(outputs["input_ids"].device) + bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1) + self.profiler.enter("calculate_reward") + if self.grpo_config["reward_fn_type"] == "code": + test_cases = [] + for prompt_id in range(bs): + test_cases.extend([outputs["test_cases"][prompt_id]] * num_gen) + reward_model_output = self.reward_model( + outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))), + test_cases=test_cases, + response_idx=outputs["response_idx"].view((-1, 2)), ) - + "\n" + else: + gt_answer = [] + for prompt_id in range(bs): + gt_answer.extend([outputs["gt_answer"][prompt_id]] * num_gen) + reward_model_output = self.reward_model( + outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))), + gt_answer=gt_answer, + response_idx=outputs["response_idx"].view((-1, 2)), + ) + outputs["reward"] = ( + torch.tensor([value[0] for value in reward_model_output]) + .to(outputs["input_ids"].device) + .view((bs, num_gen, 1)) ) - self.rollout_log_file.write(new_record) - self.rollout_log_file.flush() - self.latest_rollout_log_step = self.consumer_global_step - return rollouts + outputs["format_acc"] = ( + torch.tensor([value[1] for value in reward_model_output]) + .to(outputs["input_ids"].device) + .view((bs, num_gen, 1)) + ) + outputs["ans_acc"] = ( + torch.tensor([value[2] for value in reward_model_output]) + .to(outputs["input_ids"].device) + .view((bs, num_gen, 1)) + ) + if "gt_answer" in outputs: + outputs.pop("gt_answer") + if "test_cases" in outputs: + outputs.pop("test_cases") + self.profiler.exit("calculate_reward") + + print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") + outputs = pre_send(outputs) + self.profiler.enter("send_broadcast_data") + self.sync_data(outputs) + self.profiler.exit("send_broadcast_data") + if ( + (i + 1) % self.num_microbatches == 0 + and (episode != self.num_episodes - 1 or i != num_valid_microbatches - 1) + and (episode != 0 or (i + 1) > self.n_behind * self.num_microbatches) + ): + if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( + "enable_sleep_mode", False + ): + self.model.llm.sleep() # revict KV_cache to avoid OOM + # don't sync model for last iteration + self.sync_model(episode, i) + if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( + "enable_sleep_mode", False + ): + self.model.llm.wake_up() + # linear annealing for 1 episode, temperature from initial to 0.9 + if episode <= 0: + ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader) + self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[ + "temperature" + ] + ratio * 0.9 + if isinstance(self.model, BACKEND_MAP["vllm"]): + self.model.sample_params.temperature = (1 - ratio) * self.generate_config[ + "temperature" + ] + ratio * 0.9 def __del__(self): if self.producer_idx == 0: @@ -645,65 +853,18 @@ def load_state_dict(self, state_dict): @ray.remote -class AsyncServer: +class AsyncSimpleProducer(BaseAsyncProducer): """ - A async worker for inference only + Asyncronous version of the producer that uses vLLM for generation. + This class is designed to handle multiple producer actors and distribute tasks among them. """ - def __init__( - self, - producer_idx, - num_producers, - model_config, - generate_config, - tokenizer_config=None, - microbatch_size=1, - backend="transformers", - num_generations: int = 8, - eval_generation_config={}, - ): - tokenizer_path = model_config["path"] - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_config) - self.tokenizer.padding_side = "left" - self.microbatch_size = microbatch_size - self.producer_idx = producer_idx - self.num_producers = num_producers - assert backend == "async-vllm", f"AsyncProducer only supports async-vllm backend, got {backend}" - self.model = self.backend_cls( - model_config, generate_config, self.tokenizer, num_generations, self.microbatch_size - ) - self.eval_generation_config = copy.deepcopy(self.model.generate_config) - self.eval_generation_config["n"] = 1 # use 1 generation for evaluation - self.eval_generation_config.update(eval_generation_config) - self.eval_sample_params = SamplingParams(**self.eval_generation_config) - @torch.no_grad() async def rollout(self, input_ids, attention_mask, **kwargs): - tasks = [] - for prompt_id in range(input_ids.size(0)): - new_kwargs = copy.deepcopy(kwargs) - if "gt_answer" in new_kwargs: - new_kwargs["gt_answer"] = new_kwargs["gt_answer"][prompt_id] - if "test_cases" in new_kwargs: - new_kwargs["test_cases"] = new_kwargs["test_cases"][prompt_id] - tasks.append( - self.model.generate( - input_ids[prompt_id].unsqueeze(0), - attention_mask[prompt_id].unsqueeze(0), - **new_kwargs, - ) - ) - # print(f"Producer {self.producer_idx} running {len(tasks)} tasks") - rollouts = await asyncio.gather(*tasks) - rollouts = { - k: ( - torch.cat([r[k] for r in rollouts], dim=0) - if k not in ["gt_answer", "test_cases"] - else [r[k] for r in rollouts] - ) - for k in rollouts[0].keys() - } - if self.producer_idx == 0 and not self.eval_mode: + # naive rollout strategy without load balancing + rollouts = await self.generate(input_ids, attention_mask, **kwargs) + if hasattr(self, "rollout_log_file") and self.producer_idx == 0 and not self.eval_mode: + # for agentic producer, AsyncSimpleProducer is not the main producer, so we don't log rollouts if ( self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval or self.latest_rollout_log_step == -1 @@ -724,11 +885,6 @@ async def rollout(self, input_ids, attention_mask, **kwargs): self.latest_rollout_log_step = self.consumer_global_step return rollouts - def __del__(self): - if self.producer_idx == 0: - self.wandb_run.finish() - if hasattr(self, "rollout_log_file"): - self.rollout_log_file.close() - - def load_state_dict(self, state_dict): - self.model.load_state_dict(state_dict) + async def generate(self, input_ids, attention_mask, **kwargs): + rollouts = await super().generate(input_ids, attention_mask, **kwargs) + return rollouts diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 54ef4e303771..46c75cdba02e 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -110,7 +110,11 @@ # Sampling parameters parser.add_argument( - "-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm", "async-vllm"] + "-b", + "--backend", + type=str, + default="transformers", + choices=["transformers", "vllm", "async-vllm", "async-agentic-vllm"], ) parser.add_argument("-temp", "--temperature", type=float, default=1.0, help="Temperature for sampling.") parser.add_argument( @@ -215,7 +219,7 @@ namespace="ray-example", runtime_env={ "env_vars": { - "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray + # "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray "TOKENIZERS_PARALLELISM": "false", }, }, @@ -228,7 +232,7 @@ _temp_dir=args.ray_dir, runtime_env={ "env_vars": { - "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray + # "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray "TOKENIZERS_PARALLELISM": "false", }, }, @@ -237,7 +241,7 @@ if args.top_k is None: if args.backend == "transformers": args.top_k = 50 - elif args.backend == "vllm" or args.backend == "async-vllm": + elif "vllm" in args.backend: args.top_k = -1 os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock @@ -265,7 +269,7 @@ ) ) eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation - elif args.backend == "vllm" or args.backend == "async-vllm": + elif args.backend == "vllm" or args.backend == "async-vllm" or args.backend == "async-agentic-vllm": # os.environ["VLLM_DP_SIZE"] = str(args.producer_data_parallel_size) inference_model_config.update( dict( @@ -404,6 +408,25 @@ # Default system prompt args.system_prompt = DEFAUT_SYSTEM_PROMPT[args.reward_type] + if "agentic" in args.backend: + assert "vllm" in args.backend, "Agentic backend only supports async-agentic-vllm backends." + generate_config["n"] = 1 # agentic producer use AsyncProducer which processes one request a time + generate_config["max_tokens"] = ( + 2048 # max new tokens for each agentic step, usually smaller than max_new_tokens as agentic model will generate multiple steps + ) + agentic_config = { + "model": args.model, + "model_type": "transformers", + "generate_cfg": { + "max_input_tokens": args.max_new_tokens + args.max_prompt_tokens, + }, + } + agentic_config["generate_cfg"].update( + {k: v for k, v in generate_config.items() if k in ["top_k", "top_p", "temperature"]} + ) + else: + agentic_config = None + launch_distributed( num_producers=args.num_inferencer, num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size) @@ -424,6 +447,7 @@ num_generations=args.num_generations, train_model_config=train_model_config, grpo_config=grpo_config, + agentic_config=agentic_config, plugin_config={ "tp_size": args.tensor_parallel_size, "pp_size": args.pipeline_parallel_size, From 62f82a75ae768e805661011c772e4bc0d14289a3 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Mon, 8 Sep 2025 11:26:33 +0800 Subject: [PATCH 03/11] add langgraph agent, still buggy --- .../coati/distributed/agent/agentic.py | 21 +- .../agent/langgraph_math_agentic.py | 122 ++++++++++++ .../agent/langgraph_math_agentic_utils.py | 185 ++++++++++++++++++ .../distributed/agent/qwen_math_agentic.py | 88 +++++++++ ...th_utils.py => qwen_math_agentic_utils.py} | 6 + .../distributed/agent/test_api_based_agent.py | 126 ------------ .../coati/distributed/inference_backend.py | 7 +- .../ColossalChat/coati/distributed/launch.py | 17 +- applications/ColossalChat/rl_example.py | 33 +++- 9 files changed, 454 insertions(+), 151 deletions(-) create mode 100644 applications/ColossalChat/coati/distributed/agent/langgraph_math_agentic.py create mode 100644 applications/ColossalChat/coati/distributed/agent/langgraph_math_agentic_utils.py create mode 100644 applications/ColossalChat/coati/distributed/agent/qwen_math_agentic.py rename applications/ColossalChat/coati/distributed/agent/{agentic_math_utils.py => qwen_math_agentic_utils.py} (96%) delete mode 100644 applications/ColossalChat/coati/distributed/agent/test_api_based_agent.py diff --git a/applications/ColossalChat/coati/distributed/agent/agentic.py b/applications/ColossalChat/coati/distributed/agent/agentic.py index f348eb69d946..12de6164a832 100644 --- a/applications/ColossalChat/coati/distributed/agent/agentic.py +++ b/applications/ColossalChat/coati/distributed/agent/agentic.py @@ -4,14 +4,11 @@ import ray import torch -from coati.distributed.agent.agentic_math_utils import TIR_SYSTEM, CustomTransformers from coati.distributed.producer import BaseProducer -from qwen_agent.agents import TIRMathAgent from vllm import SamplingParams -@ray.remote -class AgenticProducer(BaseProducer): +class BaseAgenticProducer(BaseProducer): """ Asyncronous version of the producer that uses vLLM for generation. This class is designed to generate agentic response @@ -29,7 +26,6 @@ def __init__( generate_config, async_producers, tokenizer_config=None, - agentic_config=None, microbatch_size=1, backend="transformers", num_generations: int = 8, @@ -82,10 +78,13 @@ def __init__( self.async_producers = async_producers self.num_generations = num_generations self.generate_config = generate_config - self.agentic_config = model_config if not agentic_config else agentic_config - self.agentic_config.update({"model": model_config["path"]}) - self.llm = CustomTransformers(self.agentic_config, self.producer_idx, generation_workers=self.async_producers) - self.bot = TIRMathAgent(llm=self.llm, name=model_config["path"], system_message=TIR_SYSTEM) + + def _run_agentic_pipeline(self, messages): + """ + Run the agentic pipeline to generate responses based on the input messages. + This function should be implemented in subclasses. + """ + raise NotImplementedError def rollout(self, **kwargs) -> Dict[str, torch.Tensor]: """ @@ -110,9 +109,7 @@ def rollout(self, **kwargs) -> Dict[str, torch.Tensor]: } for i in range(self.num_generations): _messages = copy.deepcopy(messages) - for response in self.bot.run(messages): - continue - _messages.extend(response) + _messages = self._run_agentic_pipeline(_messages) response_input_ids = self.tokenizer.apply_chat_template(_messages, return_tensors="pt", tokenize=True) # truncate if too long response_input_ids = response_input_ids[:, : self.grpo_config["max_length"] - to_pad_left] diff --git a/applications/ColossalChat/coati/distributed/agent/langgraph_math_agentic.py b/applications/ColossalChat/coati/distributed/agent/langgraph_math_agentic.py new file mode 100644 index 000000000000..90a8d3fbd98b --- /dev/null +++ b/applications/ColossalChat/coati/distributed/agent/langgraph_math_agentic.py @@ -0,0 +1,122 @@ +from typing import Any, Dict + +import ray +from coati.distributed.agent.agentic import BaseAgenticProducer +from coati.distributed.agent.langgraph_math_agentic_utils import CustomOpenAIAPILLM, LangChainCustomLLM, python +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage +from langgraph.checkpoint.memory import MemorySaver +from langgraph.prebuilt import create_react_agent + + +@ray.remote +class LangGraphMathAgenticProducer(BaseAgenticProducer): + """ + Asyncronous version of the producer that uses vLLM for generation. + This class is designed to generate agentic response + """ + + def __init__( + self, + producer_idx, + num_producers, + num_consumer_procs, + num_episodes, + batch_size, + train_dataset_config, + model_config, + generate_config, + async_producers, + tokenizer_config=None, + agentic_config=None, + microbatch_size=1, + backend="transformers", + num_generations: int = 8, + consumer_plugin_config=None, + eval_dataset_config=None, + eval_interval=-1, # disable evaluation + grpo_config: Dict[str, Any] = None, + eval_save_dir: str = "./eval", + eval_generation_config={}, + project_name: str = None, + run_name: str = None, + wandb_group_name: str = None, + log_rollout_interval: int = 20, + rollout_log_file: str = "./rollout_log.jsonl", + enable_profiling: bool = False, + n_behind: int = 0, + ): + assert microbatch_size == 1 # microbatch_size must be 1 for agentic producer + assert batch_size == 1 # batch_size must be 1 for agentic producer + super().__init__( + producer_idx, + num_producers, + num_consumer_procs, + num_episodes, + batch_size, + train_dataset_config, + model_config, + generate_config, + async_producers, + tokenizer_config, + microbatch_size, + backend, + num_generations, + consumer_plugin_config, + eval_dataset_config=eval_dataset_config, + eval_interval=eval_interval, + grpo_config=grpo_config, + eval_save_dir=eval_save_dir, + eval_generation_config=eval_generation_config, + project_name=project_name, + run_name=run_name, + wandb_group_name=wandb_group_name, + log_rollout_interval=log_rollout_interval, + rollout_log_file=rollout_log_file, + enable_profiling=enable_profiling, + n_behind=n_behind, + ) + self.agentic_config = agentic_config + self.agentic_config.pop("agentic_type", None) + self.llm_client = CustomOpenAIAPILLM({"model": model_config["path"]}, producer_idx, self.async_producers) + self.llm = LangChainCustomLLM(self.llm_client) + # self.python_repl = PythonREPL() + # repl_tool = Tool( + # name="python_repl", + # description="A Python shell. Use this to execute python commands. Input should be a valid python command. If you want to see the output of a value, you should print it out with `print(...)`.", + # func=self.python_repl.run, + # ) + # self.tools = [repl_tool] + self.tools = [python] + self.memory = MemorySaver() + self.bot = create_react_agent(self.llm, self.tools, checkpointer=self.memory) + + def _run_agentic_pipeline(self, messages): + """ + Run the agentic pipeline to generate responses based on the input messages using the LangGraph. + """ + assert ( + len(messages) == 2 and messages[0]["role"] == "system" and messages[1]["role"] == "user" + ), "Only support 1 system message and 1 user message as input." + # inputs = messages + for event in self.bot.stream( + {"messages": [("system", messages[0]["content"]), ("user", "calculate the 1000th Fibonacci number")]}, + self.agentic_config, + ): + continue + breakpoint() + + final_state = self.bot.get_state(self.agentic_config) + transformer_messages = [] + for message in final_state[0]["messages"]: + tool_calls = None + if isinstance(message, SystemMessage): + message.content + elif isinstance(message, HumanMessage): + message.content + elif isinstance(message, AIMessage): + message.content + tool_calls = message.get("tool_calls", None) # [{"type": "function", "function": tool_call}] + elif isinstance(message, ToolMessage): + message.content + + return transformer_messages diff --git a/applications/ColossalChat/coati/distributed/agent/langgraph_math_agentic_utils.py b/applications/ColossalChat/coati/distributed/agent/langgraph_math_agentic_utils.py new file mode 100644 index 000000000000..cb587e0d6849 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/agent/langgraph_math_agentic_utils.py @@ -0,0 +1,185 @@ +# ------------------------------- +# 1. Define the Python tool +# ------------------------------- +import copy +import io +import random +import sys +from typing import Dict, List + +import ray +import torch +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import AIMessage +from langchain_core.outputs.chat_generation import ChatGeneration +from langchain_core.outputs.chat_result import ChatResult +from langchain_core.prompts import PromptTemplate +from langchain_core.tools import tool +from langgraph.checkpoint.memory import MemorySaver +from langgraph.prebuilt import create_react_agent +from tool_calling_llm import ToolCallingLLM +from transformers import AutoTokenizer + +SYSTEM_PROMPT_TEMPLATE = """{task_description}. You have access to the following tools: + +{tools} + +Use the following format: + +Question: the input question you must answer +Thought: you should always think about what to do +Action: the action to take, should be one of [{tool_names}] +Action Input: the input to the action +Observation: the result of the action +... (this Thought/Action/Action Input/Observation can repeat N times) +Thought: I now know the final answer +Final Answer: the final answer to the original input question + +Begin! + +Question: {input} +Thought:{agent_scratchpad}""" + +SYSTEM_PROMPT = PromptTemplate.from_template(SYSTEM_PROMPT_TEMPLATE) + + +class Capturing(list): + """Capture stdout prints inside exec()""" + + def __enter__(self): + self._stdout = sys.stdout + sys.stdout = self._stringio = io.StringIO() + return self + + def __exit__(self, *args): + self.extend(self._stringio.getvalue().splitlines()) + sys.stdout = self._stdout + + +@tool +def python(code: str) -> str: + """ + This function executes a string of Python code and returns the printed output. + You need to print the output. Please import all libraries used in the code string. + """ + local_vars = {} + with Capturing() as output: + exec(code, {}, local_vars) + if output == []: + return "Error: No output printed from the code. Please ensure you print the output." + return "\n".join(output) + + +# ------------------------------- +# 2. Define a Custom API LLM wrapper +# ------------------------------- +class CustomOpenAIAPILLM: + def __init__(self, cfg: dict, producer_idx, generation_workers=None): + self.producer_idx = producer_idx + self.generation_workers = generation_workers + self.load_balancer_idx = producer_idx % len(self.generation_workers) + assert "model" in cfg, "Please specify the model name in the config" + self.tokenizer = AutoTokenizer.from_pretrained(cfg["model"]) + self.role_mapping = { + "system": "system", + "user": "user", + "assistant": "assistant", + "human": "user", + "tool": "tool", + } + + def invoke(self, messages: List[Dict[str, str]], **kwargs) -> str: + """ + messages: list of {"role": "user"/"assistant"/"system", "content": "..."} + """ + # load balancing + load = [ray.get(generation_worker.get_producer_load.remote()) for generation_worker in self.generation_workers] + min_load = min(load) + candidates = [i for i, l in enumerate(load) if l == min_load] + # random tie break + self.load_balancer_idx = random.choice(candidates) + generation_worker = self.generation_workers[self.load_balancer_idx] + transformer_messages = [] + for message in messages: + transformer_messages.append({"role": self.role_mapping[message.type], "content": message.content}) + input_ids = self.tokenizer.apply_chat_template( + transformer_messages, return_tensors="pt", tokenize=True, add_generation_prompt=True + ) + attention_mask = torch.ones_like(input_ids) + rollouts = ray.get(generation_worker.generate.remote(input_ids, attention_mask, **kwargs)) + response = self.tokenizer.batch_decode( + rollouts["input_ids"][0][:, input_ids.size(-1) :], skip_special_tokens=True + )[0] + return response + + +class LangChainCustomLLM(ToolCallingLLM, BaseChatModel): + client: CustomOpenAIAPILLM = None + + def __init__(self, client: CustomOpenAIAPILLM): + super().__init__() + self.client = client + + def _generate(self, messages, stop=None, run_manager=None, **kwargs): + # content = self.client.invoke([m.dict() for m in messages]) + # chat_result = ChatResult( + # generations=[ChatGeneration(message=AIMessage(content=content))] + # ) + print("messages:", messages) + breakpoint() + system_message, functions = self._generate_system_message_and_functions(kwargs) + sample_params = {"stop": stop} if stop is not None else {} + sample_params.update({k: v for k, v in kwargs.items() if k in ["temperature", "top_p", "top_k", "max_tokens"]}) + messages_ = copy.deepcopy(messages) + messages_[0].content = messages_[0].content + "\n" + system_message.content + response_message = self.client.invoke( # type: ignore[safe-super] + [system_message] + messages, **{"sample_params": sample_params} + ) + breakpoint() + response = self._process_response(AIMessage(content=response_message), functions) + return ChatResult(generations=[ChatGeneration(message=response)]) + + @property + def _llm_type(self) -> str: + return "custom-api-llm" + + +# ------------------------------- +# 3. Build a ReAct Agent with LangGraph +# ------------------------------- +def build_agent(): + # Wrap custom API LLM in LangChain-compatible interface + + # Init LLM + llm_client = CustomOpenAIAPILLM() + llm = LangChainCustomLLM(llm_client) + + # Tools + tools = [python] + + # Memory (optional) + memory = MemorySaver() + + # Build ReAct agent + agent = create_react_agent(llm, tools, checkpointer=memory) + return agent + + +# ------------------------------- +# 4. Run the agent on a math problem +# ------------------------------- +if __name__ == "__main__": + agent = build_agent() + + # Example math question + user_input = "What is the least common multiple of 18 and 24? Use Python if needed." + + config = {"configurable": {"thread_id": "math-1"}} + for event in agent.stream({"messages": [("user", user_input)]}, config): + if "agent" in event: + print("Agent event:", event["agent"]["messages"][-1].content) + elif "tools" in event: + print("Tool event:", event["tools"]["messages"][-1].content) + + final_state = agent.get_state(config) + print("Final Answer:", final_state["messages"][-1].content) diff --git a/applications/ColossalChat/coati/distributed/agent/qwen_math_agentic.py b/applications/ColossalChat/coati/distributed/agent/qwen_math_agentic.py new file mode 100644 index 000000000000..273544b2a555 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/agent/qwen_math_agentic.py @@ -0,0 +1,88 @@ +from typing import Any, Dict + +import ray +from coati.distributed.agent.agentic import BaseAgenticProducer +from coati.distributed.agent.qwen_math_agentic_utils import TIR_SYSTEM, CustomTransformers +from qwen_agent.agents import TIRMathAgent + + +@ray.remote +class QwenMathAgenticProducer(BaseAgenticProducer): + """ + Asyncronous version of the producer that uses vLLM for generation. + This class is designed to generate agentic response + """ + + def __init__( + self, + producer_idx, + num_producers, + num_consumer_procs, + num_episodes, + batch_size, + train_dataset_config, + model_config, + generate_config, + async_producers, + tokenizer_config=None, + agentic_config=None, + microbatch_size=1, + backend="transformers", + num_generations: int = 8, + consumer_plugin_config=None, + eval_dataset_config=None, + eval_interval=-1, # disable evaluation + grpo_config: Dict[str, Any] = None, + eval_save_dir: str = "./eval", + eval_generation_config={}, + project_name: str = None, + run_name: str = None, + wandb_group_name: str = None, + log_rollout_interval: int = 20, + rollout_log_file: str = "./rollout_log.jsonl", + enable_profiling: bool = False, + n_behind: int = 0, + ): + assert microbatch_size == 1 # microbatch_size must be 1 for agentic producer + assert batch_size == 1 # batch_size must be 1 for agentic producer + super().__init__( + producer_idx, + num_producers, + num_consumer_procs, + num_episodes, + batch_size, + train_dataset_config, + model_config, + generate_config, + async_producers, + tokenizer_config, + microbatch_size, + backend, + num_generations, + consumer_plugin_config, + eval_dataset_config=eval_dataset_config, + eval_interval=eval_interval, + grpo_config=grpo_config, + eval_save_dir=eval_save_dir, + eval_generation_config=eval_generation_config, + project_name=project_name, + run_name=run_name, + wandb_group_name=wandb_group_name, + log_rollout_interval=log_rollout_interval, + rollout_log_file=rollout_log_file, + enable_profiling=enable_profiling, + n_behind=n_behind, + ) + self.agentic_config = model_config if not agentic_config else agentic_config + self.agentic_config.update({"model": model_config["path"]}) + self.llm = CustomTransformers(self.agentic_config, self.producer_idx, generation_workers=self.async_producers) + self.bot = TIRMathAgent(llm=self.llm, name=model_config["path"], system_message=TIR_SYSTEM) + + def _run_agentic_pipeline(self, messages): + """ + Run the agentic pipeline to generate responses based on the input messages using the TIRMathAgent. + """ + for response in self.bot.run(messages): + continue + messages.extend(response) + return messages diff --git a/applications/ColossalChat/coati/distributed/agent/agentic_math_utils.py b/applications/ColossalChat/coati/distributed/agent/qwen_math_agentic_utils.py similarity index 96% rename from applications/ColossalChat/coati/distributed/agent/agentic_math_utils.py rename to applications/ColossalChat/coati/distributed/agent/qwen_math_agentic_utils.py index eb44f8a93092..7787a99aabfe 100644 --- a/applications/ColossalChat/coati/distributed/agent/agentic_math_utils.py +++ b/applications/ColossalChat/coati/distributed/agent/qwen_math_agentic_utils.py @@ -60,6 +60,11 @@ def __init__(self, generation_worker=None): self.generation_worker = generation_worker def generate(self, **kwargs): + breakpoint() + if "max_new_tokens" in kwargs: + # we use VLLM backend for generation, which uses `max_tokens` + kwargs["max_tokens"] = kwargs["max_new_tokens"] + del kwargs["max_new_tokens"] rollouts = ray.get(self.generation_worker.generate.remote(**kwargs)) return rollouts["input_ids"] @@ -131,6 +136,7 @@ def _chat_stream( response = self._chat_no_stream(messages=messages, generate_cfg=generate_cfg) # if self.producer_idx == 0: # print(response) + breakpoint() yield response diff --git a/applications/ColossalChat/coati/distributed/agent/test_api_based_agent.py b/applications/ColossalChat/coati/distributed/agent/test_api_based_agent.py deleted file mode 100644 index 5e63bb5a366c..000000000000 --- a/applications/ColossalChat/coati/distributed/agent/test_api_based_agent.py +++ /dev/null @@ -1,126 +0,0 @@ -# ------------------------------- -# 1. Define the Python tool -# ------------------------------- -import io -import sys -from typing import Dict, List - -import requests -from langchain_core.tools import tool -from langgraph.checkpoint.memory import MemorySaver -from langgraph.prebuilt import create_react_agent - - -class Capturing(list): - """Capture stdout prints inside exec()""" - - def __enter__(self): - self._stdout = sys.stdout - sys.stdout = self._stringio = io.StringIO() - return self - - def __exit__(self, *args): - self.extend(self._stringio.getvalue().splitlines()) - sys.stdout = self._stdout - - -@tool -def python(code: str) -> str: - """ - This function executes a string of Python code and returns the printed output. - You need to print the output. Please import all libraries used in the code string. - """ - local_vars = {} - with Capturing() as output: - exec(code, {}, local_vars) - if output == []: - return "Error: No output printed from the code. Please ensure you print the output." - return "\n".join(output) - - -# ------------------------------- -# 2. Define a Custom API LLM wrapper -# ------------------------------- -class CustomAPILLM: - def __init__(self, api_url: str, api_key: str = None): - self.api_url = api_url - self.api_key = api_key - - def invoke(self, messages: List[Dict[str, str]]) -> str: - """ - messages: list of {"role": "user"/"assistant"/"system", "content": "..."} - """ - headers = {"Content-Type": "application/json"} - if self.api_key: - headers["Authorization"] = f"Bearer {self.api_key}" - - payload = { - "model": "custom-model", # depends on your API - "messages": messages, - "temperature": 0, - } - - response = requests.post(self.api_url, headers=headers, json=payload) - response.raise_for_status() - data = response.json() - - # Adjust according to your API response format - return data["choices"][0]["message"]["content"] - - -# ------------------------------- -# 3. Build a ReAct Agent with LangGraph -# ------------------------------- -def build_agent(): - # Wrap custom API LLM in LangChain-compatible interface - from langchain_core.language_models import BaseChatModel - from langchain_core.messages import AIMessage - - class LangChainCustomLLM(BaseChatModel): - client: CustomAPILLM = None - - def __init__(self, client: CustomAPILLM): - super().__init__() - self.client = client - - def _generate(self, messages, stop=None, run_manager=None, **kwargs): - content = self.client.invoke([m.dict() for m in messages]) - return self._create_chat_result([AIMessage(content=content)]) - - @property - def _llm_type(self) -> str: - return "custom-api-llm" - - # Init LLM - llm_client = CustomAPILLM(api_url="http://localhost:8000/v1/chat/completions") - llm = LangChainCustomLLM(llm_client) - - # Tools - tools = [python] - - # Memory (optional) - memory = MemorySaver() - - # Build ReAct agent - agent = create_react_agent(llm, tools, checkpointer=memory) - return agent - - -# ------------------------------- -# 4. Run the agent on a math problem -# ------------------------------- -if __name__ == "__main__": - agent = build_agent() - - # Example math question - user_input = "What is the least common multiple of 18 and 24? Use Python if needed." - - config = {"configurable": {"thread_id": "math-1"}} - for event in agent.stream({"messages": [("user", user_input)]}, config): - if "agent" in event: - print("Agent event:", event["agent"]["messages"][-1].content) - elif "tools" in event: - print("Tool event:", event["tools"]["messages"][-1].content) - - final_state = agent.get_state(config) - print("Final Answer:", final_state["messages"][-1].content) diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index c35c45bddf39..0c25eac272eb 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -363,7 +363,12 @@ async def generate( response_start_idx = input_ids.size(1) first_non_padding_token_idx = (input_ids != self.tokenizer.pad_token_id).int().argmax(dim=1) input_ids_no_padding = [input_ids.tolist()[0][first_non_padding_token_idx[0] :]] - sample_params = kwargs.get("sample_params", self.sample_params) + sample_params = self.sample_params + if len(kwargs) > 0: + sample_params = self.generate_config.copy() + sample_params.update(kwargs) + sample_params.update(self.FORCE_GENERATE_CONFIG) + sample_params = SamplingParams(**sample_params) out_tokens = [] out_len = [] log_probs = [] diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 793cd932e92c..926eca4ff2e1 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -4,7 +4,8 @@ from typing import Any, Dict, Optional import ray -from coati.distributed.agent.agentic import AgenticProducer +from coati.distributed.agent.langgraph_math_agentic import LangGraphMathAgenticProducer +from coati.distributed.agent.qwen_math_agentic import QwenMathAgenticProducer from .consumer import SimpleConsumer from .grpo_consumer import GRPOConsumer @@ -18,6 +19,10 @@ "RLOO": GRPOConsumer, } Producer_MAP = {"Simple": SimpleProducer, "Async": AsyncSimpleProducer} +AGENTIC_PRODUCER_MAP = { + "QwenMathAgent": QwenMathAgenticProducer, + "LangGraphMathAgent": LangGraphMathAgenticProducer, +} # supported agentic producers def get_jsonl_size_fast(path: str) -> int: @@ -178,8 +183,16 @@ async def test(): # when agentic is enabled, we use core_producer as inference engine and # AgenticProducer as the real producer _producer_procs = producer_procs + assert ( + "agentic_producer" in agentic_config + ), "Please specify the agentic producer through `agentic_producer` in agentic_config." + assert ( + agentic_config["agentic_producer"] in AGENTIC_PRODUCER_MAP + ), f"Only {list(AGENTIC_PRODUCER_MAP.keys())} are supported as agentic producer so far." + agentic_producer_cls = AGENTIC_PRODUCER_MAP[agentic_config["agentic_producer"]] + agentic_config.pop("agentic_producer") producer_procs = [ - AgenticProducer.options(num_cpus=1).remote( + agentic_producer_cls.options(num_cpus=1).remote( producer_idx=producer_idx, num_producers=num_producers * train_batch_size, num_consumer_procs=num_consumer_procs, diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 46c75cdba02e..bd92f64333b1 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -161,6 +161,13 @@ choices=["think_answer_tags", "boxed", "code"], help="Reward type for GRPO.", ) + parser.add_argument( + "--agentic-type", + type=str, + default="QwenMathAgent", + choices=["QwenMathAgent", "LangGraphMathAgent"], + help="Agentic model type for agentic training.", + ) parser.add_argument( "-cv", "--code-verifier-api-url", @@ -414,16 +421,22 @@ generate_config["max_tokens"] = ( 2048 # max new tokens for each agentic step, usually smaller than max_new_tokens as agentic model will generate multiple steps ) - agentic_config = { - "model": args.model, - "model_type": "transformers", - "generate_cfg": { - "max_input_tokens": args.max_new_tokens + args.max_prompt_tokens, - }, - } - agentic_config["generate_cfg"].update( - {k: v for k, v in generate_config.items() if k in ["top_k", "top_p", "temperature"]} - ) + if args.agentic_type == "QwenMathAgent": + agentic_config = { + "agentic_producer": "QwenMathAgent", + "model": args.model, + "model_type": "transformers", + "generate_cfg": { + "max_input_tokens": args.max_new_tokens + args.max_prompt_tokens, + }, + } + agentic_config["generate_cfg"].update( + {k: v for k, v in generate_config.items() if k in ["top_k", "top_p", "temperature"]} + ) + elif args.agentic_type == "LangGraphMathAgent": + agentic_config = {"configurable": {"thread_id": "math-1"}, "agentic_producer": "LangGraphMathAgent"} + else: + raise ValueError(f"Unsupported agentic model type: {args.agentic_type}") else: agentic_config = None From edcef9edaff712882e8f5dc152cda5121bce62cd Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Tue, 16 Sep 2025 16:23:46 +0800 Subject: [PATCH 04/11] add custom agentic producer --- .../distributed/agent/agentic_producer.py | 287 ++++++++++++++++++ .../distributed/agent/{agentic.py => base.py} | 40 ++- .../agent/langgraph_math_agentic.py | 122 -------- .../agent/langgraph_math_agentic_utils.py | 185 ----------- .../coati/distributed/agent/math_tools.py | 31 ++ ...entic.py => qwen_math_agentic_producer.py} | 4 +- .../agent/qwen_math_agentic_utils.py | 12 +- .../coati/distributed/agent/tool_worker.py | 77 +++++ .../coati/distributed/inference_backend.py | 31 +- .../ColossalChat/coati/distributed/launch.py | 31 +- .../coati/distributed/producer.py | 22 +- applications/ColossalChat/rl_example.py | 18 +- 12 files changed, 482 insertions(+), 378 deletions(-) create mode 100644 applications/ColossalChat/coati/distributed/agent/agentic_producer.py rename applications/ColossalChat/coati/distributed/agent/{agentic.py => base.py} (84%) delete mode 100644 applications/ColossalChat/coati/distributed/agent/langgraph_math_agentic.py delete mode 100644 applications/ColossalChat/coati/distributed/agent/langgraph_math_agentic_utils.py create mode 100644 applications/ColossalChat/coati/distributed/agent/math_tools.py rename applications/ColossalChat/coati/distributed/agent/{qwen_math_agentic.py => qwen_math_agentic_producer.py} (96%) create mode 100644 applications/ColossalChat/coati/distributed/agent/tool_worker.py diff --git a/applications/ColossalChat/coati/distributed/agent/agentic_producer.py b/applications/ColossalChat/coati/distributed/agent/agentic_producer.py new file mode 100644 index 000000000000..234734c41527 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/agent/agentic_producer.py @@ -0,0 +1,287 @@ +import copy +import random +import re +from typing import Any, Dict +from uuid import uuid4 + +import ray +from coati.distributed.agent.base import BaseAgenticProducer +from transformers import AutoTokenizer + +DEFAULT_SYSTEM_MESSAGE = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The Assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here .""" + + +@ray.remote +class AgenticProducer(BaseAgenticProducer): + """ + Asyncronous version of the producer that uses vLLM for generation. + This class is designed to generate agentic response + + Please use the following SYSTEM message or a similar one for the agentic math model: + '''A conversation between User and Assistant. The user asks a question, and the Assistant solves it. + The Assistant first thinks about the reasoning process in the mind and then provides the user with + the answer. The reasoning process and answer are enclosed within and + tags, respectively, i.e., reasoning process here answer here .''' + """ + + def __init__( + self, + producer_idx, + num_producers, + num_consumer_procs, + num_episodes, + batch_size, + train_dataset_config, + model_config, + generate_config, + async_producers, + tool_workers=[], + tokenizer_config=None, + agentic_config=None, + microbatch_size=1, + backend="transformers", + num_generations: int = 8, + consumer_plugin_config=None, + eval_dataset_config=None, + eval_interval=-1, # disable evaluation + grpo_config: Dict[str, Any] = None, + eval_save_dir: str = "./eval", + eval_generation_config={}, + project_name: str = None, + run_name: str = None, + wandb_group_name: str = None, + log_rollout_interval: int = 20, + rollout_log_file: str = "./rollout_log.jsonl", + enable_profiling: bool = False, + n_behind: int = 0, + ): + assert microbatch_size == 1 # microbatch_size must be 1 for agentic producer + assert batch_size == 1 # batch_size must be 1 for agentic producer + super().__init__( + producer_idx, + num_producers, + num_consumer_procs, + num_episodes, + batch_size, + train_dataset_config, + model_config, + generate_config, + async_producers, + tokenizer_config, + microbatch_size, + backend, + num_generations, + consumer_plugin_config, + eval_dataset_config=eval_dataset_config, + eval_interval=eval_interval, + grpo_config=grpo_config, + eval_save_dir=eval_save_dir, + eval_generation_config=eval_generation_config, + project_name=project_name, + run_name=run_name, + wandb_group_name=wandb_group_name, + log_rollout_interval=log_rollout_interval, + rollout_log_file=rollout_log_file, + enable_profiling=enable_profiling, + n_behind=n_behind, + ) + self.tool_workers = tool_workers + self.agentic_config = model_config if not agentic_config else agentic_config + self.agentic_config.update({"model": model_config["path"]}) + tokenizer_path = None + if tokenizer_config and "path" in tokenizer_config: + tokenizer_path = tokenizer_config["path"] + elif "path" in model_config: + tokenizer_path = model_config["path"] + assert tokenizer_path is not None, "Tokenizer path must be provided either in tokenizer_config or model_config." + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) + self.tools_schema = [] + self.tool_call_budget = self.agentic_config.get("tool_call_budget", 3) + self.llm_call_budget = self.agentic_config.get("llm_call_budget", 10) + self.async_llm_engine_map = {} + self._get_tools() + + def _get_tools(self): + """ + SYSTEM message for the agentic math model. Reference: r-start2 paper https://arxiv.org/pdf/2508.20722 + """ + tools = ray.get(self.tool_workers[0].list_tools.remote()) + tool_descriptions = {tool: ray.get(self.tool_workers[0].get_tool_description.remote(tool)) for tool in tools} + tool_arg_schemas = {tool: ray.get(self.tool_workers[0].get_args_schema.remote(tool)) for tool in tools} + self.tools = [] + for tool in tools: + tool_schema = {"name": tool, "description": tool_descriptions[tool], "parameters": tool_arg_schemas[tool]} + self.tools.append(tool_schema) + + def _build_prompt( + self, messages, add_generation_prompt: bool = True, return_dict=True, return_tensors="pt" + ) -> dict: + """ + Build the prompt for the agentic math model. + """ + return self.tokenizer.apply_chat_template( + messages, + tools=self.tools, + add_generation_prompt=add_generation_prompt, + return_dict=return_dict, + return_tensors=return_tensors, + ) + + def _parse_response(self, response: str) -> Dict[str, Any]: + """ + Parse the response from the agentic math model. + + Sample Assistant Response: + The tool indicates that Singapore’s weather today is 31°C with partly cloudy skies and light showers. \\\\boxed{It is warm and slightly rainy in Singapore today.}<|im_end|> + + Sample Assistant Response with Tool Call: + To answer this, I will check both the weather and the timezone for New York.\n\n{"name": "get_weather", "arguments": {"location": "New York"}}\n\n\n{"name": "get_timezone", "arguments": {"location": "New York"}}\n + + Sample Ouput: + { + "role": "assistant", + "content": "Let me check the current weather in Singapore by calling the weather tool.", + "tool_calls": [ + { + "function": { + "name": "get_weather", + "arguments": { + "location": "New York" + } + } + }, + { + "function": { + "name": "get_timezone", + "arguments": { + "location": "New York" + } + } + } + ] + }, + { + "role": "assistant", + "content": "The tool indicates that Singapore’s weather today is 31°C with partly cloudy skies and light showers. \\\\boxed{It is warm and slightly rainy in Singapore today.}" + } + """ + # split by + response_chunked = response.split("<|im_end|>")[0].strip() + if "" in response_chunked: + assistant_content = response_chunked.split("")[0].strip() + tool_call_sections = response_chunked[response_chunked.find("") :].strip() + # extract all tool calls + tool_calls = [] + pattern = "(.*?)" + matches = re.findall(pattern, tool_call_sections, re.DOTALL) + for match in matches: + try: + tool_call = eval(match.strip()) + name = tool_call["name"] + arguments = tool_call["arguments"] + tool_calls.append({"function": {"name": name, "arguments": arguments}}) + except Exception as e: + print(f"Failed to parse tool call: {match.strip()}. Error: {e}") + tool_calls.append({"function": {"name": "return_parsing_error", "arguments": {}}}) + else: + assistant_content = response_chunked + tool_calls = [] + assistant_message = {"role": "assistant", "content": assistant_content} + if tool_calls: + assistant_message["tool_calls"] = tool_calls + return assistant_message + + def _select_tool_worker(self) -> ray.actor.ActorHandle: + """ + Select a tool worker based on the current load. + """ + loads = ray.get([worker.get_load.remote() for worker in self.tool_workers]) + min_load = min(loads) + candidates = [i for i, l in enumerate(loads) if l == min_load] + selected_idx = random.choice(candidates) # random tie break + ray.get(self.tool_workers[selected_idx].increase_load.remote()) + return self.tool_workers[selected_idx] + + def _select_async_producer(self, request_id) -> ray.actor.ActorHandle: + """ + Select an async producer based on the current load. + """ + # use the last used async producer if exists to reuse kv cache (as vllm use paged kv cache, + # it will reuse most of the kv cache pages without recomputation) + if request_id in self.async_llm_engine_map: + return self.async_producers[self.async_llm_engine_map[request_id]] + # otherwise select the least loaded async producer + loads = ray.get([proc.get_producer_load.remote() for proc in self.async_producers]) + min_load = min(loads) + candidates = [i for i, l in enumerate(loads) if l == min_load] + selected_idx = random.choice(candidates) # random tie break + self.async_llm_engine_map[request_id] = selected_idx + return self.async_producers[selected_idx] + + def _run_agentic_pipeline(self, messages): + """ + Run the agentic pipeline to generate responses based on the input messages. + """ + tool_call_count = 0 + llm_call_count = 0 + num_prompt_tokens = 0 + request_id = str(uuid4()) + logprobs = None + while True: + # tokenize the messages + if llm_call_count > self.llm_call_budget: + print(f"LLM call budget exceeded: {llm_call_count} > {self.llm_call_budget}. Stopping.") + del self.async_llm_engine_map[request_id] + while messages[-1]["role"] == "tool": + messages.pop() + return messages, logprobs + inputs = self._build_prompt(messages, return_dict=True, return_tensors="pt") + if num_prompt_tokens == 0: + num_prompt_tokens = inputs["input_ids"].size(-1) + if inputs["input_ids"].size(-1) - num_prompt_tokens > self.generate_config["max_tokens"]: + print( + f"Max tokens exceeded: Current have generated {inputs['input_ids'].size(-1) - num_prompt_tokens} tokens > {self.generate_config.get('max_tokens', 512)}. Stopping." + ) + del self.async_llm_engine_map[request_id] + while messages[-1]["role"] == "tool": + messages.pop() + return messages, logprobs + async_producer = self._select_async_producer(request_id=request_id) + agentic_generate_config = copy.deepcopy(self.generate_config) + agentic_generate_config["max_tokens"] = self.agentic_config.get("max_tokens", 2048) + response = ray.get( + async_producer.generate.remote( + inputs["input_ids"], + inputs["attention_mask"], + request_id=request_id, + **agentic_generate_config, + ) + ) + llm_call_count += 1 + response_input_ids = response["input_ids"] + logprobs = response["action_log_probs"] + response_text = self.tokenizer.decode( + response_input_ids[0][0][inputs["input_ids"].size(-1) :], skip_special_tokens=False + ) + assistant_message = self._parse_response(response_text) + messages.append(assistant_message) + if "tool_calls" in assistant_message: + if tool_call_count > self.tool_call_budget: + print(f"Tool call budget exceeded: {tool_call_count} > {self.tool_call_budget}. Stopping.") + del self.async_llm_engine_map[request_id] + return messages, logprobs + tool_call_count += len(assistant_message["tool_calls"]) + handlers = [] + for tool_call in assistant_message["tool_calls"]: + # select a tool worker to execute the tool call + tool_worker = self._select_tool_worker() + handler = tool_worker.call.remote(tool_call["function"]["name"], tool_call["function"]["arguments"]) + handlers.append(handler) + tool_results = ray.get(handlers) + for tool_call, tool_result in zip(assistant_message["tool_calls"], tool_results): + tool_message = {"role": "tool", "content": str(tool_result)} + messages.append(tool_message) + else: + # no further tool call, return the messages + del self.async_llm_engine_map[request_id] + return messages, logprobs diff --git a/applications/ColossalChat/coati/distributed/agent/agentic.py b/applications/ColossalChat/coati/distributed/agent/base.py similarity index 84% rename from applications/ColossalChat/coati/distributed/agent/agentic.py rename to applications/ColossalChat/coati/distributed/agent/base.py index 12de6164a832..b290e9e44d70 100644 --- a/applications/ColossalChat/coati/distributed/agent/agentic.py +++ b/applications/ColossalChat/coati/distributed/agent/base.py @@ -1,5 +1,6 @@ import copy import json +from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict import ray @@ -86,6 +87,15 @@ def _run_agentic_pipeline(self, messages): """ raise NotImplementedError + def _build_prompt( + self, messages, add_generation_prompt: bool = True, return_dict=True, return_tensors="pt" + ) -> dict: + """ + Build the prompt from the input messages. + This function should be implemented in subclasses. + """ + raise NotImplementedError + def rollout(self, **kwargs) -> Dict[str, torch.Tensor]: """ Rollout function to generate responses for the input, for example, using LLM or agentic pipeline. @@ -93,9 +103,9 @@ def rollout(self, **kwargs) -> Dict[str, torch.Tensor]: """ assert len(kwargs["messages"]) == 1, "Only support batch size of 1 for agentic producer" messages = kwargs["messages"][0] - prompt_input_ids = self.tokenizer.apply_chat_template( - messages, return_tensors="pt", tokenize=True, add_generation_prompt=True - ) + prompt_input_ids = self._build_prompt( + messages, return_dict=True, return_tensors="pt", add_generation_prompt=True + )["input_ids"] # add left padding prompt_length = prompt_input_ids.shape[1] max_prompt_length = self.train_dataset_config["max_length"] @@ -107,10 +117,16 @@ def rollout(self, **kwargs) -> Dict[str, torch.Tensor]: "action_log_probs": [], "response_idx": [], } + with ThreadPoolExecutor(max_workers=self.num_generations) as executor: + results = list( + executor.map(self._run_agentic_pipeline, [copy.deepcopy(messages) for _ in range(self.num_generations)]) + ) + for i in range(self.num_generations): - _messages = copy.deepcopy(messages) - _messages = self._run_agentic_pipeline(_messages) - response_input_ids = self.tokenizer.apply_chat_template(_messages, return_tensors="pt", tokenize=True) + _messages, logprobs = results[i] + response_input_ids = self._build_prompt( + _messages, return_dict=True, return_tensors="pt", add_generation_prompt=False + )["input_ids"] # truncate if too long response_input_ids = response_input_ids[:, : self.grpo_config["max_length"] - to_pad_left] # add left right padding @@ -127,9 +143,14 @@ def rollout(self, **kwargs) -> Dict[str, torch.Tensor]: ) # [1, max_length-prompt_length] rollouts["attention_mask"].append(attention_mask) rollouts["action_mask"].append(action_mask) - rollouts["action_log_probs"].append( - torch.ones(size=(1, self.grpo_config["max_length"] - max_prompt_length)) - ) # dummy log probs + truncated_logprobs = logprobs[:, :, prompt_length : prompt_length + self.generate_config["max_tokens"]] + logprobs_padded = torch.nn.functional.pad( + truncated_logprobs, + (0, self.generate_config["max_tokens"] - truncated_logprobs.size(-1)), + "constant", + value=0.0, + ) # [1, max_new_tokens] + rollouts["action_log_probs"].append(logprobs_padded[0]) rollouts["response_idx"].append( torch.tensor( [ @@ -141,7 +162,6 @@ def rollout(self, **kwargs) -> Dict[str, torch.Tensor]: ) ) # [1, 2] rollouts["input_ids"].append(input_ids) - # breakpoint() rollouts = {k: torch.cat(v, dim=0).unsqueeze(0) for k, v in rollouts.items()} # [num_generations, ...] rollouts["temperature"] = torch.tensor([self.agentic_config.get("temperature", 1.0)]) if hasattr(self, "rollout_log_file") and self.producer_idx == 0 and not self.eval_mode: diff --git a/applications/ColossalChat/coati/distributed/agent/langgraph_math_agentic.py b/applications/ColossalChat/coati/distributed/agent/langgraph_math_agentic.py deleted file mode 100644 index 90a8d3fbd98b..000000000000 --- a/applications/ColossalChat/coati/distributed/agent/langgraph_math_agentic.py +++ /dev/null @@ -1,122 +0,0 @@ -from typing import Any, Dict - -import ray -from coati.distributed.agent.agentic import BaseAgenticProducer -from coati.distributed.agent.langgraph_math_agentic_utils import CustomOpenAIAPILLM, LangChainCustomLLM, python -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage -from langgraph.checkpoint.memory import MemorySaver -from langgraph.prebuilt import create_react_agent - - -@ray.remote -class LangGraphMathAgenticProducer(BaseAgenticProducer): - """ - Asyncronous version of the producer that uses vLLM for generation. - This class is designed to generate agentic response - """ - - def __init__( - self, - producer_idx, - num_producers, - num_consumer_procs, - num_episodes, - batch_size, - train_dataset_config, - model_config, - generate_config, - async_producers, - tokenizer_config=None, - agentic_config=None, - microbatch_size=1, - backend="transformers", - num_generations: int = 8, - consumer_plugin_config=None, - eval_dataset_config=None, - eval_interval=-1, # disable evaluation - grpo_config: Dict[str, Any] = None, - eval_save_dir: str = "./eval", - eval_generation_config={}, - project_name: str = None, - run_name: str = None, - wandb_group_name: str = None, - log_rollout_interval: int = 20, - rollout_log_file: str = "./rollout_log.jsonl", - enable_profiling: bool = False, - n_behind: int = 0, - ): - assert microbatch_size == 1 # microbatch_size must be 1 for agentic producer - assert batch_size == 1 # batch_size must be 1 for agentic producer - super().__init__( - producer_idx, - num_producers, - num_consumer_procs, - num_episodes, - batch_size, - train_dataset_config, - model_config, - generate_config, - async_producers, - tokenizer_config, - microbatch_size, - backend, - num_generations, - consumer_plugin_config, - eval_dataset_config=eval_dataset_config, - eval_interval=eval_interval, - grpo_config=grpo_config, - eval_save_dir=eval_save_dir, - eval_generation_config=eval_generation_config, - project_name=project_name, - run_name=run_name, - wandb_group_name=wandb_group_name, - log_rollout_interval=log_rollout_interval, - rollout_log_file=rollout_log_file, - enable_profiling=enable_profiling, - n_behind=n_behind, - ) - self.agentic_config = agentic_config - self.agentic_config.pop("agentic_type", None) - self.llm_client = CustomOpenAIAPILLM({"model": model_config["path"]}, producer_idx, self.async_producers) - self.llm = LangChainCustomLLM(self.llm_client) - # self.python_repl = PythonREPL() - # repl_tool = Tool( - # name="python_repl", - # description="A Python shell. Use this to execute python commands. Input should be a valid python command. If you want to see the output of a value, you should print it out with `print(...)`.", - # func=self.python_repl.run, - # ) - # self.tools = [repl_tool] - self.tools = [python] - self.memory = MemorySaver() - self.bot = create_react_agent(self.llm, self.tools, checkpointer=self.memory) - - def _run_agentic_pipeline(self, messages): - """ - Run the agentic pipeline to generate responses based on the input messages using the LangGraph. - """ - assert ( - len(messages) == 2 and messages[0]["role"] == "system" and messages[1]["role"] == "user" - ), "Only support 1 system message and 1 user message as input." - # inputs = messages - for event in self.bot.stream( - {"messages": [("system", messages[0]["content"]), ("user", "calculate the 1000th Fibonacci number")]}, - self.agentic_config, - ): - continue - breakpoint() - - final_state = self.bot.get_state(self.agentic_config) - transformer_messages = [] - for message in final_state[0]["messages"]: - tool_calls = None - if isinstance(message, SystemMessage): - message.content - elif isinstance(message, HumanMessage): - message.content - elif isinstance(message, AIMessage): - message.content - tool_calls = message.get("tool_calls", None) # [{"type": "function", "function": tool_call}] - elif isinstance(message, ToolMessage): - message.content - - return transformer_messages diff --git a/applications/ColossalChat/coati/distributed/agent/langgraph_math_agentic_utils.py b/applications/ColossalChat/coati/distributed/agent/langgraph_math_agentic_utils.py deleted file mode 100644 index cb587e0d6849..000000000000 --- a/applications/ColossalChat/coati/distributed/agent/langgraph_math_agentic_utils.py +++ /dev/null @@ -1,185 +0,0 @@ -# ------------------------------- -# 1. Define the Python tool -# ------------------------------- -import copy -import io -import random -import sys -from typing import Dict, List - -import ray -import torch -from langchain_core.language_models import BaseChatModel -from langchain_core.messages import AIMessage -from langchain_core.outputs.chat_generation import ChatGeneration -from langchain_core.outputs.chat_result import ChatResult -from langchain_core.prompts import PromptTemplate -from langchain_core.tools import tool -from langgraph.checkpoint.memory import MemorySaver -from langgraph.prebuilt import create_react_agent -from tool_calling_llm import ToolCallingLLM -from transformers import AutoTokenizer - -SYSTEM_PROMPT_TEMPLATE = """{task_description}. You have access to the following tools: - -{tools} - -Use the following format: - -Question: the input question you must answer -Thought: you should always think about what to do -Action: the action to take, should be one of [{tool_names}] -Action Input: the input to the action -Observation: the result of the action -... (this Thought/Action/Action Input/Observation can repeat N times) -Thought: I now know the final answer -Final Answer: the final answer to the original input question - -Begin! - -Question: {input} -Thought:{agent_scratchpad}""" - -SYSTEM_PROMPT = PromptTemplate.from_template(SYSTEM_PROMPT_TEMPLATE) - - -class Capturing(list): - """Capture stdout prints inside exec()""" - - def __enter__(self): - self._stdout = sys.stdout - sys.stdout = self._stringio = io.StringIO() - return self - - def __exit__(self, *args): - self.extend(self._stringio.getvalue().splitlines()) - sys.stdout = self._stdout - - -@tool -def python(code: str) -> str: - """ - This function executes a string of Python code and returns the printed output. - You need to print the output. Please import all libraries used in the code string. - """ - local_vars = {} - with Capturing() as output: - exec(code, {}, local_vars) - if output == []: - return "Error: No output printed from the code. Please ensure you print the output." - return "\n".join(output) - - -# ------------------------------- -# 2. Define a Custom API LLM wrapper -# ------------------------------- -class CustomOpenAIAPILLM: - def __init__(self, cfg: dict, producer_idx, generation_workers=None): - self.producer_idx = producer_idx - self.generation_workers = generation_workers - self.load_balancer_idx = producer_idx % len(self.generation_workers) - assert "model" in cfg, "Please specify the model name in the config" - self.tokenizer = AutoTokenizer.from_pretrained(cfg["model"]) - self.role_mapping = { - "system": "system", - "user": "user", - "assistant": "assistant", - "human": "user", - "tool": "tool", - } - - def invoke(self, messages: List[Dict[str, str]], **kwargs) -> str: - """ - messages: list of {"role": "user"/"assistant"/"system", "content": "..."} - """ - # load balancing - load = [ray.get(generation_worker.get_producer_load.remote()) for generation_worker in self.generation_workers] - min_load = min(load) - candidates = [i for i, l in enumerate(load) if l == min_load] - # random tie break - self.load_balancer_idx = random.choice(candidates) - generation_worker = self.generation_workers[self.load_balancer_idx] - transformer_messages = [] - for message in messages: - transformer_messages.append({"role": self.role_mapping[message.type], "content": message.content}) - input_ids = self.tokenizer.apply_chat_template( - transformer_messages, return_tensors="pt", tokenize=True, add_generation_prompt=True - ) - attention_mask = torch.ones_like(input_ids) - rollouts = ray.get(generation_worker.generate.remote(input_ids, attention_mask, **kwargs)) - response = self.tokenizer.batch_decode( - rollouts["input_ids"][0][:, input_ids.size(-1) :], skip_special_tokens=True - )[0] - return response - - -class LangChainCustomLLM(ToolCallingLLM, BaseChatModel): - client: CustomOpenAIAPILLM = None - - def __init__(self, client: CustomOpenAIAPILLM): - super().__init__() - self.client = client - - def _generate(self, messages, stop=None, run_manager=None, **kwargs): - # content = self.client.invoke([m.dict() for m in messages]) - # chat_result = ChatResult( - # generations=[ChatGeneration(message=AIMessage(content=content))] - # ) - print("messages:", messages) - breakpoint() - system_message, functions = self._generate_system_message_and_functions(kwargs) - sample_params = {"stop": stop} if stop is not None else {} - sample_params.update({k: v for k, v in kwargs.items() if k in ["temperature", "top_p", "top_k", "max_tokens"]}) - messages_ = copy.deepcopy(messages) - messages_[0].content = messages_[0].content + "\n" + system_message.content - response_message = self.client.invoke( # type: ignore[safe-super] - [system_message] + messages, **{"sample_params": sample_params} - ) - breakpoint() - response = self._process_response(AIMessage(content=response_message), functions) - return ChatResult(generations=[ChatGeneration(message=response)]) - - @property - def _llm_type(self) -> str: - return "custom-api-llm" - - -# ------------------------------- -# 3. Build a ReAct Agent with LangGraph -# ------------------------------- -def build_agent(): - # Wrap custom API LLM in LangChain-compatible interface - - # Init LLM - llm_client = CustomOpenAIAPILLM() - llm = LangChainCustomLLM(llm_client) - - # Tools - tools = [python] - - # Memory (optional) - memory = MemorySaver() - - # Build ReAct agent - agent = create_react_agent(llm, tools, checkpointer=memory) - return agent - - -# ------------------------------- -# 4. Run the agent on a math problem -# ------------------------------- -if __name__ == "__main__": - agent = build_agent() - - # Example math question - user_input = "What is the least common multiple of 18 and 24? Use Python if needed." - - config = {"configurable": {"thread_id": "math-1"}} - for event in agent.stream({"messages": [("user", user_input)]}, config): - if "agent" in event: - print("Agent event:", event["agent"]["messages"][-1].content) - elif "tools" in event: - print("Tool event:", event["tools"]["messages"][-1].content) - - final_state = agent.get_state(config) - print("Final Answer:", final_state["messages"][-1].content) diff --git a/applications/ColossalChat/coati/distributed/agent/math_tools.py b/applications/ColossalChat/coati/distributed/agent/math_tools.py new file mode 100644 index 000000000000..c1a5bfb436ad --- /dev/null +++ b/applications/ColossalChat/coati/distributed/agent/math_tools.py @@ -0,0 +1,31 @@ +from langchain_core.tools import Tool +from langchain_experimental.utilities import PythonREPL +from pydantic import BaseModel, Field +from pydantic.fields import FieldInfo + + +def make_title(field_name: str, field_info: FieldInfo) -> str: + return field_name + + +class PythonInput(BaseModel): + code: str = Field(description="The python code to execute", field_title_generator=make_title) + + +python_repl = PythonREPL() + + +def run_python_code(code: str) -> str: + if code.startswith("```python"): + code = code.replace("```python", "```", 1).strip() + if code.startswith("```py"): # qwen3 uses ```py + code = code.replace("```py", "```", 1).strip() + return python_repl.run(code, timeout=20) + + +repl_tool = Tool( + name="python_repl", + description="A Python shell. Use this to execute python commands. Input should be a valid python command. If you want to see the output of a value, you should print it out with `print(...)`.", + func=run_python_code, + args_schema=PythonInput, +) diff --git a/applications/ColossalChat/coati/distributed/agent/qwen_math_agentic.py b/applications/ColossalChat/coati/distributed/agent/qwen_math_agentic_producer.py similarity index 96% rename from applications/ColossalChat/coati/distributed/agent/qwen_math_agentic.py rename to applications/ColossalChat/coati/distributed/agent/qwen_math_agentic_producer.py index 273544b2a555..8631da9c2c13 100644 --- a/applications/ColossalChat/coati/distributed/agent/qwen_math_agentic.py +++ b/applications/ColossalChat/coati/distributed/agent/qwen_math_agentic_producer.py @@ -1,7 +1,7 @@ from typing import Any, Dict import ray -from coati.distributed.agent.agentic import BaseAgenticProducer +from coati.distributed.agent.base import BaseAgenticProducer from coati.distributed.agent.qwen_math_agentic_utils import TIR_SYSTEM, CustomTransformers from qwen_agent.agents import TIRMathAgent @@ -24,6 +24,7 @@ def __init__( model_config, generate_config, async_producers, + tool_workers=[], tokenizer_config=None, agentic_config=None, microbatch_size=1, @@ -85,4 +86,5 @@ def _run_agentic_pipeline(self, messages): for response in self.bot.run(messages): continue messages.extend(response) + # breakpoint() return messages diff --git a/applications/ColossalChat/coati/distributed/agent/qwen_math_agentic_utils.py b/applications/ColossalChat/coati/distributed/agent/qwen_math_agentic_utils.py index 7787a99aabfe..d8a97ccf619b 100644 --- a/applications/ColossalChat/coati/distributed/agent/qwen_math_agentic_utils.py +++ b/applications/ColossalChat/coati/distributed/agent/qwen_math_agentic_utils.py @@ -55,17 +55,18 @@ class LocalLLMFromGenerationWorkers: A class that wraps the Transformers model to support API-based text generation. """ - def __init__(self, generation_worker=None): + def __init__(self, generation_worker=None, tokenizer=None): self.device = "cpu" self.generation_worker = generation_worker + self.tokenizer = tokenizer def generate(self, **kwargs): - breakpoint() if "max_new_tokens" in kwargs: # we use VLLM backend for generation, which uses `max_tokens` kwargs["max_tokens"] = kwargs["max_new_tokens"] del kwargs["max_new_tokens"] rollouts = ray.get(self.generation_worker.generate.remote(**kwargs)) + # breakpoint() return rollouts["input_ids"] @@ -108,7 +109,7 @@ def __init__(self, cfg: dict, producer_idx, generation_workers=None): ################################################################ self.generation_workers = generation_workers self.hf_models = [ - LocalLLMFromGenerationWorkers(generation_worker=generation_worker) + LocalLLMFromGenerationWorkers(generation_worker=generation_worker, tokenizer=self.tokenizer) for generation_worker in generation_workers ] self.producer_idx = producer_idx @@ -133,10 +134,9 @@ def _chat_stream( candidates = [i for i, l in enumerate(load) if l == min_load] # random tie break self.load_balancer_idx = random.choice(candidates) + # breakpoint() response = self._chat_no_stream(messages=messages, generate_cfg=generate_cfg) - # if self.producer_idx == 0: - # print(response) - breakpoint() + # breakpoint() yield response diff --git a/applications/ColossalChat/coati/distributed/agent/tool_worker.py b/applications/ColossalChat/coati/distributed/agent/tool_worker.py new file mode 100644 index 000000000000..ae148af864c0 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/agent/tool_worker.py @@ -0,0 +1,77 @@ +from typing import Any, Dict, List, Optional, Union + +import ray +from langchain.tools import BaseTool + + +@ray.remote(concurrency_groups={"io": 1, "compute": 5}) +class ToolWorker: + """ + A unified wrapper class for LangChain tools, enabling a standard + interface to call tools regardless of their internal differences. + """ + + def __init__(self, tools: List[BaseTool]): + """ + Initialize ToolWorker with a list of LangChain tools. + + Args: + tools (List[BaseTool]): List of LangChain tools to register. + """ + self._tool_registry: Dict[str, BaseTool] = {tool.name: tool for tool in tools} + self.pending = 0 + + @ray.method(concurrency_group="io") + def get_load(self) -> int: + """Return the current load of the worker.""" + return self.pending + + @ray.method(concurrency_group="io") + def increase_load(self): + """Increase the load counter.""" + self.pending += 1 + + @ray.method(concurrency_group="io") + def list_tools(self) -> List[str]: + """Return the list of available tool names.""" + return list(self._tool_registry.keys()) + + @ray.method(concurrency_group="io") + def get_tool_description(self, tool_name: str) -> Optional[str]: + """Return the description of a specific tool.""" + tool = self._tool_registry.get(tool_name) + return tool.description if tool else None + + @ray.method(concurrency_group="io") + def get_args_schema(self, tool_name: str): + """Return the argument schema of a specific tool.""" + assert tool_name in self._tool_registry, f"Tool '{tool_name}' not found. Available: {self.list_tools()}" + tool = self._tool_registry.get(tool_name) + schema = tool.args_schema.model_json_schema(by_alias=False) + return schema + + @ray.method(concurrency_group="compute") + def call(self, tool_name: str, input_data: Union[str, Dict[str, Any]], **kwargs) -> Any: + """ + Call a tool by name with input data. + + Args: + tool_name (str): Name of the tool to call. + input_data (Union[str, Dict[str, Any]]): Input to pass to the tool. + **kwargs: Extra keyword arguments for the tool. + + Returns: + Any: The tool's output. + """ + if tool_name == "return_parsing_error": + self.pending -= 1 + return "Error: Tool call parsing error. Please use the correct JSON format." + if tool_name not in self._tool_registry: + return f"Error: Tool {tool_name} not found. Available tools: {self.list_tools()}" + tool = self._tool_registry[tool_name] + try: + ret = tool.run(input_data, **kwargs) + except Exception as e: + ret = f"Error: Tool {tool_name} execution failed with error: {str(e)}" + self.pending -= 1 + return ret diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 0c25eac272eb..d01edd042037 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -344,7 +344,7 @@ def __init__( self.model_config = model_config self.tokenizer = tokenizer self.num_generations = num_generations - self.queued_requests = [] + self.running_requests = [] self.microbatch_size = microbatch_size self.profiler = profiler @@ -358,8 +358,10 @@ async def generate( input_ids (torch.Tensor): shape [B, S], B=1 attention_mask (torch.Tensor): shape [B, S] """ - # breakpoint() assert input_ids.size(0) == attention_mask.size(0) == 1 + request_id = ( + str(uuid4()) if not "request_id" in kwargs else kwargs.pop("request_id") + ) # use fixed request_id to reuse kv cache response_start_idx = input_ids.size(1) first_non_padding_token_idx = (input_ids != self.tokenizer.pad_token_id).int().argmax(dim=1) input_ids_no_padding = [input_ids.tolist()[0][first_non_padding_token_idx[0] :]] @@ -373,10 +375,10 @@ async def generate( out_len = [] log_probs = [] response_idx = [] - while len(self.queued_requests) >= self.microbatch_size: + while len(self.running_requests) >= self.microbatch_size: + # print(f"Current running {len(self.running_requests)}/{self.microbatch_size} requests, waiting...") await asyncio.sleep(0.1) - request_id = str(uuid4()) - self.queued_requests.append(request_id) # enqueue + self.running_requests.append(request_id) # enqueue # pop the first input_ids and attention_mask prompt_token_ids = input_ids_no_padding[0] self.profiler.enter(f"vllm generate {request_id}") @@ -386,14 +388,25 @@ async def generate( async for chunk in outputs: # generate the output tokens, can yield to avoid blocking pass - self.queued_requests.remove(request_id) # dequeue - for output_i in chunk.outputs: + self.running_requests.remove(request_id) # dequeue + if self.generate_config.get("prompt_logprobs", None) is not None: + # when prompt_logprobs is not None, vllm will return logprobs for the whole sequence + # for agentic producer, we return the logprobs of the whole sequence + log_probs = [ + [m[t].logprob if m is not None else 0.0 for m, t in zip(chunk.prompt_logprobs, chunk.prompt_token_ids)] + ] + for _ in range(sample_params.n - 1): + log_probs.append([t for t in log_probs[0]]) # repeat the same logprobs for num_generations times + else: + log_probs = [[] for _ in range(sample_params.n)] + + for generation_id, output_i in enumerate(chunk.outputs): out_len.append(len(output_i.token_ids)) out_tokens.append(list(output_i.token_ids)) response_idx.append((response_start_idx, response_start_idx + len(output_i.token_ids) - 1)) assert len(output_i.logprobs) == len(output_i.token_ids) p = [m[t].logprob for m, t in zip(output_i.logprobs, output_i.token_ids)] - log_probs.append(p) + log_probs[generation_id].extend(p) self.profiler.exit(f"vllm generate {request_id}") # pad them max_len = self.sample_params.max_tokens @@ -402,7 +415,7 @@ async def generate( for i, new_token_ids in enumerate(out_tokens): pad_len = max_len - out_len[i] out_tokens[i] = new_token_ids + [self.tokenizer.pad_token_id] * pad_len - log_probs[i] = log_probs[i] + [0.0] * pad_len + log_probs[i] = log_probs[i] + [0.0] * (max_len - len(log_probs[i])) action_mask[i, out_len[i] :] = 0 out_tokens = torch.tensor(out_tokens) diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 926eca4ff2e1..a22dd7856b49 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -4,8 +4,9 @@ from typing import Any, Dict, Optional import ray -from coati.distributed.agent.langgraph_math_agentic import LangGraphMathAgenticProducer -from coati.distributed.agent.qwen_math_agentic import QwenMathAgenticProducer +from coati.distributed.agent.agentic_producer import AgenticProducer +from coati.distributed.agent.qwen_math_agentic_producer import QwenMathAgenticProducer +from coati.distributed.agent.tool_worker import ToolWorker from .consumer import SimpleConsumer from .grpo_consumer import GRPOConsumer @@ -21,7 +22,7 @@ Producer_MAP = {"Simple": SimpleProducer, "Async": AsyncSimpleProducer} AGENTIC_PRODUCER_MAP = { "QwenMathAgent": QwenMathAgenticProducer, - "LangGraphMathAgent": LangGraphMathAgenticProducer, + "Agentic": AgenticProducer, } # supported agentic producers @@ -165,21 +166,16 @@ def launch_distributed( ) producer_procs.append(producer) ray.get([p.setup.remote() for p in producer_procs]) - """ - # test async generate - import torch - import asyncio - import time - async def test(): - res_ref = producer_procs[0].generate.remote(torch.ones((2, 10), dtype=torch.int), torch.ones((2, 10), dtype=torch.int)) - res = await res_ref - return res - res = asyncio.run(test()) - print(res) - time.sleep(1000) - """ if enable_agentic: + from coati.distributed.agent.math_tools import repl_tool + + # setup tool workers + tool_workers = [] + if agentic_config["agentic_producer"] == "Agentic": + # 10 tool workers can handle 50 queries simultaneously + # note that imported repl_tool will be serialized and deserialized in each tool worker, therefore all workers can run parallely + tool_workers = [ToolWorker.remote([repl_tool]) for _ in range(agentic_config.get("num_tool_workers", 10))] # when agentic is enabled, we use core_producer as inference engine and # AgenticProducer as the real producer _producer_procs = producer_procs @@ -194,7 +190,7 @@ async def test(): producer_procs = [ agentic_producer_cls.options(num_cpus=1).remote( producer_idx=producer_idx, - num_producers=num_producers * train_batch_size, + num_producers=num_producers * inference_batch_size, num_consumer_procs=num_consumer_procs, num_episodes=num_episodes, batch_size=1, # batch_size must be 1 for agentic producer @@ -202,6 +198,7 @@ async def test(): model_config=inference_model_config, generate_config=generate_config, async_producers=_producer_procs, + tool_workers=tool_workers, tokenizer_config=tokenizer_config, agentic_config=agentic_config, microbatch_size=1, # microbatch_size must be 1 for agentic producer diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 44a6214ff6b0..d5d7cba1d6f3 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -587,26 +587,6 @@ def __init__( self.condition = asyncio.Condition() self.data_ready_for_sending = [] - # @torch.no_grad() - # async def generate(self, input_ids, attention_mask, **kwargs): - # tasks = [] - # print("input_ids:", input_ids) - # for prompt_id in range(input_ids.size(0)): - # new_kwargs = copy.deepcopy(kwargs) - # if "gt_answer" in new_kwargs: - # new_kwargs["gt_answer"] = new_kwargs["gt_answer"][prompt_id] - # if "test_cases" in new_kwargs: - # new_kwargs["test_cases"] = new_kwargs["test_cases"][prompt_id] - # tasks.append( - # self.model.generate( - # input_ids[prompt_id].unsqueeze(0), - # attention_mask[prompt_id].unsqueeze(0), - # **new_kwargs, - # ) - # ) - # rollouts = await asyncio.gather(*tasks) - # return rollouts - @torch.no_grad() async def generate(self, input_ids, attention_mask, **kwargs): # naive rollout strategy @@ -647,7 +627,7 @@ async def get_producer_load(self): """ Get the load of each producer. """ - return len(self.model.queued_requests) + return len(self.model.running_requests) async def async_sync_model(self, episode, step, num_processes: int = 1) -> None: """ diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index bd92f64333b1..bab6f14b6a94 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -164,8 +164,8 @@ parser.add_argument( "--agentic-type", type=str, - default="QwenMathAgent", - choices=["QwenMathAgent", "LangGraphMathAgent"], + default="Agentic", + choices=["Agentic", "QwenMathAgent"], help="Agentic model type for agentic training.", ) parser.add_argument( @@ -418,9 +418,6 @@ if "agentic" in args.backend: assert "vllm" in args.backend, "Agentic backend only supports async-agentic-vllm backends." generate_config["n"] = 1 # agentic producer use AsyncProducer which processes one request a time - generate_config["max_tokens"] = ( - 2048 # max new tokens for each agentic step, usually smaller than max_new_tokens as agentic model will generate multiple steps - ) if args.agentic_type == "QwenMathAgent": agentic_config = { "agentic_producer": "QwenMathAgent", @@ -433,8 +430,15 @@ agentic_config["generate_cfg"].update( {k: v for k, v in generate_config.items() if k in ["top_k", "top_p", "temperature"]} ) - elif args.agentic_type == "LangGraphMathAgent": - agentic_config = {"configurable": {"thread_id": "math-1"}, "agentic_producer": "LangGraphMathAgent"} + elif args.agentic_type == "Agentic": + generate_config["stop"] = ["<|im_end|>"] + generate_config["prompt_logprobs"] = 0 + agentic_config = { + "agentic_producer": "Agentic", + "tool_call_budget": 5, + "llm_call_budget": 10, + "max_tokens": 2048, + } else: raise ValueError(f"Unsupported agentic model type: {args.agentic_type}") else: From b6391bd720953d126469c6ed93898238725f76d6 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Tue, 16 Sep 2025 16:29:13 +0800 Subject: [PATCH 05/11] remove qwen agent producer --- .../agent/qwen_math_agentic_producer.py | 90 --------- .../agent/qwen_math_agentic_utils.py | 176 ------------------ .../ColossalChat/coati/distributed/launch.py | 2 - applications/ColossalChat/rl_example.py | 16 +- 4 files changed, 2 insertions(+), 282 deletions(-) delete mode 100644 applications/ColossalChat/coati/distributed/agent/qwen_math_agentic_producer.py delete mode 100644 applications/ColossalChat/coati/distributed/agent/qwen_math_agentic_utils.py diff --git a/applications/ColossalChat/coati/distributed/agent/qwen_math_agentic_producer.py b/applications/ColossalChat/coati/distributed/agent/qwen_math_agentic_producer.py deleted file mode 100644 index 8631da9c2c13..000000000000 --- a/applications/ColossalChat/coati/distributed/agent/qwen_math_agentic_producer.py +++ /dev/null @@ -1,90 +0,0 @@ -from typing import Any, Dict - -import ray -from coati.distributed.agent.base import BaseAgenticProducer -from coati.distributed.agent.qwen_math_agentic_utils import TIR_SYSTEM, CustomTransformers -from qwen_agent.agents import TIRMathAgent - - -@ray.remote -class QwenMathAgenticProducer(BaseAgenticProducer): - """ - Asyncronous version of the producer that uses vLLM for generation. - This class is designed to generate agentic response - """ - - def __init__( - self, - producer_idx, - num_producers, - num_consumer_procs, - num_episodes, - batch_size, - train_dataset_config, - model_config, - generate_config, - async_producers, - tool_workers=[], - tokenizer_config=None, - agentic_config=None, - microbatch_size=1, - backend="transformers", - num_generations: int = 8, - consumer_plugin_config=None, - eval_dataset_config=None, - eval_interval=-1, # disable evaluation - grpo_config: Dict[str, Any] = None, - eval_save_dir: str = "./eval", - eval_generation_config={}, - project_name: str = None, - run_name: str = None, - wandb_group_name: str = None, - log_rollout_interval: int = 20, - rollout_log_file: str = "./rollout_log.jsonl", - enable_profiling: bool = False, - n_behind: int = 0, - ): - assert microbatch_size == 1 # microbatch_size must be 1 for agentic producer - assert batch_size == 1 # batch_size must be 1 for agentic producer - super().__init__( - producer_idx, - num_producers, - num_consumer_procs, - num_episodes, - batch_size, - train_dataset_config, - model_config, - generate_config, - async_producers, - tokenizer_config, - microbatch_size, - backend, - num_generations, - consumer_plugin_config, - eval_dataset_config=eval_dataset_config, - eval_interval=eval_interval, - grpo_config=grpo_config, - eval_save_dir=eval_save_dir, - eval_generation_config=eval_generation_config, - project_name=project_name, - run_name=run_name, - wandb_group_name=wandb_group_name, - log_rollout_interval=log_rollout_interval, - rollout_log_file=rollout_log_file, - enable_profiling=enable_profiling, - n_behind=n_behind, - ) - self.agentic_config = model_config if not agentic_config else agentic_config - self.agentic_config.update({"model": model_config["path"]}) - self.llm = CustomTransformers(self.agentic_config, self.producer_idx, generation_workers=self.async_producers) - self.bot = TIRMathAgent(llm=self.llm, name=model_config["path"], system_message=TIR_SYSTEM) - - def _run_agentic_pipeline(self, messages): - """ - Run the agentic pipeline to generate responses based on the input messages using the TIRMathAgent. - """ - for response in self.bot.run(messages): - continue - messages.extend(response) - # breakpoint() - return messages diff --git a/applications/ColossalChat/coati/distributed/agent/qwen_math_agentic_utils.py b/applications/ColossalChat/coati/distributed/agent/qwen_math_agentic_utils.py deleted file mode 100644 index d8a97ccf619b..000000000000 --- a/applications/ColossalChat/coati/distributed/agent/qwen_math_agentic_utils.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright 2023 The Qwen team, Alibaba Group. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""A TIR(tool-integrated reasoning) math agent -```bash -python tir_math.py -``` -""" -import os -import random - -import ray -from qwen_agent.agents import TIRMathAgent -from qwen_agent.llm.base import register_llm -from qwen_agent.llm.function_calling import BaseFnCallModel -from qwen_agent.llm.transformers_llm import Transformers -from qwen_agent.log import logger -from transformers import AutoTokenizer - -ROOT_RESOURCE = os.path.join(os.path.dirname(__file__), "resource") - -# We use the following two systems to distinguish between COT mode and TIR mode -TIR_SYSTEM = """Please integrate natural language reasoning with programs to solve the problem above, and put your final answer within \\boxed{}.""" -COT_SYSTEM = """Please reason step by step, and put your final answer within \\boxed{}.""" - -from transformers import StoppingCriteria - -tokenizer = AutoTokenizer.from_pretrained("/mnt/nfs/share/data/model/Qwen2.5-Math-7B-Instruct", trust_remote_code=True) - - -class StopOnTokens(StoppingCriteria): - def __init__(self, stop_token_ids): - self.stop_token_ids = stop_token_ids - - def __call__(self, input_ids, scores, **kwargs): - # Check if the last token is one of the stop tokens - if input_ids[0, -1].item() in self.stop_token_ids: - return True - return False - - -class LocalLLMFromGenerationWorkers: - """ - A class that wraps the Transformers model to support API-based text generation. - """ - - def __init__(self, generation_worker=None, tokenizer=None): - self.device = "cpu" - self.generation_worker = generation_worker - self.tokenizer = tokenizer - - def generate(self, **kwargs): - if "max_new_tokens" in kwargs: - # we use VLLM backend for generation, which uses `max_tokens` - kwargs["max_tokens"] = kwargs["max_new_tokens"] - del kwargs["max_new_tokens"] - rollouts = ray.get(self.generation_worker.generate.remote(**kwargs)) - # breakpoint() - return rollouts["input_ids"] - - -@register_llm("api_based_transformers") -class CustomTransformers(Transformers): - """ - Transformers class that supports API-based text generation. - """ - - def __init__(self, cfg: dict, producer_idx, generation_workers=None): - BaseFnCallModel.__init__(self, cfg) # skip the super() init of Transformers to avoid loading hf model - ############ Setup logic from Transformers.__init__ ############### - if "model" not in cfg: - raise ValueError("Please provide the model id or directory through `model` in cfg.") - - try: - from transformers import AutoConfig, AutoProcessor, PreTrainedTokenizer, PreTrainedTokenizerFast - except ImportError as e: - raise ImportError( - "Could not import classes from transformers. " "Please install it with `pip install -U transformers`" - ) from e - - self.hf_config = AutoConfig.from_pretrained(cfg["model"]) - arch = self.hf_config.architectures[0] - if len(self.hf_config.architectures) > 1: - logger.warning( - f"The config for the transformers model type contains more than one architecture, choosing the first: {arch}" - ) - - # try loading a processor, if got a tokenizer, regarding the model as text-only - processor = AutoProcessor.from_pretrained(cfg["model"]) - if isinstance(processor, (PreTrainedTokenizer, PreTrainedTokenizerFast)): - logger.info(f"Regarding the transformers model as text-only since its processor is a tokenizer.") - self.tokenizer = processor - self._support_multimodal_input = False - else: - self.processor = processor - self.tokenizer = self.processor.tokenizer - self._support_multimodal_input = True - ################################################################ - self.generation_workers = generation_workers - self.hf_models = [ - LocalLLMFromGenerationWorkers(generation_worker=generation_worker, tokenizer=self.tokenizer) - for generation_worker in generation_workers - ] - self.producer_idx = producer_idx - self.load_balancer_idx = producer_idx % len(self.generation_workers) - - @property - def hf_model(self): - # Simple round-robin load balancing - model = self.hf_models[self.load_balancer_idx] - return model - - def _chat_stream( - self, - messages, - delta_stream: bool, - generate_cfg: dict, - ): - # overwrite streaming because streamer is not serializable - # determine load balancer idx based on producer load, refresh every generation - load = [ray.get(generation_worker.get_producer_load.remote()) for generation_worker in self.generation_workers] - min_load = min(load) - candidates = [i for i, l in enumerate(load) if l == min_load] - # random tie break - self.load_balancer_idx = random.choice(candidates) - # breakpoint() - response = self._chat_no_stream(messages=messages, generate_cfg=generate_cfg) - # breakpoint() - yield response - - -def init_agent_service(): - llm_cfg = { - # Use the OpenAI-compatible model service provided by DashScope: - "model": "/mnt/nfs/share/data/model/Qwen2.5-Math-7B-Instruct", - "model_type": "transformers", - "generate_cfg": { - # Using the API's native tool call interface - "top_k": 1, - }, - } - llm = CustomTransformers(llm_cfg) - bot = TIRMathAgent(llm=llm, name="Qwen2.5-Math", system_message=TIR_SYSTEM) - return bot - - -def app_tui(): - # Define the agent - bot = init_agent_service() - - # Chat - messages = [] - while True: - # Query example: 斐波那契数列前10个数字 - query = input("user question: ") - messages.append({"role": "user", "content": query}) - response = [] - for response in bot.run(messages): - print("bot response:", response) - messages.extend(response) - - -# if __name__ == '__main__': -# # Test the TIR math agent locally -# app_tui() diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index a22dd7856b49..24f16e7c018e 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -5,7 +5,6 @@ import ray from coati.distributed.agent.agentic_producer import AgenticProducer -from coati.distributed.agent.qwen_math_agentic_producer import QwenMathAgenticProducer from coati.distributed.agent.tool_worker import ToolWorker from .consumer import SimpleConsumer @@ -21,7 +20,6 @@ } Producer_MAP = {"Simple": SimpleProducer, "Async": AsyncSimpleProducer} AGENTIC_PRODUCER_MAP = { - "QwenMathAgent": QwenMathAgenticProducer, "Agentic": AgenticProducer, } # supported agentic producers diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index bab6f14b6a94..706f2708cd2b 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -165,7 +165,7 @@ "--agentic-type", type=str, default="Agentic", - choices=["Agentic", "QwenMathAgent"], + choices=["Agentic"], help="Agentic model type for agentic training.", ) parser.add_argument( @@ -418,19 +418,7 @@ if "agentic" in args.backend: assert "vllm" in args.backend, "Agentic backend only supports async-agentic-vllm backends." generate_config["n"] = 1 # agentic producer use AsyncProducer which processes one request a time - if args.agentic_type == "QwenMathAgent": - agentic_config = { - "agentic_producer": "QwenMathAgent", - "model": args.model, - "model_type": "transformers", - "generate_cfg": { - "max_input_tokens": args.max_new_tokens + args.max_prompt_tokens, - }, - } - agentic_config["generate_cfg"].update( - {k: v for k, v in generate_config.items() if k in ["top_k", "top_p", "temperature"]} - ) - elif args.agentic_type == "Agentic": + if args.agentic_type == "Agentic": generate_config["stop"] = ["<|im_end|>"] generate_config["prompt_logprobs"] = 0 agentic_config = { From d47c56356bc2325db2d9209d4c56a845d7cd7ffc Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Thu, 18 Sep 2025 16:45:37 +0800 Subject: [PATCH 06/11] fix rollout, action mask, attention mask bugs --- .../distributed/agent/agentic_producer.py | 9 +-- .../coati/distributed/agent/base.py | 15 +++-- .../coati/distributed/agent/math_tools.py | 2 +- .../coati/distributed/consumer.py | 1 - .../coati/distributed/inference_backend.py | 2 +- .../ColossalChat/coati/distributed/launch.py | 2 +- .../coati/distributed/producer.py | 57 ++++++++++++------- .../coati/distributed/reward/reward_fn.py | 16 +++++- .../conversation_template/qwen3.json | 8 +++ applications/ColossalChat/rl_example.py | 13 ++++- 10 files changed, 80 insertions(+), 45 deletions(-) create mode 100644 applications/ColossalChat/conversation_template/qwen3.json diff --git a/applications/ColossalChat/coati/distributed/agent/agentic_producer.py b/applications/ColossalChat/coati/distributed/agent/agentic_producer.py index 234734c41527..48b07ffadde3 100644 --- a/applications/ColossalChat/coati/distributed/agent/agentic_producer.py +++ b/applications/ColossalChat/coati/distributed/agent/agentic_producer.py @@ -6,7 +6,6 @@ import ray from coati.distributed.agent.base import BaseAgenticProducer -from transformers import AutoTokenizer DEFAULT_SYSTEM_MESSAGE = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The Assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here .""" @@ -88,13 +87,6 @@ def __init__( self.tool_workers = tool_workers self.agentic_config = model_config if not agentic_config else agentic_config self.agentic_config.update({"model": model_config["path"]}) - tokenizer_path = None - if tokenizer_config and "path" in tokenizer_config: - tokenizer_path = tokenizer_config["path"] - elif "path" in model_config: - tokenizer_path = model_config["path"] - assert tokenizer_path is not None, "Tokenizer path must be provided either in tokenizer_config or model_config." - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) self.tools_schema = [] self.tool_call_budget = self.agentic_config.get("tool_call_budget", 3) self.llm_call_budget = self.agentic_config.get("llm_call_budget", 10) @@ -258,6 +250,7 @@ def _run_agentic_pipeline(self, messages): ) ) llm_call_count += 1 + self.consumer_global_step = response.pop("consumer_global_step") response_input_ids = response["input_ids"] logprobs = response["action_log_probs"] response_text = self.tokenizer.decode( diff --git a/applications/ColossalChat/coati/distributed/agent/base.py b/applications/ColossalChat/coati/distributed/agent/base.py index b290e9e44d70..6ea5b17e10ba 100644 --- a/applications/ColossalChat/coati/distributed/agent/base.py +++ b/applications/ColossalChat/coati/distributed/agent/base.py @@ -135,15 +135,13 @@ def rollout(self, **kwargs) -> Dict[str, torch.Tensor]: input_ids = torch.nn.functional.pad( response_input_ids, (to_pad_left, to_pad_right), "constant", value=self.tokenizer.pad_token_id ) # [1, max_length] - attention_mask = torch.nn.functional.pad( - torch.ones_like(response_input_ids), (to_pad_left, to_pad_right), "constant", value=0 - ) # [1, max_length] - action_mask = torch.nn.functional.pad( - torch.ones(size=(1, response_length)), (0, to_pad_right), "constant", value=0 - ) # [1, max_length-prompt_length] + attention_mask = input_ids.ne(self.tokenizer.pad_token_id).int() # [1, max_length] + action_mask = input_ids[:, max_prompt_length:].ne(self.tokenizer.pad_token_id).int() rollouts["attention_mask"].append(attention_mask) rollouts["action_mask"].append(action_mask) - truncated_logprobs = logprobs[:, :, prompt_length : prompt_length + self.generate_config["max_tokens"]] + truncated_logprobs = logprobs[ + :, :, prompt_length : prompt_length + self.generate_config["max_tokens"] + ] # truncate to max_new_tokens logprobs_padded = torch.nn.functional.pad( truncated_logprobs, (0, self.generate_config["max_tokens"] - truncated_logprobs.size(-1)), @@ -177,7 +175,8 @@ def rollout(self, **kwargs) -> Dict[str, torch.Tensor]: "rollout": self.tokenizer.batch_decode( rollouts["input_ids"][:, 0], skip_special_tokens=True ), - } + }, + ensure_ascii=False, ) + "\n" ) diff --git a/applications/ColossalChat/coati/distributed/agent/math_tools.py b/applications/ColossalChat/coati/distributed/agent/math_tools.py index c1a5bfb436ad..dba8b93b6519 100644 --- a/applications/ColossalChat/coati/distributed/agent/math_tools.py +++ b/applications/ColossalChat/coati/distributed/agent/math_tools.py @@ -20,7 +20,7 @@ def run_python_code(code: str) -> str: code = code.replace("```python", "```", 1).strip() if code.startswith("```py"): # qwen3 uses ```py code = code.replace("```py", "```", 1).strip() - return python_repl.run(code, timeout=20) + return python_repl.run(code, timeout=30) repl_tool = Tool( diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 6a885f23b1e4..45aaead49e5e 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -325,7 +325,6 @@ def loop(self) -> None: ) # for setting start index when resuming training if self.rank == 0: print(f"Saved model checkpoint at step {step + 1} in folder {self.save_dir}") - # breakpoint() if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and ( episode != 0 or step >= self.n_behind ): diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index d01edd042037..6fd3fa2ddb49 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -409,7 +409,7 @@ async def generate( log_probs[generation_id].extend(p) self.profiler.exit(f"vllm generate {request_id}") # pad them - max_len = self.sample_params.max_tokens + max_len = sample_params.max_tokens action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype) for i, new_token_ids in enumerate(out_tokens): diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 24f16e7c018e..ca77316267d1 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -68,7 +68,7 @@ def launch_distributed( eval_interval: int = 100, eval_save_dir: Optional[str] = None, eval_generation_config: Optional[Dict[str, Any]] = None, - log_rollout_interval: int = 20, + log_rollout_interval: int = 1, rollout_save_dir: str = "./rollout", enable_profiling: bool = False, n_behind: int = 0, diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index d5d7cba1d6f3..2964885ba867 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -93,7 +93,14 @@ def __init__( reward_model_kwargs = { k: v for k, v in grpo_config.items() - if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length", "code_verifier_api_url"] + if k + in [ + "soft_over_length_punishment", + "max_new_tokens", + "cache_length", + "code_verifier_api_url", + "forced_patterns", + ] } self.response_format_tags = grpo_config.get("response_format_tags", None) if producer_idx == 0 and rollout_log_file is not None: @@ -103,7 +110,7 @@ def __init__( ) else: os.makedirs(os.path.dirname(rollout_log_file), exist_ok=True) - self.rollout_log_file = open(rollout_log_file, "w", encoding="utf8") + self.rollout_log_file = open(rollout_log_file, "a", encoding="utf8") if self.producer_idx == 0: self.wandb_run = wandb.init( project=project_name, @@ -260,6 +267,9 @@ def sync_model(self, episode, step) -> None: state_dict = ray_broadcast_tensor_dict( None, self.num_producers, device=self.device, group_name="sync_model" ) + print( + f"[P{self.producer_idx}] Sync model episode {episode} step {(step + 1) // self.num_microbatches - 1} done" + ) if "consumer_global_step" in state_dict: self.consumer_global_step = state_dict.pop("consumer_global_step").item() self.load_state_dict(state_dict) @@ -498,7 +508,8 @@ def rollout(self, input_ids, attention_mask, **kwargs): "rollout": self.tokenizer.batch_decode( rollouts["input_ids"][:, 0], skip_special_tokens=True ), - } + }, + ensure_ascii=False, ) + "\n" ) @@ -583,8 +594,10 @@ def __init__( self.eval_generation_config["n"] = 1 # use 1 generation for evaluation self.eval_generation_config.update(eval_generation_config) self.eval_sample_params = SamplingParams(**self.eval_generation_config) - self.ready_processes = 0 - self.condition = asyncio.Condition() + self.ready_processes_sync_model = 0 + self.ready_processes_sync_data = 0 + self.sync_model_condition = asyncio.Condition() + self.sync_data_condition = asyncio.Condition() self.data_ready_for_sending = [] @torch.no_grad() @@ -613,6 +626,7 @@ async def generate(self, input_ids, attention_mask, **kwargs): ).cpu() # CUDA tensor is not serializable by ray for k in rollouts[0].keys() } + rollouts["consumer_global_step"] = self.consumer_global_step return rollouts @torch.no_grad() @@ -634,33 +648,33 @@ async def async_sync_model(self, episode, step, num_processes: int = 1) -> None: Asyncronous version to sync model from consumer to producer. called by another producer, such as agentic producer. """ - async with self.condition: - self.ready_processes += 1 + async with self.sync_model_condition: + self.ready_processes_sync_model += 1 # Wait until all processes are ready - if self.ready_processes < num_processes: - await self.condition.wait() + if self.ready_processes_sync_model < num_processes: + await self.sync_model_condition.wait() - # Only one process should reset `ready_processes` and perform the sync - if self.ready_processes == num_processes: - self.ready_processes = 0 - self.condition.notify_all() # Notify all waiting processes + # Only one process should reset `ready_processes_sync_model` and perform the sync + if self.ready_processes_sync_model == num_processes: + self.ready_processes_sync_model = 0 + self.sync_model_condition.notify_all() # Notify all waiting processes self.sync_model(episode, step) async def async_sync_data(self, data: Dict[str, torch.Tensor], num_processes: int = 1) -> None: # merge data dict - async with self.condition: - self.ready_processes += 1 + async with self.sync_data_condition: + self.ready_processes_sync_data += 1 if data: self.data_ready_for_sending.append(data) # Wait until all processes are ready - if self.ready_processes < num_processes: - await self.condition.wait() + if self.ready_processes_sync_data < num_processes: + await self.sync_data_condition.wait() # Only one process should reset `ready_processes` and perform the sync - if self.ready_processes == num_processes: # wait for all producers to join - self.ready_processes = 0 - self.condition.notify_all() + if self.ready_processes_sync_data == num_processes: # wait for all producers to join + self.ready_processes_sync_data = 0 + self.sync_data_condition.notify_all() # merge data for sending if len(self.data_ready_for_sending) >= 1: batch_rollout_data = {} @@ -856,7 +870,8 @@ async def rollout(self, input_ids, attention_mask, **kwargs): "rollout": self.tokenizer.batch_decode( rollouts["input_ids"][:, 0], skip_special_tokens=True ), - } + }, + ensure_ascii=False, ) + "\n" ) diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index 9aa39788f8b3..57217ab94d08 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -19,6 +19,7 @@ import json +import re import torch from latex2sympy2_extended import NormalizationConfig @@ -126,6 +127,12 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): format_valid = validate_response_structure(processed_str, kwargs["tags"]) + if "forced_patterns" in kwargs and kwargs["forced_patterns"]: + forced_patterns = kwargs["forced_patterns"] + format_valid = format_valid and all( + [re.search(pattern, decoded_final_answer) is not None for pattern in forced_patterns] + ) + # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid if final_answer is not None: if eval_mode or format_valid: @@ -161,7 +168,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): tokenizer = kwargs["tokenizer"] eval_mode = kwargs.get("eval_mode", False) soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False) - acc_score = 10.0 + acc_score = 1.0 reward = torch.tensor(0.0) format_acc = torch.tensor(0.0) ans_acc = torch.tensor(0.0) @@ -182,7 +189,6 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): raise ValueError("no gt_answer is provided, please check your training dataset.") decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) - # print(f"decoded_final_answer: {decoded_final_answer[-100:]}", gt_answer) final_answer = extract_boxed_solution(decoded_final_answer) format_valid = final_answer is not None if "tags" in kwargs and kwargs["tags"]: @@ -190,7 +196,11 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): format_valid = format_valid and all( [decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags] ) - + if "forced_patterns" in kwargs and kwargs["forced_patterns"]: + forced_patterns = kwargs["forced_patterns"] + format_valid = format_valid and all( + [re.search(pattern, decoded_final_answer) is not None for pattern in forced_patterns] + ) # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid if final_answer is not None: if eval_mode or format_valid: diff --git a/applications/ColossalChat/conversation_template/qwen3.json b/applications/ColossalChat/conversation_template/qwen3.json new file mode 100644 index 000000000000..7c713d2b173c --- /dev/null +++ b/applications/ColossalChat/conversation_template/qwen3.json @@ -0,0 +1,8 @@ +{ + "chat_template": "{%- if tools %}\\n {{- \'<|im_start|>system\\\\n\' }}\\n {%- if messages[0].role == \'system\' %}\\n {{- messages[0].content + \'\\\\n\\\\n\' }}\\n {%- endif %}\\n {{- \\"# Tools\\\\n\\\\nYou may call one or more functions to assist with the user query.\\\\n\\\\nYou are provided with function signatures within XML tags:\\\\n\\" }}\\n {%- for tool in tools %}\\n {{- \\"\\\\n\\" }}\\n {{- tool | tojson }}\\n {%- endfor %}\\n {{- \\"\\\\n\\\\n\\\\nFor each function call, return a json object with function name and arguments within XML tags:\\\\n\\\\n{\\\\\\"name\\\\\\": , \\\\\\"arguments\\\\\\": }\\\\n<|im_end|>\\\\n\\" }}\\n{%- else %}\\n {%- if messages[0].role == \'system\' %}\\n {{- \'<|im_start|>system\\\\n\' + messages[0].content + \'<|im_end|>\\\\n\' }}\\n {%- endif %}\\n{%- endif %}\\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\\n{%- for message in messages[::-1] %}\\n {%- set index = (messages|length - 1) - loop.index0 %}\\n {%- if ns.multi_step_tool and message.role == \\"user\\" and message.content is string and not(message.content.startswith(\'\') and message.content.endswith(\'\')) %}\\n {%- set ns.multi_step_tool = false %}\\n {%- set ns.last_query_index = index %}\\n {%- endif %}\\n{%- endfor %}\\n{%- for message in messages %}\\n {%- if message.content is string %}\\n {%- set content = message.content %}\\n {%- else %}\\n {%- set content = \'\' %}\\n {%- endif %}\\n {%- if (message.role == \\"user\\") or (message.role == \\"system\\" and not loop.first) %}\\n {{- \'<|im_start|>\' + message.role + \'\\\\n\' + content + \'<|im_end|>\' + \'\\\\n\' }}\\n {%- elif message.role == \\"assistant\\" %}\\n {{- \'<|im_start|>\' + message.role + \'\\\\n\' + content }}\\n {%- if message.tool_calls %}\\n {%- for tool_call in message.tool_calls %}\\n {%- if (loop.first and content) or (not loop.first) %}\\n {{- \'\\\\n\' }}\\n {%- endif %}\\n {%- if tool_call.function %}\\n {%- set tool_call = tool_call.function %}\\n {%- endif %}\\n {{- \'\\\\n{\\"name\\": \\"\' }}\\n {{- tool_call.name }}\\n {{- \'\\", \\"arguments\\": \' }}\\n {%- if tool_call.arguments is string %}\\n {{- tool_call.arguments }}\\n {%- else %}\\n {{- tool_call.arguments | tojson }}\\n {%- endif %}\\n {{- \'}\\\\n\' }}\\n {%- endfor %}\\n {%- endif %}\\n {{- \'<|im_end|>\\\\n\' }}\\n {%- elif message.role == \\"tool\\" %}\\n {%- if loop.first or (messages[loop.index0 - 1].role != \\"tool\\") %}\\n {{- \'<|im_start|>user\' }}\\n {%- endif %}\\n {{- \'\\\\n\\\\n\' }}\\n {{- content }}\\n {{- \'\\\\n\' }}\\n {%- if loop.last or (messages[loop.index0 + 1].role != \\"tool\\") %}\\n {{- \'<|im_end|>\\\\n\' }}\\n {%- endif %}\\n {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n {{- \'<|im_start|>assistant\\\\n\' }}\\n {%- if enable_thinking is defined and enable_thinking is false %}\\n {{- \'\\\\n\\\\n\\\\n\\\\n\' }}\\n {%- endif %}\\n{%- endif %}", + "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + "stop_ids": [ + 7 + ], + "end_of_assistant": "<|im_end|>" +} diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 706f2708cd2b..5af6fb79203d 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -131,6 +131,7 @@ default=1.0, help="Top p for sampling. Please check the generation arguments documentation for your backend.", ) + parser.add_argument("-ct", "--chat-template", type=str, default=None, help="Chat template to use for the model.") parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.") parser.add_argument("-mnt", "--max-new-tokens", type=int, default=1024 * 4 - 512, help="Max length for generation.") parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.") @@ -427,11 +428,20 @@ "llm_call_budget": 10, "max_tokens": 2048, } + grpo_config["forced_patterns"] = [ + r"\n.+\n" + ] # force at least one correct tool call else: raise ValueError(f"Unsupported agentic model type: {args.agentic_type}") else: agentic_config = None + tokenizer_config = { + "path": args.model, + "trust_remote_code": True, + "chat_template": args.chat_template, + } + launch_distributed( num_producers=args.num_inferencer, num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size) @@ -453,6 +463,7 @@ train_model_config=train_model_config, grpo_config=grpo_config, agentic_config=agentic_config, + tokenizer_config=tokenizer_config, plugin_config={ "tp_size": args.tensor_parallel_size, "pp_size": args.pipeline_parallel_size, @@ -480,7 +491,7 @@ eval_interval=args.eval_interval, eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")), eval_generation_config=eval_generation_config, - log_rollout_interval=20, + log_rollout_interval=1, rollout_save_dir=args.rollout_save_dir, enable_profiling=args.enable_profiling, n_behind=args.n_behind, From 2b46ab1401db3977529aa2f5cf2a39a922fc3d3b Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Thu, 18 Sep 2025 18:28:36 +0800 Subject: [PATCH 07/11] simplify _run_agentic_pipeline; fix old_log_probs --- .../coati/distributed/agent/agentic_producer.py | 12 ++++-------- .../ColossalChat/coati/distributed/agent/base.py | 16 +++++++--------- .../ColossalChat/coati/distributed/loss.py | 4 ++-- applications/ColossalChat/rl_example.py | 10 ++++------ 4 files changed, 17 insertions(+), 25 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/agent/agentic_producer.py b/applications/ColossalChat/coati/distributed/agent/agentic_producer.py index 48b07ffadde3..4f7dc3c9f219 100644 --- a/applications/ColossalChat/coati/distributed/agent/agentic_producer.py +++ b/applications/ColossalChat/coati/distributed/agent/agentic_producer.py @@ -224,9 +224,7 @@ def _run_agentic_pipeline(self, messages): if llm_call_count > self.llm_call_budget: print(f"LLM call budget exceeded: {llm_call_count} > {self.llm_call_budget}. Stopping.") del self.async_llm_engine_map[request_id] - while messages[-1]["role"] == "tool": - messages.pop() - return messages, logprobs + return messages, response_input_ids, logprobs inputs = self._build_prompt(messages, return_dict=True, return_tensors="pt") if num_prompt_tokens == 0: num_prompt_tokens = inputs["input_ids"].size(-1) @@ -235,9 +233,7 @@ def _run_agentic_pipeline(self, messages): f"Max tokens exceeded: Current have generated {inputs['input_ids'].size(-1) - num_prompt_tokens} tokens > {self.generate_config.get('max_tokens', 512)}. Stopping." ) del self.async_llm_engine_map[request_id] - while messages[-1]["role"] == "tool": - messages.pop() - return messages, logprobs + return messages, response_input_ids, logprobs async_producer = self._select_async_producer(request_id=request_id) agentic_generate_config = copy.deepcopy(self.generate_config) agentic_generate_config["max_tokens"] = self.agentic_config.get("max_tokens", 2048) @@ -262,7 +258,7 @@ def _run_agentic_pipeline(self, messages): if tool_call_count > self.tool_call_budget: print(f"Tool call budget exceeded: {tool_call_count} > {self.tool_call_budget}. Stopping.") del self.async_llm_engine_map[request_id] - return messages, logprobs + return messages, response_input_ids, logprobs tool_call_count += len(assistant_message["tool_calls"]) handlers = [] for tool_call in assistant_message["tool_calls"]: @@ -277,4 +273,4 @@ def _run_agentic_pipeline(self, messages): else: # no further tool call, return the messages del self.async_llm_engine_map[request_id] - return messages, logprobs + return messages, response_input_ids, logprobs diff --git a/applications/ColossalChat/coati/distributed/agent/base.py b/applications/ColossalChat/coati/distributed/agent/base.py index 6ea5b17e10ba..e5ff9ffc588b 100644 --- a/applications/ColossalChat/coati/distributed/agent/base.py +++ b/applications/ColossalChat/coati/distributed/agent/base.py @@ -123,24 +123,22 @@ def rollout(self, **kwargs) -> Dict[str, torch.Tensor]: ) for i in range(self.num_generations): - _messages, logprobs = results[i] - response_input_ids = self._build_prompt( - _messages, return_dict=True, return_tensors="pt", add_generation_prompt=False - )["input_ids"] + # due to the multiround feature, action_mask and attention_mask need to be recomputed + _messages, response_input_ids, logprobs = results[i] # truncate if too long - response_input_ids = response_input_ids[:, : self.grpo_config["max_length"] - to_pad_left] + response_input_ids = response_input_ids[0, :, : self.grpo_config["max_length"] - to_pad_left] # add left right padding - to_pad_right = self.grpo_config["max_length"] - response_input_ids.shape[1] - to_pad_left - response_length = response_input_ids.shape[1] - prompt_length + to_pad_right = self.grpo_config["max_length"] - response_input_ids.size(-1) - to_pad_left input_ids = torch.nn.functional.pad( response_input_ids, (to_pad_left, to_pad_right), "constant", value=self.tokenizer.pad_token_id ) # [1, max_length] attention_mask = input_ids.ne(self.tokenizer.pad_token_id).int() # [1, max_length] action_mask = input_ids[:, max_prompt_length:].ne(self.tokenizer.pad_token_id).int() + response_length = action_mask.sum().item() rollouts["attention_mask"].append(attention_mask) rollouts["action_mask"].append(action_mask) truncated_logprobs = logprobs[ - :, :, prompt_length : prompt_length + self.generate_config["max_tokens"] + 0, :, prompt_length : prompt_length + self.generate_config["max_tokens"] ] # truncate to max_new_tokens logprobs_padded = torch.nn.functional.pad( truncated_logprobs, @@ -148,7 +146,7 @@ def rollout(self, **kwargs) -> Dict[str, torch.Tensor]: "constant", value=0.0, ) # [1, max_new_tokens] - rollouts["action_log_probs"].append(logprobs_padded[0]) + rollouts["action_log_probs"].append(logprobs_padded) rollouts["response_idx"].append( torch.tensor( [ diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py index ab38f987f65a..7fcfdba31f4d 100644 --- a/applications/ColossalChat/coati/distributed/loss.py +++ b/applications/ColossalChat/coati/distributed/loss.py @@ -37,9 +37,9 @@ def forward( total_effective_tokens_in_batch: torch.Tensor = None, ) -> torch.Tensor: if action_mask is None: - ratio = (log_probs - log_probs.detach()).exp() + ratio = (log_probs - old_log_probs.detach()).exp() else: - ratio = ((log_probs - log_probs.detach()) * action_mask).exp() + ratio = ((log_probs - old_log_probs.detach()) * action_mask).exp() surr1 = ratio * advantages surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 5af6fb79203d..5770bfde69a0 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -429,18 +429,16 @@ "max_tokens": 2048, } grpo_config["forced_patterns"] = [ - r"\n.+\n" + r"\n.+\n" # please modify based on your tool response format ] # force at least one correct tool call else: raise ValueError(f"Unsupported agentic model type: {args.agentic_type}") else: agentic_config = None - tokenizer_config = { - "path": args.model, - "trust_remote_code": True, - "chat_template": args.chat_template, - } + tokenizer_config = {"path": args.model, "trust_remote_code": True} + if args.chat_template is not None: + tokenizer_config["chat_template"] = args.chat_template launch_distributed( num_producers=args.num_inferencer, From 8745e8f4d1ca62ee2ea59020d200ea983ffc8f20 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Fri, 19 Sep 2025 17:34:47 +0800 Subject: [PATCH 08/11] test asyncllm producer and other settings --- .../ColossalChat/coati/distributed/consumer.py | 2 +- .../coati/distributed/inference_backend.py | 11 ++++++++--- applications/ColossalChat/coati/distributed/launch.py | 2 +- .../ColossalChat/coati/distributed/producer.py | 9 +++++---- 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 45aaead49e5e..fa0e331e8f92 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -181,7 +181,6 @@ def loop(self) -> None: for step in pbar: torch.cuda.reset_peak_memory_stats() i = 0 - self.profiler.enter(f"rollout_episode_{episode}_step_{step}") for _ in range(self.num_recv_per_update): if self.n_behind > 0: @@ -325,6 +324,7 @@ def loop(self) -> None: ) # for setting start index when resuming training if self.rank == 0: print(f"Saved model checkpoint at step {step + 1} in folder {self.save_dir}") + if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and ( episode != 0 or step >= self.n_behind ): diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 6fd3fa2ddb49..e6eab16692de 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -251,7 +251,12 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar micro_batch_input_ids_no_padding = [ micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size) ] - sample_params = kwargs.get("sample_params", self.sample_params) + sample_params = self.sample_params + if len(kwargs) > 0: + sample_params = self.generate_config.copy() + sample_params.update({k: v for k, v in kwargs.items() if k not in ["gt_answer", "test_cases", "labels"]}) + sample_params.update(self.FORCE_GENERATE_CONFIG) + sample_params = SamplingParams(**sample_params) outputs = self.llm.generate( prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=sample_params, use_tqdm=False ) @@ -358,7 +363,7 @@ async def generate( input_ids (torch.Tensor): shape [B, S], B=1 attention_mask (torch.Tensor): shape [B, S] """ - assert input_ids.size(0) == attention_mask.size(0) == 1 + assert input_ids.size(0) == attention_mask.size(0) == 1, "AsyncVLLMInferenceBackend only supports batch size 1" request_id = ( str(uuid4()) if not "request_id" in kwargs else kwargs.pop("request_id") ) # use fixed request_id to reuse kv cache @@ -368,7 +373,7 @@ async def generate( sample_params = self.sample_params if len(kwargs) > 0: sample_params = self.generate_config.copy() - sample_params.update(kwargs) + sample_params.update({k: v for k, v in kwargs.items() if k not in ["gt_answer", "test_cases", "labels"]}) sample_params.update(self.FORCE_GENERATE_CONFIG) sample_params = SamplingParams(**sample_params) out_tokens = [] diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index ca77316267d1..25f2e03d065c 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -143,7 +143,7 @@ def launch_distributed( tokenizer_config=tokenizer_config, microbatch_size=( inference_microbatch_size * num_generations - if "async" in inference_backend + if "async-agentic" in inference_backend else inference_microbatch_size ), backend=inference_backend, diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 2964885ba867..80e5e560c9ff 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -284,7 +284,6 @@ def sync_data(self, data: Dict[str, torch.Tensor]) -> None: ray_broadcast_tensor_dict(data, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}") def loop(self) -> None: - # breakpoint() self.sync_model(0, 0) num_update_per_episode = len(self.train_dataloader) // self.num_microbatches num_valid_microbatches = num_update_per_episode * self.num_microbatches @@ -620,10 +619,10 @@ async def generate(self, input_ids, attention_mask, **kwargs): rollouts = await asyncio.gather(*tasks) rollouts = { k: ( - torch.cat([r[k] for r in rollouts], dim=0) + torch.cat([r[k] for r in rollouts], dim=0).cpu() if k not in ["gt_answer", "test_cases"] else [r[k] for r in rollouts] - ).cpu() # CUDA tensor is not serializable by ray + ) # CUDA tensor is not serializable by ray for k in rollouts[0].keys() } rollouts["consumer_global_step"] = self.consumer_global_step @@ -758,8 +757,8 @@ async def loop(self) -> None: self.eval_mode = False self.latest_eval_step = self.consumer_global_step self.profiler.enter("rollout") - # breakpoint() outputs = await self.rollout(**batch) + outputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in outputs.items()} self.profiler.exit("rollout") outputs["temperature"] = torch.tensor( [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0) @@ -803,6 +802,8 @@ async def loop(self) -> None: outputs.pop("gt_answer") if "test_cases" in outputs: outputs.pop("test_cases") + if "consumer_global_step" in outputs: + outputs.pop("consumer_global_step") self.profiler.exit("calculate_reward") print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") From c095ec35da4dc3379da63cfa6e0f9b663c8c43dc Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Mon, 22 Sep 2025 09:36:47 +0800 Subject: [PATCH 09/11] tested anb fix style issue --- applications/ColossalChat/coati/distributed/consumer.py | 1 - applications/ColossalChat/coati/distributed/inference_backend.py | 1 - 2 files changed, 2 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index fa0e331e8f92..45840e7e488a 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -165,7 +165,6 @@ def loop(self) -> None: state_dict, src=self.num_producers, device=self.device, group_name="sync_model" ) del state_dict - print(f"[C{self.rank}]: Sync model before training done") torch.cuda.empty_cache() self.profiler.exit("sync_model") diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index e6eab16692de..9f9d8e36ba68 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -381,7 +381,6 @@ async def generate( log_probs = [] response_idx = [] while len(self.running_requests) >= self.microbatch_size: - # print(f"Current running {len(self.running_requests)}/{self.microbatch_size} requests, waiting...") await asyncio.sleep(0.1) self.running_requests.append(request_id) # enqueue # pop the first input_ids and attention_mask From 8ca76fe935a81c88b31654f6ebe1ad35106c93bd Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Tue, 23 Sep 2025 10:47:44 +0800 Subject: [PATCH 10/11] fix vllm configuration and load balancing --- .../distributed/agent/agentic_producer.py | 34 ++++++++++--------- .../coati/distributed/agent/tool_worker.py | 13 ------- .../ColossalChat/coati/distributed/launch.py | 12 +++---- .../coati/distributed/producer.py | 7 ---- .../ColossalChat/coati/distributed/utils.py | 28 +++++++++++++-- applications/ColossalChat/rl_example.py | 6 ++-- 6 files changed, 53 insertions(+), 47 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/agent/agentic_producer.py b/applications/ColossalChat/coati/distributed/agent/agentic_producer.py index 4f7dc3c9f219..c80a1f319a97 100644 --- a/applications/ColossalChat/coati/distributed/agent/agentic_producer.py +++ b/applications/ColossalChat/coati/distributed/agent/agentic_producer.py @@ -52,6 +52,7 @@ def __init__( log_rollout_interval: int = 20, rollout_log_file: str = "./rollout_log.jsonl", enable_profiling: bool = False, + load_balancer=None, n_behind: int = 0, ): assert microbatch_size == 1 # microbatch_size must be 1 for agentic producer @@ -84,6 +85,7 @@ def __init__( enable_profiling=enable_profiling, n_behind=n_behind, ) + self.load_balancer = load_balancer self.tool_workers = tool_workers self.agentic_config = model_config if not agentic_config else agentic_config self.agentic_config.update({"model": model_config["path"]}) @@ -183,32 +185,26 @@ def _parse_response(self, response: str) -> Dict[str, Any]: assistant_message["tool_calls"] = tool_calls return assistant_message - def _select_tool_worker(self) -> ray.actor.ActorHandle: + def _select_tool_worker(self) -> int: """ Select a tool worker based on the current load. """ - loads = ray.get([worker.get_load.remote() for worker in self.tool_workers]) - min_load = min(loads) - candidates = [i for i, l in enumerate(loads) if l == min_load] - selected_idx = random.choice(candidates) # random tie break - ray.get(self.tool_workers[selected_idx].increase_load.remote()) - return self.tool_workers[selected_idx] + selected_idx, current_loads = ray.get(self.load_balancer.get_next_worker.remote("tool", amount=1)) + return selected_idx - def _select_async_producer(self, request_id) -> ray.actor.ActorHandle: + def _select_async_producer(self, request_id) -> int: """ Select an async producer based on the current load. """ # use the last used async producer if exists to reuse kv cache (as vllm use paged kv cache, # it will reuse most of the kv cache pages without recomputation) if request_id in self.async_llm_engine_map: - return self.async_producers[self.async_llm_engine_map[request_id]] + ray.get(self.load_balancer.increase_load.remote("async-llm", self.async_llm_engine_map[request_id], 1)) + return self.async_llm_engine_map[request_id] # otherwise select the least loaded async producer - loads = ray.get([proc.get_producer_load.remote() for proc in self.async_producers]) - min_load = min(loads) - candidates = [i for i, l in enumerate(loads) if l == min_load] - selected_idx = random.choice(candidates) # random tie break + selected_idx, current_loads = ray.get(self.load_balancer.get_next_worker.remote("async-llm", amount=1)) self.async_llm_engine_map[request_id] = selected_idx - return self.async_producers[selected_idx] + return selected_idx def _run_agentic_pipeline(self, messages): """ @@ -234,7 +230,7 @@ def _run_agentic_pipeline(self, messages): ) del self.async_llm_engine_map[request_id] return messages, response_input_ids, logprobs - async_producer = self._select_async_producer(request_id=request_id) + async_producer = self.async_producers[self._select_async_producer(request_id=request_id)] agentic_generate_config = copy.deepcopy(self.generate_config) agentic_generate_config["max_tokens"] = self.agentic_config.get("max_tokens", 2048) response = ray.get( @@ -246,6 +242,7 @@ def _run_agentic_pipeline(self, messages): ) ) llm_call_count += 1 + ray.get(self.load_balancer.decrease_load.remote("async-llm", self.async_llm_engine_map[request_id], 1)) self.consumer_global_step = response.pop("consumer_global_step") response_input_ids = response["input_ids"] logprobs = response["action_log_probs"] @@ -261,12 +258,17 @@ def _run_agentic_pipeline(self, messages): return messages, response_input_ids, logprobs tool_call_count += len(assistant_message["tool_calls"]) handlers = [] + tool_workers_called = [] for tool_call in assistant_message["tool_calls"]: # select a tool worker to execute the tool call - tool_worker = self._select_tool_worker() + tool_worker_idx = self._select_tool_worker() + tool_workers_called.append(tool_worker_idx) + tool_worker = self.tool_workers[tool_worker_idx] handler = tool_worker.call.remote(tool_call["function"]["name"], tool_call["function"]["arguments"]) handlers.append(handler) tool_results = ray.get(handlers) + for idx in tool_workers_called: + ray.get(self.load_balancer.decrease_load.remote("tool", idx, 1)) for tool_call, tool_result in zip(assistant_message["tool_calls"], tool_results): tool_message = {"role": "tool", "content": str(tool_result)} messages.append(tool_message) diff --git a/applications/ColossalChat/coati/distributed/agent/tool_worker.py b/applications/ColossalChat/coati/distributed/agent/tool_worker.py index ae148af864c0..454d2adba1d0 100644 --- a/applications/ColossalChat/coati/distributed/agent/tool_worker.py +++ b/applications/ColossalChat/coati/distributed/agent/tool_worker.py @@ -19,17 +19,6 @@ def __init__(self, tools: List[BaseTool]): tools (List[BaseTool]): List of LangChain tools to register. """ self._tool_registry: Dict[str, BaseTool] = {tool.name: tool for tool in tools} - self.pending = 0 - - @ray.method(concurrency_group="io") - def get_load(self) -> int: - """Return the current load of the worker.""" - return self.pending - - @ray.method(concurrency_group="io") - def increase_load(self): - """Increase the load counter.""" - self.pending += 1 @ray.method(concurrency_group="io") def list_tools(self) -> List[str]: @@ -64,7 +53,6 @@ def call(self, tool_name: str, input_data: Union[str, Dict[str, Any]], **kwargs) Any: The tool's output. """ if tool_name == "return_parsing_error": - self.pending -= 1 return "Error: Tool call parsing error. Please use the correct JSON format." if tool_name not in self._tool_registry: return f"Error: Tool {tool_name} not found. Available tools: {self.list_tools()}" @@ -73,5 +61,4 @@ def call(self, tool_name: str, input_data: Union[str, Dict[str, Any]], **kwargs) ret = tool.run(input_data, **kwargs) except Exception as e: ret = f"Error: Tool {tool_name} execution failed with error: {str(e)}" - self.pending -= 1 return ret diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 25f2e03d065c..535cf25dd545 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -10,6 +10,7 @@ from .consumer import SimpleConsumer from .grpo_consumer import GRPOConsumer from .producer import AsyncSimpleProducer, SimpleProducer +from .utils import LoadBalancer ALGO_MAP = { "Simple": SimpleConsumer, @@ -86,7 +87,7 @@ def launch_distributed( num_samples = get_jsonl_size_fast(dataset_path) global_inference_batch_size = inference_batch_size * num_producers num_update_per_episode = num_samples // global_inference_batch_size - num_recv_per_update = inference_batch_size // inference_microbatch_size if "async" not in inference_backend else 1 + num_recv_per_update = inference_batch_size // inference_microbatch_size if "async-agentic" not in inference_backend else 1 run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}" wandb_group_name = str(uuid.uuid4()) @@ -124,6 +125,7 @@ def launch_distributed( enable_agentic = "agentic" in inference_backend if enable_agentic: inference_backend = inference_backend.replace("agentic-", "") + inference_microbatch_size = inference_microbatch_size * num_generations for i in range(num_producers): node_id = gpu_to_node_id[0] producer_ip_address = gpu_to_ip_address[0] @@ -141,11 +143,7 @@ def launch_distributed( model_config=inference_model_config, generate_config=generate_config, tokenizer_config=tokenizer_config, - microbatch_size=( - inference_microbatch_size * num_generations - if "async-agentic" in inference_backend - else inference_microbatch_size - ), + microbatch_size=inference_microbatch_size, backend=inference_backend, num_generations=num_generations, consumer_plugin_config=plugin_config, @@ -183,6 +181,7 @@ def launch_distributed( assert ( agentic_config["agentic_producer"] in AGENTIC_PRODUCER_MAP ), f"Only {list(AGENTIC_PRODUCER_MAP.keys())} are supported as agentic producer so far." + load_balancer = LoadBalancer.remote({"tool": len(tool_workers), "async-llm": num_producers}) agentic_producer_cls = AGENTIC_PRODUCER_MAP[agentic_config["agentic_producer"]] agentic_config.pop("agentic_producer") producer_procs = [ @@ -214,6 +213,7 @@ def launch_distributed( log_rollout_interval=log_rollout_interval, rollout_log_file=rollout_log_file, enable_profiling=enable_profiling, + load_balancer=load_balancer, n_behind=n_behind, ) for producer_idx in range(num_producers * inference_batch_size) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 80e5e560c9ff..d60be515a859 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -636,12 +636,6 @@ async def rollout(self, input_ids, attention_mask, **kwargs): """ raise NotImplementedError("rollout must be implemented in subclasses") - async def get_producer_load(self): - """ - Get the load of each producer. - """ - return len(self.model.running_requests) - async def async_sync_model(self, episode, step, num_processes: int = 1) -> None: """ Asyncronous version to sync model from consumer to producer. @@ -853,7 +847,6 @@ class AsyncSimpleProducer(BaseAsyncProducer): Asyncronous version of the producer that uses vLLM for generation. This class is designed to handle multiple producer actors and distribute tasks among them. """ - @torch.no_grad() async def rollout(self, input_ids, attention_mask, **kwargs): # naive rollout strategy without load balancing diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 466914cc0d4d..629872921fc5 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -1,12 +1,12 @@ import json import os from typing import Any, Dict, List - +import asyncio import torch from filelock import FileLock - +import random from colossalai.shardformer.layer.loss import dist_log_prob - +import ray def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]: batches = [] @@ -165,3 +165,25 @@ def safe_append_to_jsonl_file(file_path, data): for entry in data: json_line = json.dumps(entry, ensure_ascii=False) f.write(json_line + "\n") + +@ray.remote +class LoadBalancer: + def __init__(self, worker_counts): + self.load = {} + for type in worker_counts: + self.load[type] = {k: 0 for k in range(worker_counts[type])} + + def get_next_worker(self, worker_type, amount=1): + loads = [(k, v) for k, v in self.load[worker_type].items()] + min_load = min(loads, key=lambda x: x[1]) + candidates = [k for k, v in loads if v == min_load[1]] + chosen = random.choice(candidates) + self.load[worker_type][chosen] += amount + return chosen, self.load[worker_type] + + def increase_load(self, worker_type, worker_id, amount=1): + self.load[worker_type][worker_id] += amount + + def decrease_load(self, worker_type, worker_id, amount=1): + self.load[worker_type][worker_id] -= amount + \ No newline at end of file diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 5770bfde69a0..5c798bdc2b0c 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -281,8 +281,10 @@ # os.environ["VLLM_DP_SIZE"] = str(args.producer_data_parallel_size) inference_model_config.update( dict( - gpu_memory_utilization=0.7, - enforce_eager=True, + gpu_memory_utilization=0.8, + max_num_batched_tokens=4096, + max_num_seqs=1024, + enforce_eager=False, enable_chunked_prefill=True, max_model_len=args.max_new_tokens + args.max_prompt_tokens, tensor_parallel_size=args.producer_tensor_parallel_size, From f10f707c58921a34a948885bd6ada3cfb368c8b7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Sep 2025 02:49:42 +0000 Subject: [PATCH 11/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../coati/distributed/agent/agentic_producer.py | 1 - .../ColossalChat/coati/distributed/launch.py | 4 +++- .../ColossalChat/coati/distributed/producer.py | 1 + .../ColossalChat/coati/distributed/utils.py | 14 ++++++++------ 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/agent/agentic_producer.py b/applications/ColossalChat/coati/distributed/agent/agentic_producer.py index c80a1f319a97..bd2f6e56d91f 100644 --- a/applications/ColossalChat/coati/distributed/agent/agentic_producer.py +++ b/applications/ColossalChat/coati/distributed/agent/agentic_producer.py @@ -1,5 +1,4 @@ import copy -import random import re from typing import Any, Dict from uuid import uuid4 diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 535cf25dd545..f060104db0c1 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -87,7 +87,9 @@ def launch_distributed( num_samples = get_jsonl_size_fast(dataset_path) global_inference_batch_size = inference_batch_size * num_producers num_update_per_episode = num_samples // global_inference_batch_size - num_recv_per_update = inference_batch_size // inference_microbatch_size if "async-agentic" not in inference_backend else 1 + num_recv_per_update = ( + inference_batch_size // inference_microbatch_size if "async-agentic" not in inference_backend else 1 + ) run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}" wandb_group_name = str(uuid.uuid4()) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index d60be515a859..ed0faa9fdae6 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -847,6 +847,7 @@ class AsyncSimpleProducer(BaseAsyncProducer): Asyncronous version of the producer that uses vLLM for generation. This class is designed to handle multiple producer actors and distribute tasks among them. """ + @torch.no_grad() async def rollout(self, input_ids, attention_mask, **kwargs): # naive rollout strategy without load balancing diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 629872921fc5..48b823cb0937 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -1,12 +1,14 @@ import json import os +import random from typing import Any, Dict, List -import asyncio + +import ray import torch from filelock import FileLock -import random + from colossalai.shardformer.layer.loss import dist_log_prob -import ray + def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]: batches = [] @@ -166,6 +168,7 @@ def safe_append_to_jsonl_file(file_path, data): json_line = json.dumps(entry, ensure_ascii=False) f.write(json_line + "\n") + @ray.remote class LoadBalancer: def __init__(self, worker_counts): @@ -180,10 +183,9 @@ def get_next_worker(self, worker_type, amount=1): chosen = random.choice(candidates) self.load[worker_type][chosen] += amount return chosen, self.load[worker_type] - + def increase_load(self, worker_type, worker_id, amount=1): self.load[worker_type][worker_id] += amount - + def decrease_load(self, worker_type, worker_id, amount=1): self.load[worker_type][worker_id] -= amount - \ No newline at end of file