Skip to content

Commit ff6696a

Browse files
committed
support n_behind, add profiling
1 parent e3d56cb commit ff6696a

File tree

8 files changed

+233
-29
lines changed

8 files changed

+233
-29
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import ray.util.collective as cc
77
import torch
88
import torch.distributed as dist
9+
from coati.distributed.profiling_utils import CustomProfiler
910
from tqdm import tqdm
1011
from transformers import AutoModelForCausalLM
1112

@@ -36,6 +37,8 @@ def __init__(
3637
minibatch_size: int = 1,
3738
save_interval: int = 100,
3839
save_dir: str = "./model",
40+
enable_profiling: bool = False,
41+
n_behind: int = 0,
3942
):
4043
self.num_producers = num_producers
4144
self.num_episodes = num_episodes
@@ -49,6 +52,7 @@ def __init__(
4952
self.minibatch_size = minibatch_size
5053
self.save_interval = save_interval
5154
self.save_dir = save_dir
55+
self.enable_profiling = enable_profiling
5256
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
5357
self.num_microbatches = batch_size // minibatch_size
5458

@@ -57,6 +61,7 @@ def __init__(
5761

5862
self.device = get_current_device()
5963
self.lr_scheduler = None
64+
self.n_behind = n_behind
6065

6166
def setup(self) -> None:
6267
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
@@ -94,6 +99,7 @@ def setup(self) -> None:
9499

95100
self.buffer = []
96101
self.recv_cnt = 0
102+
self.profiler = CustomProfiler(f"C{self.rank}", disabled=not self.enable_profiling)
97103

98104
def state_dict(self) -> Dict[str, torch.Tensor]:
99105
raise NotImplementedError
@@ -112,14 +118,17 @@ def loop(self) -> None:
112118
disable=self.rank != 0,
113119
) as pbar:
114120
for step in pbar:
121+
torch.cuda.reset_peak_memory_stats()
115122
i = 0
116123
for _ in range(self.num_recv_per_update):
117124
# receive data from producers
118125
for r in range(self.num_producers):
119126
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
127+
self.profiler.enter(f"recv_broadcast_data_P{r}")
120128
raw_batch = ray_broadcast_tensor_dict(
121129
None, src=0, device=self.device, group_name=f"sync_data_{r}"
122130
)
131+
self.profiler.exit(f"recv_broadcast_data_P{r}")
123132
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
124133
# we need to calculate the metrics before filtering here for logging
125134
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
@@ -192,7 +201,9 @@ def loop(self) -> None:
192201
}
193202
batch = bind_batch([t[0] for t in batches])
194203
batch = post_recv(batch)
204+
self.profiler.enter("step")
195205
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
206+
self.profiler.exit("step")
196207
self.buffer = self.buffer[
197208
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
198209
]
@@ -221,13 +232,16 @@ def loop(self) -> None:
221232
if self.rank == 0:
222233
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
223234

224-
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
235+
if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and (
236+
episode != 0 or step >= self.n_behind
237+
):
225238
if self.pp_size > 1:
226239
print(
227240
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
228241
)
229242
else:
230243
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
244+
self.profiler.enter("sync_model")
231245
torch.cuda.empty_cache()
232246
state_dict = self.state_dict()
233247
if self.pp_size > 1:
@@ -245,6 +259,12 @@ def loop(self) -> None:
245259
)
246260
del state_dict
247261
torch.cuda.empty_cache()
262+
self.profiler.exit("sync_model")
263+
self.profiler.log(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
264+
265+
def __del__(self):
266+
if hasattr(self, "profiler"):
267+
self.profiler.close()
248268

249269

250270
@ray.remote

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def __init__(
3838
project_name: str = None,
3939
run_name: str = None,
4040
wandb_group_name: str = None,
41+
enable_profiling: bool = False,
42+
n_behind: int = 0,
4143
):
4244
print(f"Using GRPO config: {grpo_config}")
4345
if (
@@ -63,6 +65,8 @@ def __init__(
6365
minibatch_size,
6466
save_interval=save_interval,
6567
save_dir=save_dir,
68+
enable_profiling=enable_profiling,
69+
n_behind=n_behind,
6670
)
6771
path = model_config.pop("path")
6872
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def launch_distributed(
5757
eval_generation_config: Optional[Dict[str, Any]] = None,
5858
log_rollout_interval: int = 20,
5959
rollout_save_dir: str = "./rollout",
60+
enable_profiling: bool = False,
61+
n_behind: int = 0,
6062
):
6163
if core_algo not in ALGO_MAP:
6264
raise NotImplementedError(f"{core_algo} is not supported yet.")
@@ -132,6 +134,8 @@ def launch_distributed(
132134
wandb_group_name=wandb_group_name,
133135
log_rollout_interval=log_rollout_interval,
134136
rollout_log_file=rollout_log_file,
137+
enable_profiling=enable_profiling,
138+
n_behind=n_behind,
135139
)
136140
producer_procs.append(producer)
137141
ray.get([p.setup.remote() for p in producer_procs])
@@ -171,6 +175,8 @@ def launch_distributed(
171175
project_name=project_name,
172176
run_name=run_name,
173177
wandb_group_name=wandb_group_name,
178+
enable_profiling=enable_profiling,
179+
n_behind=n_behind,
174180
)
175181
consumer_procs.append(consumer)
176182
ray.get([p.setup.remote() for p in consumer_procs])

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import tqdm
1010
import wandb
1111
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
12+
from coati.distributed.profiling_utils import CustomProfiler
1213
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
1314
from coati.distributed.reward.verifiable_reward import VerifiableReward
1415
from ray.util.collective import allreduce
@@ -52,6 +53,8 @@ def __init__(
5253
wandb_group_name: str = None,
5354
log_rollout_interval: int = 20,
5455
rollout_log_file: str = "./rollout_log.jsonl",
56+
enable_profiling: bool = False,
57+
n_behind: int = 0,
5558
):
5659
self.producer_idx = producer_idx
5760
self.num_producers = num_producers
@@ -62,6 +65,7 @@ def __init__(
6265
assert batch_size % microbatch_size == 0
6366
self.num_microbatches = batch_size // microbatch_size
6467
self.latest_eval_step = -1
68+
self.profiler = CustomProfiler(f"P{self.producer_idx}", disabled=not enable_profiling)
6569

6670
self.train_dataset_config = train_dataset_config
6771
self.model_config = model_config
@@ -75,6 +79,7 @@ def __init__(
7579
self.log_rollout_interval = log_rollout_interval
7680
self.latest_rollout_log_step = -1
7781
self.grpo_config = grpo_config
82+
self.n_behind = n_behind
7883
reward_model_kwargs = {
7984
k: v
8085
for k, v in grpo_config.items()
@@ -268,11 +273,14 @@ def loop(self) -> None:
268273
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
269274
self.eval_mode = False
270275
self.latest_eval_step = self.consumer_global_step
276+
self.profiler.enter("rollout")
271277
outputs = self.rollout(**batch)
278+
self.profiler.exit("rollout")
272279
outputs["temperature"] = torch.tensor(
273280
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
274281
).to(outputs["input_ids"].device)
275282
bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1)
283+
self.profiler.enter("calculate_reward")
276284
if self.grpo_config["reward_fn_type"] == "code":
277285
test_cases = []
278286
for prompt_id in range(bs):
@@ -310,20 +318,26 @@ def loop(self) -> None:
310318
outputs.pop("gt_answer")
311319
if "test_cases" in outputs:
312320
outputs.pop("test_cases")
321+
self.profiler.exit("calculate_reward")
313322

314323
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
315324
outputs = pre_send(outputs)
325+
self.profiler.enter("send_broadcast_data")
316326
ray_broadcast_tensor_dict(
317327
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
318328
)
319-
if (i + 1) % self.num_microbatches == 0 and (
320-
episode != self.num_episodes - 1 or i != num_valid_microbatches - 1
329+
self.profiler.exit("send_broadcast_data")
330+
if (
331+
(i + 1) % self.num_microbatches == 0
332+
and (episode != self.num_episodes - 1 or i != num_valid_microbatches - 1)
333+
and (episode != 0 or (i + 1) > self.n_behind * self.num_microbatches)
321334
):
322335
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
323336
"enable_sleep_mode", False
324337
):
325338
self.model.llm.sleep() # revict KV_cache to avoid OOM
326339
# don't sync model for last iteration
340+
self.profiler.enter("sync_model")
327341
torch.cuda.empty_cache()
328342

329343
if self.consumer_pp_size > 1:
@@ -349,6 +363,7 @@ def loop(self) -> None:
349363
self.load_state_dict(state_dict)
350364
del state_dict
351365
torch.cuda.empty_cache()
366+
self.profiler.exit("sync_model")
352367
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
353368
"enable_sleep_mode", False
354369
):
@@ -364,6 +379,9 @@ def loop(self) -> None:
364379
"temperature"
365380
] + ratio * 0.9
366381

382+
def __del__(self):
383+
self.profiler.close()
384+
367385

368386
@ray.remote
369387
class SimpleProducer(BaseProducer):
@@ -392,6 +410,8 @@ def __init__(
392410
wandb_group_name: str = None,
393411
log_rollout_interval: int = 20,
394412
rollout_log_file: str = "./rollout_log.jsonl",
413+
enable_profiling: bool = False,
414+
n_behind: int = 0,
395415
):
396416
super().__init__(
397417
producer_idx,
@@ -415,6 +435,8 @@ def __init__(
415435
wandb_group_name=wandb_group_name,
416436
log_rollout_interval=log_rollout_interval,
417437
rollout_log_file=rollout_log_file,
438+
enable_profiling=enable_profiling,
439+
n_behind=n_behind,
418440
)
419441
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
420442
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
class CustomProfiler:
2+
def __init__(self, name, disabled=True):
3+
self.disabled = disabled
4+
if not disabled:
5+
self.name = name
6+
self.pid = os.getpid()
7+
self.file = open(f"{name}.prof", "w")
8+
9+
def _log(self, message):
10+
if self.disabled:
11+
return
12+
current_time = time.time()
13+
self.file.write(f"{current_time} {self.name} {self.pid}:: {message}\n")
14+
self.file.flush()
15+
16+
def log(self, message):
17+
if self.disabled:
18+
return
19+
current_time = time.time()
20+
self.file.write(f"[Log]: {current_time} {self.name} {self.pid}:: {message}\n")
21+
self.file.flush()
22+
23+
def enter(self, event_name):
24+
self._log(f"Enter {event_name}")
25+
26+
def exit(self, event_name):
27+
self._log(f"Exit {event_name}")
28+
29+
def close(self):
30+
if self.disabled:
31+
return
32+
self.file.close()
33+
print(f"Profiler data written to {self.name}.prof")
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
2+
3+
# 8K context length
4+
rm -rf *.prof
5+
MAX_NEW_TOKENS=$((8192-512))
6+
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 16 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -mnt $MAX_NEW_TOKENS -nb 1 --enable_profiling 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.txt
7+
python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.png

applications/ColossalChat/rl_example.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,27 @@
6767
default=2,
6868
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
6969
)
70+
parser.add_argument(
71+
"-tp",
72+
"--tensor-parallel-size",
73+
type=int,
74+
default=1,
75+
help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
76+
)
77+
parser.add_argument(
78+
"-pp",
79+
"--pipeline-parallel-size",
80+
type=int,
81+
default=1,
82+
help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
83+
)
84+
parser.add_argument(
85+
"-zero",
86+
"--zero-stage",
87+
type=int,
88+
default=0,
89+
help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.",
90+
)
7091
parser.add_argument(
7192
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
7293
)
@@ -97,6 +118,13 @@
97118
parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.")
98119
parser.add_argument("-mnt", "--max-new-tokens", type=int, default=1024 * 4 - 512, help="Max length for generation.")
99120
parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.")
121+
parser.add_argument(
122+
"-ptp",
123+
"--producer-tensor-parallel-size",
124+
type=int,
125+
default=1,
126+
help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.",
127+
)
100128

101129
# GRPO parameters
102130
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
@@ -117,6 +145,13 @@
117145
default=100,
118146
help="Interval for evaluation. Evaluate every ei training steps.",
119147
)
148+
parser.add_argument(
149+
"-nb",
150+
"--n-behind",
151+
type=int,
152+
default=0,
153+
help="Number of producer batches to rollout to fill the data buffer before trainer starts to decrease bubble time",
154+
)
120155

121156
# Logging/Checkpointing parameters
122157
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
@@ -128,32 +163,7 @@
128163
"-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
129164
)
130165
parser.add_argument(
131-
"-tp",
132-
"--tensor-parallel-size",
133-
type=int,
134-
default=1,
135-
help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
136-
)
137-
parser.add_argument(
138-
"-pp",
139-
"--pipeline-parallel-size",
140-
type=int,
141-
default=1,
142-
help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
143-
)
144-
parser.add_argument(
145-
"-zero",
146-
"--zero-stage",
147-
type=int,
148-
default=0,
149-
help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.",
150-
)
151-
parser.add_argument(
152-
"-ptp",
153-
"--producer-tensor-parallel-size",
154-
type=int,
155-
default=1,
156-
help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.",
166+
"--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process."
157167
)
158168
args = parser.parse_args()
159169

@@ -353,4 +363,6 @@
353363
eval_generation_config=eval_generation_config,
354364
log_rollout_interval=20,
355365
rollout_save_dir=args.rollout_save_dir,
366+
enable_profiling=args.enable_profiling,
367+
n_behind=args.n_behind,
356368
)

0 commit comments

Comments
 (0)