Skip to content

Commit 7f814e7

Browse files
committed
support agentic with asyncllm
1 parent d49a28d commit 7f814e7

File tree

12 files changed

+946
-438
lines changed

12 files changed

+946
-438
lines changed

applications/ColossalChat/coati/dataset/loader.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Dataloader for sft, dpo, ppo
55
"""
66

7+
import copy
78
import os
89
from dataclasses import dataclass
910
from typing import Dict, Iterator, List, Optional, Sequence, Union
@@ -423,7 +424,9 @@ class RawConversationDataset(Dataset):
423424
Each instance is a dictionary with fields `system`, `roles`, `messages`, `offset`, `sep_style`, `seps`.
424425
"""
425426

426-
def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int, system_prompt: str) -> None:
427+
def __init__(
428+
self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int, system_prompt: str, tokenize=True
429+
) -> None:
427430
self.tokenizer = tokenizer
428431
self.raw_texts = []
429432
with jsonlines.open(input_file) as f:
@@ -432,30 +435,50 @@ def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length:
432435
self.tokenized_texts = [None] * len(self.raw_texts)
433436
self.max_length = max_length
434437
self.system_prompt = system_prompt
438+
self.tokenize = tokenize
435439

436440
def __len__(self) -> int:
437441
return len(self.raw_texts)
438442

439443
def __getitem__(self, index: int):
440-
if self.tokenized_texts[index] is None:
441-
message = self.raw_texts[index]
442-
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt)
443-
self.tokenized_texts[index] = dict(tokens)
444-
return self.tokenized_texts[index]
444+
if self.tokenize:
445+
if self.tokenized_texts[index] is None:
446+
message = self.raw_texts[index]
447+
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt)
448+
self.tokenized_texts[index] = dict(tokens)
449+
return self.tokenized_texts[index]
450+
else:
451+
chat = copy.deepcopy(self.raw_texts[index])
452+
chat["messages"] = [{"role": "system", "content": self.system_prompt}, chat["messages"]]
453+
return chat
445454

446455

447456
def collate_fn_grpo(batch):
448-
input_ids = [item["input_ids"] for item in batch]
449-
attention_mask = [item["attention_mask"] for item in batch]
450-
labels = [item["labels"] for item in batch]
451-
# Assume input_ids, attention_mask, labels are already of the same length,
452-
# otherwise use pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
453-
input_ids = torch.stack(input_ids)
454-
attention_mask = torch.stack(attention_mask)
455-
labels = torch.stack(labels)
456-
ret = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
457-
if "test_cases" in batch[0]:
458-
ret["test_cases"] = [item["test_cases"] for item in batch]
459-
if "gt_answer" in batch[0]:
460-
ret["gt_answer"] = [item["gt_answer"] for item in batch]
461-
return ret
457+
if "input_ids" in batch[0]:
458+
# tokenized format
459+
input_ids = [item["input_ids"] for item in batch]
460+
attention_mask = [item["attention_mask"] for item in batch]
461+
labels = [item["labels"] for item in batch]
462+
# Assume input_ids, attention_mask, labels are already of the same length,
463+
# otherwise use pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
464+
input_ids = torch.stack(input_ids)
465+
attention_mask = torch.stack(attention_mask)
466+
labels = torch.stack(labels)
467+
ret = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
468+
if "test_cases" in batch[0]:
469+
ret["test_cases"] = [item["test_cases"] for item in batch]
470+
if "gt_answer" in batch[0]:
471+
ret["gt_answer"] = [item["gt_answer"] for item in batch]
472+
return ret
473+
elif "messages" in batch[0]:
474+
# vllm format
475+
ret = {
476+
"messages": [item["messages"] for item in batch],
477+
}
478+
if "test_cases" in batch[0]:
479+
ret["test_cases"] = [item["test_cases"] for item in batch]
480+
if "gt_answer" in batch[0]:
481+
ret["gt_answer"] = [item["gt_answer"] for item in batch]
482+
return ret
483+
else:
484+
raise ValueError("Unsupported batch format")

applications/ColossalChat/coati/distributed/agent/=0.3,

Whitespace-only changes.
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import copy
2+
import json
3+
from typing import Any, Dict
4+
5+
import ray
6+
import torch
7+
from coati.distributed.agent.agentic_math_utils import TIR_SYSTEM, CustomTransformers
8+
from coati.distributed.producer import BaseProducer
9+
from qwen_agent.agents import TIRMathAgent
10+
from vllm import SamplingParams
11+
12+
13+
@ray.remote
14+
class AgenticProducer(BaseProducer):
15+
"""
16+
Asyncronous version of the producer that uses vLLM for generation.
17+
This class is designed to generate agentic response
18+
"""
19+
20+
def __init__(
21+
self,
22+
producer_idx,
23+
num_producers,
24+
num_consumer_procs,
25+
num_episodes,
26+
batch_size,
27+
train_dataset_config,
28+
model_config,
29+
generate_config,
30+
async_producers,
31+
tokenizer_config=None,
32+
agentic_config=None,
33+
microbatch_size=1,
34+
backend="transformers",
35+
num_generations: int = 8,
36+
consumer_plugin_config=None,
37+
eval_dataset_config=None,
38+
eval_interval=-1, # disable evaluation
39+
grpo_config: Dict[str, Any] = None,
40+
eval_save_dir: str = "./eval",
41+
eval_generation_config={},
42+
project_name: str = None,
43+
run_name: str = None,
44+
wandb_group_name: str = None,
45+
log_rollout_interval: int = 20,
46+
rollout_log_file: str = "./rollout_log.jsonl",
47+
enable_profiling: bool = False,
48+
n_behind: int = 0,
49+
):
50+
assert microbatch_size == 1 # microbatch_size must be 1 for agentic producer
51+
assert batch_size == 1 # batch_size must be 1 for agentic producer
52+
super().__init__(
53+
producer_idx,
54+
num_producers,
55+
num_consumer_procs,
56+
num_episodes,
57+
batch_size,
58+
train_dataset_config,
59+
model_config,
60+
generate_config,
61+
tokenizer_config,
62+
microbatch_size,
63+
backend,
64+
consumer_plugin_config,
65+
eval_dataset_config=eval_dataset_config,
66+
eval_interval=eval_interval,
67+
grpo_config=grpo_config,
68+
eval_save_dir=eval_save_dir,
69+
project_name=project_name,
70+
run_name=run_name,
71+
wandb_group_name=wandb_group_name,
72+
log_rollout_interval=log_rollout_interval,
73+
rollout_log_file=rollout_log_file,
74+
enable_profiling=enable_profiling,
75+
n_behind=n_behind,
76+
enable_agentic=True,
77+
)
78+
self.eval_generation_config = copy.deepcopy(generate_config)
79+
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
80+
self.eval_generation_config.update(eval_generation_config)
81+
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
82+
self.async_producers = async_producers
83+
self.num_generations = num_generations
84+
self.generate_config = generate_config
85+
self.agentic_config = model_config if not agentic_config else agentic_config
86+
self.agentic_config.update({"model": model_config["path"]})
87+
self.llm = CustomTransformers(self.agentic_config, self.producer_idx, generation_workers=self.async_producers)
88+
self.bot = TIRMathAgent(llm=self.llm, name=model_config["path"], system_message=TIR_SYSTEM)
89+
90+
def rollout(self, **kwargs) -> Dict[str, torch.Tensor]:
91+
"""
92+
Rollout function to generate responses for the input, for example, using LLM or agentic pipeline.
93+
This function should be implemented in subclasses.
94+
"""
95+
assert len(kwargs["messages"]) == 1, "Only support batch size of 1 for agentic producer"
96+
messages = kwargs["messages"][0]
97+
prompt_input_ids = self.tokenizer.apply_chat_template(
98+
messages, return_tensors="pt", tokenize=True, add_generation_prompt=True
99+
)
100+
# add left padding
101+
prompt_length = prompt_input_ids.shape[1]
102+
max_prompt_length = self.train_dataset_config["max_length"]
103+
to_pad_left = max_prompt_length - prompt_length
104+
rollouts = {
105+
"input_ids": [],
106+
"attention_mask": [],
107+
"action_mask": [],
108+
"action_log_probs": [],
109+
"response_idx": [],
110+
}
111+
for i in range(self.num_generations):
112+
_messages = copy.deepcopy(messages)
113+
for response in self.bot.run(messages):
114+
continue
115+
_messages.extend(response)
116+
response_input_ids = self.tokenizer.apply_chat_template(_messages, return_tensors="pt", tokenize=True)
117+
# truncate if too long
118+
response_input_ids = response_input_ids[:, : self.grpo_config["max_length"] - to_pad_left]
119+
# add left right padding
120+
to_pad_right = self.grpo_config["max_length"] - response_input_ids.shape[1] - to_pad_left
121+
response_length = response_input_ids.shape[1] - prompt_length
122+
input_ids = torch.nn.functional.pad(
123+
response_input_ids, (to_pad_left, to_pad_right), "constant", value=self.tokenizer.pad_token_id
124+
) # [1, max_length]
125+
attention_mask = torch.nn.functional.pad(
126+
torch.ones_like(response_input_ids), (to_pad_left, to_pad_right), "constant", value=0
127+
) # [1, max_length]
128+
action_mask = torch.nn.functional.pad(
129+
torch.ones(size=(1, response_length)), (0, to_pad_right), "constant", value=0
130+
) # [1, max_length-prompt_length]
131+
rollouts["attention_mask"].append(attention_mask)
132+
rollouts["action_mask"].append(action_mask)
133+
rollouts["action_log_probs"].append(
134+
torch.ones(size=(1, self.grpo_config["max_length"] - max_prompt_length))
135+
) # dummy log probs
136+
rollouts["response_idx"].append(
137+
torch.tensor(
138+
[
139+
[
140+
self.train_dataset_config["max_length"],
141+
self.train_dataset_config["max_length"] + response_length,
142+
]
143+
]
144+
)
145+
) # [1, 2]
146+
rollouts["input_ids"].append(input_ids)
147+
# breakpoint()
148+
rollouts = {k: torch.cat(v, dim=0).unsqueeze(0) for k, v in rollouts.items()} # [num_generations, ...]
149+
rollouts["temperature"] = torch.tensor([self.agentic_config.get("temperature", 1.0)])
150+
if hasattr(self, "rollout_log_file") and self.producer_idx == 0 and not self.eval_mode:
151+
# for agentic producer, AsyncSimpleProducer is not the main producer, so we don't log rollouts
152+
if (
153+
self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
154+
or self.latest_rollout_log_step == -1
155+
):
156+
new_record = (
157+
json.dumps(
158+
{
159+
"train_step": self.consumer_global_step,
160+
"rollout": self.tokenizer.batch_decode(
161+
rollouts["input_ids"][:, 0], skip_special_tokens=True
162+
),
163+
}
164+
)
165+
+ "\n"
166+
)
167+
self.rollout_log_file.write(new_record)
168+
self.rollout_log_file.flush()
169+
self.latest_rollout_log_step = self.consumer_global_step
170+
171+
if "gt_answer" in kwargs:
172+
rollouts["gt_answer"] = kwargs["gt_answer"]
173+
if "test_cases" in kwargs:
174+
rollouts["test_cases"] = kwargs["test_cases"]
175+
return rollouts
176+
177+
def sync_model(self, episode, step) -> None:
178+
"""
179+
sync model from consumer to self.async_producers
180+
AgenticProducer does not hold any model weights, so no need to sync model to self.async_producers
181+
"""
182+
tasks = []
183+
for proc in self.async_producers:
184+
tasks.append(proc.async_sync_model.remote(episode, step, self.num_producers))
185+
ray.get(tasks)
186+
return
187+
188+
def sync_data(self, data: Dict[str, torch.Tensor]) -> None:
189+
"""
190+
sync data from self to consumer
191+
"""
192+
tasks = []
193+
for idx, proc in enumerate(self.async_producers):
194+
if idx == self.producer_idx % len(self.async_producers):
195+
tasks.append(proc.async_sync_data.remote(data, self.num_producers))
196+
else:
197+
tasks.append(proc.async_sync_data.remote({}, self.num_producers))
198+
ray.get(tasks)
199+
return

0 commit comments

Comments
 (0)