|
| 1 | +import copy |
| 2 | +import json |
| 3 | +from typing import Any, Dict |
| 4 | + |
| 5 | +import ray |
| 6 | +import torch |
| 7 | +from coati.distributed.agent.agentic_math_utils import TIR_SYSTEM, CustomTransformers |
| 8 | +from coati.distributed.producer import BaseProducer |
| 9 | +from qwen_agent.agents import TIRMathAgent |
| 10 | +from vllm import SamplingParams |
| 11 | + |
| 12 | + |
| 13 | +@ray.remote |
| 14 | +class AgenticProducer(BaseProducer): |
| 15 | + """ |
| 16 | + Asyncronous version of the producer that uses vLLM for generation. |
| 17 | + This class is designed to generate agentic response |
| 18 | + """ |
| 19 | + |
| 20 | + def __init__( |
| 21 | + self, |
| 22 | + producer_idx, |
| 23 | + num_producers, |
| 24 | + num_consumer_procs, |
| 25 | + num_episodes, |
| 26 | + batch_size, |
| 27 | + train_dataset_config, |
| 28 | + model_config, |
| 29 | + generate_config, |
| 30 | + async_producers, |
| 31 | + tokenizer_config=None, |
| 32 | + agentic_config=None, |
| 33 | + microbatch_size=1, |
| 34 | + backend="transformers", |
| 35 | + num_generations: int = 8, |
| 36 | + consumer_plugin_config=None, |
| 37 | + eval_dataset_config=None, |
| 38 | + eval_interval=-1, # disable evaluation |
| 39 | + grpo_config: Dict[str, Any] = None, |
| 40 | + eval_save_dir: str = "./eval", |
| 41 | + eval_generation_config={}, |
| 42 | + project_name: str = None, |
| 43 | + run_name: str = None, |
| 44 | + wandb_group_name: str = None, |
| 45 | + log_rollout_interval: int = 20, |
| 46 | + rollout_log_file: str = "./rollout_log.jsonl", |
| 47 | + enable_profiling: bool = False, |
| 48 | + n_behind: int = 0, |
| 49 | + ): |
| 50 | + assert microbatch_size == 1 # microbatch_size must be 1 for agentic producer |
| 51 | + assert batch_size == 1 # batch_size must be 1 for agentic producer |
| 52 | + super().__init__( |
| 53 | + producer_idx, |
| 54 | + num_producers, |
| 55 | + num_consumer_procs, |
| 56 | + num_episodes, |
| 57 | + batch_size, |
| 58 | + train_dataset_config, |
| 59 | + model_config, |
| 60 | + generate_config, |
| 61 | + tokenizer_config, |
| 62 | + microbatch_size, |
| 63 | + backend, |
| 64 | + consumer_plugin_config, |
| 65 | + eval_dataset_config=eval_dataset_config, |
| 66 | + eval_interval=eval_interval, |
| 67 | + grpo_config=grpo_config, |
| 68 | + eval_save_dir=eval_save_dir, |
| 69 | + project_name=project_name, |
| 70 | + run_name=run_name, |
| 71 | + wandb_group_name=wandb_group_name, |
| 72 | + log_rollout_interval=log_rollout_interval, |
| 73 | + rollout_log_file=rollout_log_file, |
| 74 | + enable_profiling=enable_profiling, |
| 75 | + n_behind=n_behind, |
| 76 | + enable_agentic=True, |
| 77 | + ) |
| 78 | + self.eval_generation_config = copy.deepcopy(generate_config) |
| 79 | + self.eval_generation_config["n"] = 1 # use 1 generation for evaluation |
| 80 | + self.eval_generation_config.update(eval_generation_config) |
| 81 | + self.eval_sample_params = SamplingParams(**self.eval_generation_config) |
| 82 | + self.async_producers = async_producers |
| 83 | + self.num_generations = num_generations |
| 84 | + self.generate_config = generate_config |
| 85 | + self.agentic_config = model_config if not agentic_config else agentic_config |
| 86 | + self.agentic_config.update({"model": model_config["path"]}) |
| 87 | + self.llm = CustomTransformers(self.agentic_config, self.producer_idx, generation_workers=self.async_producers) |
| 88 | + self.bot = TIRMathAgent(llm=self.llm, name=model_config["path"], system_message=TIR_SYSTEM) |
| 89 | + |
| 90 | + def rollout(self, **kwargs) -> Dict[str, torch.Tensor]: |
| 91 | + """ |
| 92 | + Rollout function to generate responses for the input, for example, using LLM or agentic pipeline. |
| 93 | + This function should be implemented in subclasses. |
| 94 | + """ |
| 95 | + assert len(kwargs["messages"]) == 1, "Only support batch size of 1 for agentic producer" |
| 96 | + messages = kwargs["messages"][0] |
| 97 | + prompt_input_ids = self.tokenizer.apply_chat_template( |
| 98 | + messages, return_tensors="pt", tokenize=True, add_generation_prompt=True |
| 99 | + ) |
| 100 | + # add left padding |
| 101 | + prompt_length = prompt_input_ids.shape[1] |
| 102 | + max_prompt_length = self.train_dataset_config["max_length"] |
| 103 | + to_pad_left = max_prompt_length - prompt_length |
| 104 | + rollouts = { |
| 105 | + "input_ids": [], |
| 106 | + "attention_mask": [], |
| 107 | + "action_mask": [], |
| 108 | + "action_log_probs": [], |
| 109 | + "response_idx": [], |
| 110 | + } |
| 111 | + for i in range(self.num_generations): |
| 112 | + _messages = copy.deepcopy(messages) |
| 113 | + for response in self.bot.run(messages): |
| 114 | + continue |
| 115 | + _messages.extend(response) |
| 116 | + response_input_ids = self.tokenizer.apply_chat_template(_messages, return_tensors="pt", tokenize=True) |
| 117 | + # truncate if too long |
| 118 | + response_input_ids = response_input_ids[:, : self.grpo_config["max_length"] - to_pad_left] |
| 119 | + # add left right padding |
| 120 | + to_pad_right = self.grpo_config["max_length"] - response_input_ids.shape[1] - to_pad_left |
| 121 | + response_length = response_input_ids.shape[1] - prompt_length |
| 122 | + input_ids = torch.nn.functional.pad( |
| 123 | + response_input_ids, (to_pad_left, to_pad_right), "constant", value=self.tokenizer.pad_token_id |
| 124 | + ) # [1, max_length] |
| 125 | + attention_mask = torch.nn.functional.pad( |
| 126 | + torch.ones_like(response_input_ids), (to_pad_left, to_pad_right), "constant", value=0 |
| 127 | + ) # [1, max_length] |
| 128 | + action_mask = torch.nn.functional.pad( |
| 129 | + torch.ones(size=(1, response_length)), (0, to_pad_right), "constant", value=0 |
| 130 | + ) # [1, max_length-prompt_length] |
| 131 | + rollouts["attention_mask"].append(attention_mask) |
| 132 | + rollouts["action_mask"].append(action_mask) |
| 133 | + rollouts["action_log_probs"].append( |
| 134 | + torch.ones(size=(1, self.grpo_config["max_length"] - max_prompt_length)) |
| 135 | + ) # dummy log probs |
| 136 | + rollouts["response_idx"].append( |
| 137 | + torch.tensor( |
| 138 | + [ |
| 139 | + [ |
| 140 | + self.train_dataset_config["max_length"], |
| 141 | + self.train_dataset_config["max_length"] + response_length, |
| 142 | + ] |
| 143 | + ] |
| 144 | + ) |
| 145 | + ) # [1, 2] |
| 146 | + rollouts["input_ids"].append(input_ids) |
| 147 | + # breakpoint() |
| 148 | + rollouts = {k: torch.cat(v, dim=0).unsqueeze(0) for k, v in rollouts.items()} # [num_generations, ...] |
| 149 | + rollouts["temperature"] = torch.tensor([self.agentic_config.get("temperature", 1.0)]) |
| 150 | + if hasattr(self, "rollout_log_file") and self.producer_idx == 0 and not self.eval_mode: |
| 151 | + # for agentic producer, AsyncSimpleProducer is not the main producer, so we don't log rollouts |
| 152 | + if ( |
| 153 | + self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval |
| 154 | + or self.latest_rollout_log_step == -1 |
| 155 | + ): |
| 156 | + new_record = ( |
| 157 | + json.dumps( |
| 158 | + { |
| 159 | + "train_step": self.consumer_global_step, |
| 160 | + "rollout": self.tokenizer.batch_decode( |
| 161 | + rollouts["input_ids"][:, 0], skip_special_tokens=True |
| 162 | + ), |
| 163 | + } |
| 164 | + ) |
| 165 | + + "\n" |
| 166 | + ) |
| 167 | + self.rollout_log_file.write(new_record) |
| 168 | + self.rollout_log_file.flush() |
| 169 | + self.latest_rollout_log_step = self.consumer_global_step |
| 170 | + |
| 171 | + if "gt_answer" in kwargs: |
| 172 | + rollouts["gt_answer"] = kwargs["gt_answer"] |
| 173 | + if "test_cases" in kwargs: |
| 174 | + rollouts["test_cases"] = kwargs["test_cases"] |
| 175 | + return rollouts |
| 176 | + |
| 177 | + def sync_model(self, episode, step) -> None: |
| 178 | + """ |
| 179 | + sync model from consumer to self.async_producers |
| 180 | + AgenticProducer does not hold any model weights, so no need to sync model to self.async_producers |
| 181 | + """ |
| 182 | + tasks = [] |
| 183 | + for proc in self.async_producers: |
| 184 | + tasks.append(proc.async_sync_model.remote(episode, step, self.num_producers)) |
| 185 | + ray.get(tasks) |
| 186 | + return |
| 187 | + |
| 188 | + def sync_data(self, data: Dict[str, torch.Tensor]) -> None: |
| 189 | + """ |
| 190 | + sync data from self to consumer |
| 191 | + """ |
| 192 | + tasks = [] |
| 193 | + for idx, proc in enumerate(self.async_producers): |
| 194 | + if idx == self.producer_idx % len(self.async_producers): |
| 195 | + tasks.append(proc.async_sync_data.remote(data, self.num_producers)) |
| 196 | + else: |
| 197 | + tasks.append(proc.async_sync_data.remote({}, self.num_producers)) |
| 198 | + ray.get(tasks) |
| 199 | + return |
0 commit comments