Skip to content

Commit e1a38e7

Browse files
committed
Merge branch 'grpo_optimization' of https://github.com/hpcaitech/ColossalAI into grpo_optimization
2 parents 8880b83 + db8baee commit e1a38e7

File tree

8 files changed

+317
-61
lines changed

8 files changed

+317
-61
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 94 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import ray.util.collective as cc
77
import torch
88
import torch.distributed as dist
9+
from coati.distributed.profiling_utils import CustomProfiler
910
from tqdm import tqdm
1011
from transformers import AutoModelForCausalLM
1112

@@ -36,6 +37,8 @@ def __init__(
3637
minibatch_size: int = 1,
3738
save_interval: int = 100,
3839
save_dir: str = "./model",
40+
enable_profiling: bool = False,
41+
n_behind: int = 0,
3942
):
4043
self.num_producers = num_producers
4144
self.num_episodes = num_episodes
@@ -49,6 +52,7 @@ def __init__(
4952
self.minibatch_size = minibatch_size
5053
self.save_interval = save_interval
5154
self.save_dir = save_dir
55+
self.enable_profiling = enable_profiling
5256
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
5357
self.num_microbatches = batch_size // minibatch_size
5458

@@ -57,6 +61,7 @@ def __init__(
5761

5862
self.device = get_current_device()
5963
self.lr_scheduler = None
64+
self.n_behind = n_behind
6065

6166
def setup(self) -> None:
6267
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
@@ -94,13 +99,45 @@ def setup(self) -> None:
9499

95100
self.buffer = []
96101
self.recv_cnt = 0
102+
self.profiler = CustomProfiler(f"C{self.rank}", disabled=not self.enable_profiling)
97103

98104
def state_dict(self) -> Dict[str, torch.Tensor]:
99105
raise NotImplementedError
100106

101107
def step(self, step_idx: int, **kwargs) -> Optional[float]:
102108
raise NotImplementedError
103109

110+
def prepare_mini_batch(self, effective_group_to_raw_group_mapping: Dict[int, int]) -> Dict[str, torch.Tensor]:
111+
"""
112+
Prepare a mini-batch from the effective group to raw group mapping.
113+
This method is used to create a mini-batch for training.
114+
"""
115+
batches = [
116+
self.buffer[effective_group_to_raw_group_mapping[i]]
117+
for i in range(self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size)
118+
]
119+
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
120+
# each mini-batch use the first self.dp_size * minibatch_size effective samples
121+
raw_mini_batches = self.buffer[
122+
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
123+
] # include the last effective sample
124+
raw_mini_batches_metric_dict = {
125+
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
126+
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
127+
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
128+
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
129+
}
130+
batch = bind_batch([t[0] for t in batches])
131+
batch = post_recv(batch)
132+
return batch, raw_mini_batches_metric_dict
133+
134+
def calculate_effective_group_to_raw_group_mapping(self):
135+
effective_group_to_raw_group_mapping = {}
136+
for buffer_idx in range(len(self.buffer)):
137+
if self.buffer[buffer_idx][0] is not None:
138+
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx
139+
return effective_group_to_raw_group_mapping
140+
104141
def loop(self) -> None:
105142
print(
106143
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
@@ -112,14 +149,49 @@ def loop(self) -> None:
112149
disable=self.rank != 0,
113150
) as pbar:
114151
for step in pbar:
152+
torch.cuda.reset_peak_memory_stats()
115153
i = 0
116154
for _ in range(self.num_recv_per_update):
155+
# after sync model, do not wait for more data to arrive as rollout takes time, use buffered data
156+
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
157+
while len(effective_group_to_raw_group_mapping) > max(
158+
self.dp_size * self.batch_size
159+
- self.dp_size
160+
* self.minibatch_size
161+
* self.grpo_config.get("num_minibatch_during_rollout", 1),
162+
self.dp_size * self.minibatch_size,
163+
):
164+
self.profiler.log(
165+
f"Still have {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training"
166+
)
167+
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
168+
effective_group_to_raw_group_mapping
169+
)
170+
self.profiler.enter("step")
171+
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
172+
self.profiler.exit("step")
173+
self.buffer = self.buffer[
174+
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
175+
]
176+
# recalculate the effective group to raw group mapping
177+
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
178+
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
179+
assert (
180+
len(effective_group_to_raw_group_mapping)
181+
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
182+
)
183+
if loss is not None:
184+
pbar.set_postfix({"loss": loss})
185+
i += 1
186+
117187
# receive data from producers
118188
for r in range(self.num_producers):
119189
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
190+
self.profiler.enter(f"recv_broadcast_data_P{r}")
120191
raw_batch = ray_broadcast_tensor_dict(
121192
None, src=0, device=self.device, group_name=f"sync_data_{r}"
122193
)
194+
self.profiler.exit(f"recv_broadcast_data_P{r}")
123195
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
124196
# we need to calculate the metrics before filtering here for logging
125197
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
@@ -161,49 +233,29 @@ def loop(self) -> None:
161233
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
162234
)
163235
# mapping the effective group to the raw group for indexing
164-
effective_group_to_raw_group_mapping = {}
165-
for buffer_idx in range(len(self.buffer)):
166-
if self.buffer[buffer_idx][0] is not None:
167-
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
168-
buffer_idx
169-
)
236+
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
170237
print(
171238
f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"
172239
)
173240

174-
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
241+
while len(effective_group_to_raw_group_mapping) > self.dp_size * self.batch_size:
242+
self.profiler.log(
243+
f"Received {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.batch_size}, start training after recv"
244+
)
245+
# always keep at least dp_size * batch_size effective samples in the buffer for training during the rollout times after each sync model
175246
# on each dp_rank, we use minibatch_size effective samples to form a batch
176-
batches = [
177-
self.buffer[effective_group_to_raw_group_mapping[i]]
178-
for i in range(
179-
self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size
180-
)
181-
]
182-
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
183-
# each mini-batch use the first self.dp_size * minibatch_size effective samples
184-
raw_mini_batches = self.buffer[
185-
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
186-
] # include the last effective sample
187-
raw_mini_batches_metric_dict = {
188-
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
189-
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
190-
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
191-
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
192-
}
193-
batch = bind_batch([t[0] for t in batches])
194-
batch = post_recv(batch)
247+
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
248+
effective_group_to_raw_group_mapping
249+
)
250+
self.profiler.enter("step")
195251
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
252+
self.profiler.exit("step")
196253
self.buffer = self.buffer[
197254
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
198255
]
199256
# recalculate the effective group to raw group mapping
200257
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
201-
effective_group_to_raw_group_mapping = {}
202-
for buffer_idx in range(len(self.buffer)):
203-
if self.buffer[buffer_idx][0] is not None:
204-
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
205-
buffer_idx
206-
)
258+
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
207259
assert (
208260
len(effective_group_to_raw_group_mapping)
209261
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
@@ -221,13 +273,16 @@ def loop(self) -> None:
221273
if self.rank == 0:
222274
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
223275

224-
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
276+
if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and (
277+
episode != 0 or step >= self.n_behind
278+
):
225279
if self.pp_size > 1:
226280
print(
227281
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
228282
)
229283
else:
230284
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
285+
self.profiler.enter("sync_model")
231286
torch.cuda.empty_cache()
232287
state_dict = self.state_dict()
233288
if self.pp_size > 1:
@@ -245,6 +300,12 @@ def loop(self) -> None:
245300
)
246301
del state_dict
247302
torch.cuda.empty_cache()
303+
self.profiler.exit("sync_model")
304+
self.profiler.log(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
305+
306+
def __del__(self):
307+
if hasattr(self, "profiler"):
308+
self.profiler.close()
248309

249310

250311
@ray.remote

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def __init__(
3838
project_name: str = None,
3939
run_name: str = None,
4040
wandb_group_name: str = None,
41+
enable_profiling: bool = False,
42+
n_behind: int = 0,
4143
):
4244
print(f"Using GRPO config: {grpo_config}")
4345
if (
@@ -63,6 +65,8 @@ def __init__(
6365
minibatch_size,
6466
save_interval=save_interval,
6567
save_dir=save_dir,
68+
enable_profiling=enable_profiling,
69+
n_behind=n_behind,
6670
)
6771
path = model_config.pop("path")
6872
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def launch_distributed(
5757
eval_generation_config: Optional[Dict[str, Any]] = None,
5858
log_rollout_interval: int = 20,
5959
rollout_save_dir: str = "./rollout",
60+
enable_profiling: bool = False,
61+
n_behind: int = 0,
6062
):
6163
if core_algo not in ALGO_MAP:
6264
raise NotImplementedError(f"{core_algo} is not supported yet.")
@@ -132,6 +134,8 @@ def launch_distributed(
132134
wandb_group_name=wandb_group_name,
133135
log_rollout_interval=log_rollout_interval,
134136
rollout_log_file=rollout_log_file,
137+
enable_profiling=enable_profiling,
138+
n_behind=n_behind,
135139
)
136140
producer_procs.append(producer)
137141
ray.get([p.setup.remote() for p in producer_procs])
@@ -171,6 +175,8 @@ def launch_distributed(
171175
project_name=project_name,
172176
run_name=run_name,
173177
wandb_group_name=wandb_group_name,
178+
enable_profiling=enable_profiling,
179+
n_behind=n_behind,
174180
)
175181
consumer_procs.append(consumer)
176182
ray.get([p.setup.remote() for p in consumer_procs])

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import tqdm
1010
import wandb
1111
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
12+
from coati.distributed.profiling_utils import CustomProfiler
1213
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
1314
from coati.distributed.reward.verifiable_reward import VerifiableReward
1415
from ray.util.collective import allreduce
@@ -52,6 +53,8 @@ def __init__(
5253
wandb_group_name: str = None,
5354
log_rollout_interval: int = 20,
5455
rollout_log_file: str = "./rollout_log.jsonl",
56+
enable_profiling: bool = False,
57+
n_behind: int = 0,
5558
):
5659
self.producer_idx = producer_idx
5760
self.num_producers = num_producers
@@ -62,6 +65,7 @@ def __init__(
6265
assert batch_size % microbatch_size == 0
6366
self.num_microbatches = batch_size // microbatch_size
6467
self.latest_eval_step = -1
68+
self.profiler = CustomProfiler(f"P{self.producer_idx}", disabled=not enable_profiling)
6569

6670
self.train_dataset_config = train_dataset_config
6771
self.model_config = model_config
@@ -75,6 +79,7 @@ def __init__(
7579
self.log_rollout_interval = log_rollout_interval
7680
self.latest_rollout_log_step = -1
7781
self.grpo_config = grpo_config
82+
self.n_behind = n_behind
7883
reward_model_kwargs = {
7984
k: v
8085
for k, v in grpo_config.items()
@@ -268,11 +273,14 @@ def loop(self) -> None:
268273
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
269274
self.eval_mode = False
270275
self.latest_eval_step = self.consumer_global_step
276+
self.profiler.enter("rollout")
271277
outputs = self.rollout(**batch)
278+
self.profiler.exit("rollout")
272279
outputs["temperature"] = torch.tensor(
273280
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
274281
).to(outputs["input_ids"].device)
275282
bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1)
283+
self.profiler.enter("calculate_reward")
276284
if self.grpo_config["reward_fn_type"] == "code":
277285
test_cases = []
278286
for prompt_id in range(bs):
@@ -310,20 +318,26 @@ def loop(self) -> None:
310318
outputs.pop("gt_answer")
311319
if "test_cases" in outputs:
312320
outputs.pop("test_cases")
321+
self.profiler.exit("calculate_reward")
313322

314323
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
315324
outputs = pre_send(outputs)
325+
self.profiler.enter("send_broadcast_data")
316326
ray_broadcast_tensor_dict(
317327
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
318328
)
319-
if (i + 1) % self.num_microbatches == 0 and (
320-
episode != self.num_episodes - 1 or i != num_valid_microbatches - 1
329+
self.profiler.exit("send_broadcast_data")
330+
if (
331+
(i + 1) % self.num_microbatches == 0
332+
and (episode != self.num_episodes - 1 or i != num_valid_microbatches - 1)
333+
and (episode != 0 or (i + 1) > self.n_behind * self.num_microbatches)
321334
):
322335
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
323336
"enable_sleep_mode", False
324337
):
325338
self.model.llm.sleep() # revict KV_cache to avoid OOM
326339
# don't sync model for last iteration
340+
self.profiler.enter("sync_model")
327341
torch.cuda.empty_cache()
328342

329343
if self.consumer_pp_size > 1:
@@ -349,6 +363,7 @@ def loop(self) -> None:
349363
self.load_state_dict(state_dict)
350364
del state_dict
351365
torch.cuda.empty_cache()
366+
self.profiler.exit("sync_model")
352367
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
353368
"enable_sleep_mode", False
354369
):
@@ -364,6 +379,9 @@ def loop(self) -> None:
364379
"temperature"
365380
] + ratio * 0.9
366381

382+
def __del__(self):
383+
self.profiler.close()
384+
367385

368386
@ray.remote
369387
class SimpleProducer(BaseProducer):
@@ -392,6 +410,8 @@ def __init__(
392410
wandb_group_name: str = None,
393411
log_rollout_interval: int = 20,
394412
rollout_log_file: str = "./rollout_log.jsonl",
413+
enable_profiling: bool = False,
414+
n_behind: int = 0,
395415
):
396416
super().__init__(
397417
producer_idx,
@@ -415,6 +435,8 @@ def __init__(
415435
wandb_group_name=wandb_group_name,
416436
log_rollout_interval=log_rollout_interval,
417437
rollout_log_file=rollout_log_file,
438+
enable_profiling=enable_profiling,
439+
n_behind=n_behind,
418440
)
419441
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
420442
self.eval_generation_config = copy.deepcopy(self.model.generate_config)

0 commit comments

Comments
 (0)