Skip to content

Commit 7ee4452

Browse files
committed
fix vllm
1 parent 7795d4c commit 7ee4452

File tree

5 files changed

+172
-24
lines changed

5 files changed

+172
-24
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 145 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import json
2+
import os
13
from contextlib import nullcontext
24
from typing import Optional
35

46
import ray
57
import torch
8+
import torch.distributed as dist
69
import wandb
710
from coati.distributed.consumer import BaseConsumer
811
from coati.distributed.loss import PolicyLoss
@@ -33,6 +36,8 @@ def __init__(
3336
microbatch_size=1,
3437
num_generations=4,
3538
use_wandb=True,
39+
generator_config=None,
40+
filter_range=None,
3641
):
3742
super().__init__(
3843
num_producers,
@@ -69,6 +74,9 @@ def __init__(
6974
self.tokenizer = AutoTokenizer.from_pretrained(path)
7075
self.pad_token_id = self.tokenizer.pad_token_id
7176
self.num_generations = num_generations
77+
self.filter_range = filter_range
78+
if self.filter_range is not None:
79+
assert len(self.filter_range) == 2, "Filter range should have 2 values."
7280

7381
# Initialize verifiable reward.
7482
response_format_tags = {
@@ -84,7 +92,11 @@ def __init__(
8492
self.policy_loss_fn = PolicyLoss()
8593
self.global_step = 0
8694
if use_wandb and self.rank == 0:
87-
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True)
95+
if "repetition_penalty" in generator_config:
96+
name = f"{generator_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generator_config['temperature']:.01f}_rep_penalty_{generator_config['repetition_penalty']:.01f}"
97+
else:
98+
name = f"{generator_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generator_config['temperature']:.01f}"
99+
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name)
88100

89101
def setup(self):
90102
super().setup()
@@ -121,7 +133,7 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
121133
attention_mask=data["attention_mask"],
122134
)["logits"]
123135
action_log_probs = calc_action_log_probs(
124-
policy_model_logits, data["input_ids"], num_action, self.plugin.shard_config
136+
policy_model_logits / generator_config["temperature"], data["input_ids"], num_action, self.plugin.shard_config
125137
)
126138

127139
with torch.no_grad():
@@ -130,7 +142,7 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
130142
attention_mask=data["attention_mask"],
131143
)["logits"]
132144
reference_action_log_probs = calc_action_log_probs(
133-
reference_model_logits, data["input_ids"], num_action, self.plugin.shard_config
145+
reference_model_logits / generator_config["temperature"], data["input_ids"], num_action, self.plugin.shard_config
134146
)
135147

136148
per_token_kl = (
@@ -149,7 +161,14 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
149161
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
150162

151163
# [batch_size, num_generations]
164+
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
165+
loss_mask = (
166+
None
167+
if self.filter_range is None
168+
else torch.logical_and(reward > self.filter_range[0], reward < self.filter_range[1])
169+
)
152170
group_reward = reward.view(-1, self.num_generations)
171+
reward_mean = group_reward.mean(dim=1)
153172

154173
# [batch_size x num_generations]
155174
reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0)
@@ -164,6 +183,7 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
164183
advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
165184
per_token_kl,
166185
action_mask,
186+
loss_mask=loss_mask,
167187
)
168188

169189
if not skip_update:
@@ -232,3 +252,125 @@ def state_dict(self):
232252
model = self.policy_model.unwrap()
233253
state_dict = model.state_dict()
234254
return state_dict
255+
256+
257+
@ray.remote
258+
class GRPOEvalConsumer(BaseConsumer):
259+
def __init__(
260+
self,
261+
num_producers,
262+
num_episodes,
263+
rank,
264+
world_size,
265+
master_addr,
266+
master_port,
267+
num_update_per_episode,
268+
num_recv_per_update,
269+
batch_size,
270+
model_config,
271+
plugin_config,
272+
microbatch_size=1,
273+
num_generations=4,
274+
use_wandb=True,
275+
log_dir="./results",
276+
):
277+
super().__init__(
278+
num_producers,
279+
num_episodes,
280+
rank,
281+
world_size,
282+
master_addr,
283+
master_port,
284+
num_update_per_episode,
285+
num_recv_per_update,
286+
batch_size,
287+
model_config,
288+
plugin_config,
289+
microbatch_size,
290+
)
291+
path = model_config.pop("path")
292+
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
293+
self.policy_model.train()
294+
self.accum_reward = torch.zeros(1, device=self.device)
295+
self.accum_format_reward = torch.zeros(1, device=self.device)
296+
self.accum_acc_reward = torch.zeros(1, device=self.device)
297+
self.accum_response_length = torch.zeros(1, device=self.device)
298+
self.accum_count = torch.zeros(1, device=self.device)
299+
300+
self.tokenizer = AutoTokenizer.from_pretrained(path)
301+
self.pad_token_id = self.tokenizer.pad_token_id
302+
self.num_generations = num_generations
303+
304+
# Initialize verifiable reward.
305+
response_format_tags = {
306+
"think_start": {"text": "<think>", "num_occur": 1},
307+
"think_end": {"text": "</think>", "num_occur": 1},
308+
"answer_start": {"text": "<answer>", "num_occur": 1},
309+
"answer_end": {"text": "</answer>", "num_occur": 1},
310+
}
311+
self.reward_model = VerifiableReward(
312+
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
313+
)
314+
315+
self.log_dir = log_dir
316+
if not os.path.exists(self.log_dir):
317+
os.makedirs(self.log_dir)
318+
else:
319+
os.system(f"rm -rf {self.log_dir}/*")
320+
321+
def setup(self):
322+
super().setup()
323+
self.policy_model, _, *_ = self.booster.boost(self.policy_model)
324+
325+
def step(self, step_idx: int, **kwargs) -> Optional[float]:
326+
rank = dist.get_rank()
327+
data = {k: v.view(-1, v.size(-1)).cpu() for k, v in kwargs.items()}
328+
kwargs["input_ids"].size(0)
329+
reward_group = self.reward_model(
330+
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
331+
)
332+
reward = [value[0].item() for value in reward_group]
333+
format_reward = [value[1].item() for value in reward_group]
334+
acc_reward = [value[2].item() for value in reward_group]
335+
response_length = [(data["response_idx"][i][1] - data["response_idx"][i][0]).item() for i in range(len(reward))]
336+
337+
response = self.tokenizer.batch_decode(data["input_ids"], skip_special_tokens=True)
338+
with open(f"{self.log_dir}/eval_results_rank_{rank}.jsonl", "a", encoding="utf8") as f:
339+
for i in range(len(response)):
340+
f.write(
341+
json.dumps(
342+
{
343+
"response": response[i],
344+
"reward": reward[i],
345+
"format_reward": format_reward[i],
346+
"acc_reward": acc_reward[i],
347+
"response_length": response_length[i],
348+
},
349+
ensure_ascii=False,
350+
)
351+
+ "\n"
352+
)
353+
354+
self.accum_reward += sum(reward)
355+
self.accum_format_reward += sum(format_reward)
356+
self.accum_acc_reward += sum(acc_reward)
357+
self.accum_response_length += sum(response_length)
358+
self.accum_count += len(reward)
359+
360+
# print results
361+
total_count = all_reduce_mean(self.accum_count, self.plugin)
362+
mean_reward = all_reduce_mean(self.accum_reward, self.plugin) / total_count
363+
mean_format_reward = all_reduce_mean(self.accum_format_reward, self.plugin) / total_count
364+
mean_acc_reward = all_reduce_mean(self.accum_acc_reward, self.plugin) / total_count
365+
mean_response_length = all_reduce_mean(self.accum_response_length, self.plugin) / total_count
366+
if rank == 0:
367+
print(
368+
f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_reward}, Mean Acc Reward: {mean_acc_reward}, Mean Response Length: {mean_response_length}"
369+
)
370+
return None
371+
372+
def state_dict(self):
373+
self.policy_model._force_wait_all_gather()
374+
model = self.policy_model.unwrap()
375+
state_dict = model.state_dict()
376+
return state_dict

applications/ColossalChat/coati/distributed/inference_backend.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any]
183183
raise ImportError("vllm is not installed")
184184
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
185185
path = model_config.pop("path")
186-
self.llm = LLM(path, **model_config)
186+
self.llm = LLM(model=path, **model_config)
187187
generate_config = generate_config.copy()
188188
generate_config.update(self.FORCE_GENERATE_CONFIG)
189189
self.generate_config = SamplingParams(**generate_config)
@@ -194,8 +194,15 @@ def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any]
194194
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
195195
micro_batch_size = input_ids.size(0)
196196
response_start_idx = input_ids.size(1)
197+
micro_batch_input_ids = input_ids.tolist()
198+
micro_batch_input_ids_no_padding = []
199+
for i in range(micro_batch_size):
200+
for j in range(input_ids.size(1)):
201+
if micro_batch_input_ids[i][j] != self.tokenizer.pad_token_id:
202+
micro_batch_input_ids_no_padding.append(micro_batch_input_ids[i][j:])
203+
break
197204
outputs = self.llm.generate(
198-
prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False
205+
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
199206
)
200207
out_tokens = []
201208
out_len = []

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1+
import copy
12
from typing import Any, Dict, Optional
23

34
import ray
45

56
from .consumer import SimpleConsumer
6-
from .grpo_consumer import GRPOConsumer
7+
from .grpo_consumer import GRPOConsumer, GRPOEvalConsumer
78
from .producer import SimpleProducer
89

9-
ALGO_MAP = {
10-
"Simple": SimpleConsumer,
11-
"GRPO": GRPOConsumer,
12-
}
10+
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "EvalGRPO": GRPOEvalConsumer}
1311

1412

1513
def get_jsonl_size_fast(path: str) -> int:
@@ -80,6 +78,12 @@ def launch_distributed(
8078
backend=inference_backend,
8179
)
8280
procs.append(producer)
81+
generate_config_consumer = copy.deepcopy(generate_config)
82+
generate_config_consumer.update(
83+
dict(
84+
backend=inference_backend,
85+
)
86+
)
8387
for i in range(num_consumer_procs):
8488
consumer = core_consumer.options(num_gpus=1).remote(
8589
num_producers=num_producers,
@@ -94,6 +98,8 @@ def launch_distributed(
9498
model_config=train_model_config,
9599
plugin_config=plugin_config,
96100
microbatch_size=train_microbatch_size,
101+
generate_config=generate_config_consumer,
102+
filter_range=[0.05, 9.0],
97103
)
98104
procs.append(consumer)
99105
ray.get([p.setup.remote() for p in procs])

applications/ColossalChat/coati/distributed/loss.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def forward(
2323
advantages: torch.Tensor,
2424
per_token_kl: torch.Tensor,
2525
action_mask: Optional[torch.Tensor] = None,
26+
loss_mask: Optional[torch.Tensor] = None,
2627
) -> torch.Tensor:
2728
skip = False
2829
if action_mask is None:
@@ -38,5 +39,7 @@ def forward(
3839
loss = masked_mean(loss, action_mask)
3940
else:
4041
loss = loss.mean(dim=1)
42+
if loss_mask is not None:
43+
loss = loss * loss_mask
4144
loss = loss.mean()
4245
return loss, skip, ratio.max()

applications/ColossalChat/rl_example.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,14 @@
1515
parser.add_argument("-tbs", "--train-batch-size", type=int, default=32)
1616
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1)
1717
parser.add_argument("-b", "--backend", type=str, default="transformers")
18-
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple, GRPO"])
18+
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
1919
args = parser.parse_args()
2020

2121
ray.init(address="local", namespace="ray-example")
2222

2323
inference_model_config = dict(path=args.model)
2424
train_model_config = dict(path=args.model)
25-
generate_config = dict(
26-
top_k=50,
27-
top_p=0.9,
28-
temperature=1.0,
29-
)
25+
generate_config = dict(top_k=50, top_p=0.9, temperature=0.7)
3026

3127
if args.backend == "transformers":
3228
inference_model_config.update(
@@ -52,19 +48,13 @@
5248
)
5349
)
5450
elif args.backend == "vllm":
55-
inference_model_config.update(
56-
dict(
57-
gpu_memory_utilization=0.7,
58-
)
59-
)
51+
inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True))
6052
generate_config.update(
6153
dict(
6254
max_tokens=2048,
6355
ignore_eos=True,
6456
include_stop_str_in_output=True,
6557
stop=["</answer>"],
66-
temperature=0.7,
67-
top_p=0.95,
6858
)
6959
)
7060
else:
@@ -97,6 +87,6 @@
9787
plugin_config={},
9888
inference_backend=args.backend,
9989
master_addr="localhost",
100-
master_port=29504,
90+
master_port=29503,
10191
core_algo=args.algo,
10292
)

0 commit comments

Comments
 (0)