Skip to content

Commit 594c2c6

Browse files
YeAnbangTong Li
andcommitted
[feat[ Support one-behind to reduce bubble time. Add profiling code (#6353)
* support n_behind, add profiling * fix bugs * fix visualization * fix behind * fix loop issue * add profiling * fix update * update assert * remove assert --------- Co-authored-by: Tong Li <[email protected]>
1 parent 685e0bd commit 594c2c6

File tree

8 files changed

+365
-82
lines changed

8 files changed

+365
-82
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 125 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import ray.util.collective as cc
77
import torch
88
import torch.distributed as dist
9+
from coati.distributed.profiling_utils import CustomProfiler
910
from tqdm import tqdm
1011
from transformers import AutoModelForCausalLM
1112

@@ -36,6 +37,8 @@ def __init__(
3637
minibatch_size: int = 1,
3738
save_interval: int = 100,
3839
save_dir: str = "./model",
40+
enable_profiling: bool = False,
41+
n_behind: int = 0,
3942
):
4043
self.num_producers = num_producers
4144
self.num_episodes = num_episodes
@@ -49,6 +52,7 @@ def __init__(
4952
self.minibatch_size = minibatch_size
5053
self.save_interval = save_interval
5154
self.save_dir = save_dir
55+
self.enable_profiling = enable_profiling
5256
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
5357
self.num_microbatches = batch_size // minibatch_size
5458

@@ -57,6 +61,7 @@ def __init__(
5761

5862
self.device = get_current_device()
5963
self.lr_scheduler = None
64+
self.n_behind = n_behind
6065

6166
def setup(self) -> None:
6267
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
@@ -94,13 +99,49 @@ def setup(self) -> None:
9499

95100
self.buffer = []
96101
self.recv_cnt = 0
102+
self.profiler = CustomProfiler(f"C{self.rank}", disabled=not self.enable_profiling)
97103

98104
def state_dict(self) -> Dict[str, torch.Tensor]:
99105
raise NotImplementedError
100106

101107
def step(self, step_idx: int, **kwargs) -> Optional[float]:
102108
raise NotImplementedError
103109

110+
def prepare_mini_batch(self, effective_group_to_raw_group_mapping: Dict[int, int]) -> Dict[str, torch.Tensor]:
111+
"""
112+
Prepare a mini-batch from the effective group to raw group mapping.
113+
This method is used to create a mini-batch for training.
114+
"""
115+
batches = [
116+
self.buffer[effective_group_to_raw_group_mapping[i]]
117+
for i in range(self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size)
118+
]
119+
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
120+
# each mini-batch use the first self.dp_size * minibatch_size effective samples
121+
raw_mini_batches = self.buffer[
122+
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
123+
] # include the last effective sample
124+
raw_mini_batches_metric_dict = {
125+
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
126+
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
127+
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
128+
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
129+
}
130+
batch = bind_batch([t[0] for t in batches])
131+
batch = post_recv(batch)
132+
return batch, raw_mini_batches_metric_dict
133+
134+
def calculate_effective_group_to_raw_group_mapping(self, step):
135+
effective_group_to_raw_group_mapping = {}
136+
for buffer_idx in range(len(self.buffer)):
137+
if self.buffer[buffer_idx][0] is not None:
138+
if self.n_behind == 0:
139+
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx
140+
else:
141+
if self.buffer[buffer_idx][-1] <= step - self.n_behind:
142+
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx
143+
return effective_group_to_raw_group_mapping
144+
104145
def loop(self) -> None:
105146
print(
106147
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
@@ -112,14 +153,53 @@ def loop(self) -> None:
112153
disable=self.rank != 0,
113154
) as pbar:
114155
for step in pbar:
156+
torch.cuda.reset_peak_memory_stats()
115157
i = 0
158+
159+
self.profiler.enter(f"rollout_episode_{episode}_step_{step}")
116160
for _ in range(self.num_recv_per_update):
161+
if self.n_behind > 0:
162+
# after sync model, do not wait for more data to arrive as rollout takes time, use buffered data
163+
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping(
164+
step=step
165+
)
166+
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
167+
self.profiler.log(
168+
f"Still have {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training"
169+
)
170+
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
171+
effective_group_to_raw_group_mapping
172+
)
173+
self.profiler.enter("step")
174+
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
175+
self.profiler.exit("step")
176+
self.buffer = self.buffer[
177+
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
178+
]
179+
# recalculate the effective group to raw group mapping
180+
effective_group_to_raw_group_mapping_size_before = len(
181+
effective_group_to_raw_group_mapping
182+
)
183+
effective_group_to_raw_group_mapping = (
184+
self.calculate_effective_group_to_raw_group_mapping(step=step)
185+
)
186+
assert (
187+
len(effective_group_to_raw_group_mapping)
188+
== effective_group_to_raw_group_mapping_size_before
189+
- self.dp_size * self.minibatch_size
190+
)
191+
if loss is not None:
192+
pbar.set_postfix({"loss": loss})
193+
i += 1
194+
117195
# receive data from producers
118196
for r in range(self.num_producers):
119197
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
198+
self.profiler.enter(f"recv_broadcast_data_P{r}")
120199
raw_batch = ray_broadcast_tensor_dict(
121200
None, src=0, device=self.device, group_name=f"sync_data_{r}"
122201
)
202+
self.profiler.exit(f"recv_broadcast_data_P{r}")
123203
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
124204
# we need to calculate the metrics before filtering here for logging
125205
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
@@ -153,63 +233,52 @@ def loop(self) -> None:
153233
format_acc[group_idx],
154234
ans_acc[group_idx],
155235
response_len[group_idx],
236+
step,
156237
]
157238
)
158239
if effective_group_mask is not None:
159240
print(
160241
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
161242
)
162243
# mapping the effective group to the raw group for indexing
163-
effective_group_to_raw_group_mapping = {}
164-
for buffer_idx in range(len(self.buffer)):
165-
if self.buffer[buffer_idx][0] is not None:
166-
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
167-
buffer_idx
168-
)
244+
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping(
245+
step=step
246+
)
169247
print(
170248
f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"
171249
)
172250

173-
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
174-
# on each dp_rank, we use minibatch_size effective samples to form a batch
175-
batches = [
176-
self.buffer[effective_group_to_raw_group_mapping[i]]
177-
for i in range(
178-
self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size
251+
if self.n_behind == 0:
252+
# If n_behind is 0, we start training after receiving data from producers.
253+
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
254+
self.profiler.log(
255+
f"Collect {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training"
179256
)
180-
]
181-
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
182-
# each mini-batch use the first self.dp_size * minibatch_size effective samples
183-
raw_mini_batches = self.buffer[
184-
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
185-
] # include the last effective sample
186-
raw_mini_batches_metric_dict = {
187-
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
188-
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
189-
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
190-
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
191-
}
192-
batch = bind_batch([t[0] for t in batches])
193-
batch = post_recv(batch)
194-
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
195-
self.buffer = self.buffer[
196-
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
197-
]
198-
# recalculate the effective group to raw group mapping
199-
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
200-
effective_group_to_raw_group_mapping = {}
201-
for buffer_idx in range(len(self.buffer)):
202-
if self.buffer[buffer_idx][0] is not None:
203-
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
204-
buffer_idx
205-
)
206-
assert (
207-
len(effective_group_to_raw_group_mapping)
208-
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
209-
)
210-
if loss is not None:
211-
pbar.set_postfix({"loss": loss})
212-
i += 1
257+
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
258+
effective_group_to_raw_group_mapping
259+
)
260+
self.profiler.enter("step")
261+
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
262+
self.profiler.exit("step")
263+
self.buffer = self.buffer[
264+
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
265+
]
266+
# recalculate the effective group to raw group mapping
267+
effective_group_to_raw_group_mapping_size_before = len(
268+
effective_group_to_raw_group_mapping
269+
)
270+
effective_group_to_raw_group_mapping = (
271+
self.calculate_effective_group_to_raw_group_mapping(step=step)
272+
)
273+
assert (
274+
len(effective_group_to_raw_group_mapping)
275+
== effective_group_to_raw_group_mapping_size_before
276+
- self.dp_size * self.minibatch_size
277+
)
278+
if loss is not None:
279+
pbar.set_postfix({"loss": loss})
280+
i += 1
281+
213282
if self.lr_scheduler is not None:
214283
self.lr_scheduler.step()
215284
if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode:
@@ -220,13 +289,16 @@ def loop(self) -> None:
220289
if self.rank == 0:
221290
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
222291

223-
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
292+
if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and (
293+
episode != 0 or step >= self.n_behind
294+
):
224295
if self.pp_size > 1:
225296
print(
226297
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
227298
)
228299
else:
229300
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
301+
self.profiler.enter("sync_model")
230302
torch.cuda.empty_cache()
231303
state_dict = self.state_dict()
232304
if self.pp_size > 1:
@@ -244,6 +316,13 @@ def loop(self) -> None:
244316
)
245317
del state_dict
246318
torch.cuda.empty_cache()
319+
self.profiler.exit("sync_model")
320+
self.profiler.log(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
321+
self.profiler.exit(f"rollout_episode_{episode}_step_{step}")
322+
323+
def __del__(self):
324+
if hasattr(self, "profiler"):
325+
self.profiler.close()
247326

248327

249328
@ray.remote

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def __init__(
3838
project_name: str = None,
3939
run_name: str = None,
4040
wandb_group_name: str = None,
41+
enable_profiling: bool = False,
42+
n_behind: int = 0,
4143
):
4244
print(f"Using GRPO config: {grpo_config}")
4345
if (
@@ -63,6 +65,8 @@ def __init__(
6365
minibatch_size,
6466
save_interval=save_interval,
6567
save_dir=save_dir,
68+
enable_profiling=enable_profiling,
69+
n_behind=n_behind,
6670
)
6771
path = model_config.pop("path")
6872
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def launch_distributed(
5757
eval_generation_config: Optional[Dict[str, Any]] = None,
5858
log_rollout_interval: int = 20,
5959
rollout_save_dir: str = "./rollout",
60+
enable_profiling: bool = False,
61+
n_behind: int = 0,
6062
):
6163
if core_algo not in ALGO_MAP:
6264
raise NotImplementedError(f"{core_algo} is not supported yet.")
@@ -132,6 +134,8 @@ def launch_distributed(
132134
wandb_group_name=wandb_group_name,
133135
log_rollout_interval=log_rollout_interval,
134136
rollout_log_file=rollout_log_file,
137+
enable_profiling=enable_profiling,
138+
n_behind=n_behind,
135139
)
136140
producer_procs.append(producer)
137141
ray.get([p.setup.remote() for p in producer_procs])
@@ -171,6 +175,8 @@ def launch_distributed(
171175
project_name=project_name,
172176
run_name=run_name,
173177
wandb_group_name=wandb_group_name,
178+
enable_profiling=enable_profiling,
179+
n_behind=n_behind,
174180
)
175181
consumer_procs.append(consumer)
176182
ray.get([p.setup.remote() for p in consumer_procs])

0 commit comments

Comments
 (0)