Skip to content

Commit 880d886

Browse files
committed
make sync model async
1 parent e1a38e7 commit 880d886

File tree

6 files changed

+215
-74
lines changed

6 files changed

+215
-74
lines changed

applications/ColossalChat/coati/distributed/comm.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import copy
12
from typing import Any, Dict
23

4+
import ray
35
import ray.util.collective as cc
46
import torch
57
import torch.distributed.distributed_c10d as c10d
@@ -32,7 +34,11 @@ def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str =
3234

3335

3436
def ray_broadcast_tensor_dict(
35-
tensor_dict: Dict[str, torch.Tensor], src: int = 0, device=None, group_name: str = "default"
37+
tensor_dict: Dict[str, torch.Tensor],
38+
src: int = 0,
39+
device=None,
40+
group_name: str = "default",
41+
offload_to_cpu: bool = False,
3642
) -> Dict[str, torch.Tensor]:
3743
rank = cc.get_rank(group_name)
3844
if rank == src:
@@ -46,12 +52,65 @@ def ray_broadcast_tensor_dict(
4652
out_dict = {}
4753
for k, shape, dtype in metadata:
4854
if rank == src:
49-
tensor = tensor_dict[k]
55+
if offload_to_cpu:
56+
tensor = tensor_dict[k].to(device)
57+
else:
58+
tensor = tensor_dict[k]
5059
else:
5160
tensor = torch.empty(shape, dtype=dtype, device=device)
5261
cc.broadcast(tensor, src, group_name)
5362
if rank != src:
54-
out_dict[k] = tensor
63+
if offload_to_cpu:
64+
out_dict[k] = tensor.cpu()
65+
else:
66+
out_dict[k] = tensor
5567
if rank == src:
5668
out_dict = tensor_dict
5769
return out_dict
70+
71+
72+
@ray.remote
73+
class SharedVariableActor:
74+
def __init__(self):
75+
# double queues
76+
self.data_queue = None
77+
self.data_queue_buffered = None
78+
self.model_weights = None
79+
self.data_access_count = 0
80+
self.ready_process_count = {}
81+
82+
def increase_ready_process_count(self, name):
83+
self.ready_process_count = {k: v for k, v in self.ready_process_count.items() if k > name - 5}
84+
if name not in self.ready_process_count:
85+
self.ready_process_count[name] = 0
86+
self.ready_process_count[name] += 1
87+
88+
def get_ready_process_count(self, name):
89+
return self.ready_process_count[name]
90+
91+
def extend_data(self, data):
92+
if self.data_access_count > 0:
93+
# update the buffered data if data is not being accessed by all consumers
94+
# if producer are too fast, will not overwrite the data but extend the data
95+
if self.data_queue_buffered is None:
96+
self.data_queue_buffered = []
97+
self.data_queue_buffered.extend(data)
98+
return True
99+
if self.data_queue is None:
100+
self.data_queue = []
101+
self.data_queue.extend(data)
102+
self.data_access_count = 0
103+
return True
104+
105+
def get_data(self):
106+
if self.data_queue is None:
107+
return None
108+
data = copy.deepcopy(self.data_queue)
109+
self.data_access_count += 1
110+
if self.data_access_count == 4:
111+
# data in data_queue has been accessed by all consumers
112+
# swap the data queue with the buffered data, erase the old data
113+
if self.data_queue_buffered is not None:
114+
self.data_queue = self.data_queue_buffered
115+
self.data_queue_buffered = None
116+
return data

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 64 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import os
2+
import threading
3+
import time
24
from contextlib import nullcontext
35
from typing import Any, Dict, Optional
46

@@ -16,13 +18,15 @@
1618
from colossalai.nn.optimizer import HybridAdam
1719
from colossalai.utils import get_current_device
1820

19-
from .comm import ray_broadcast_tensor_dict
21+
from .comm import SharedVariableActor, ray_broadcast_tensor_dict
2022
from .utils import bind_batch, post_recv, unbind_batch
2123

2224

2325
class BaseConsumer:
2426
def __init__(
2527
self,
28+
shared_sync_data_actor: SharedVariableActor,
29+
shared_sync_model_actor: SharedVariableActor,
2630
num_producers: int,
2731
num_episodes: int,
2832
rank: int,
@@ -63,6 +67,13 @@ def __init__(
6367
self.lr_scheduler = None
6468
self.n_behind = n_behind
6569

70+
# for running sync data and model in separate actors/threads
71+
self.shared_sync_data_actor = shared_sync_data_actor
72+
self.shared_sync_model_actor = shared_sync_model_actor
73+
self.thread_started = False
74+
self.model_sync_step = 0
75+
self.state_dict_cpu = {}
76+
6677
def setup(self) -> None:
6778
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
6879

@@ -85,6 +96,7 @@ def setup(self) -> None:
8596
self.pp_size = dist.get_world_size(self.plugin.pp_group)
8697

8798
# Init Hybrid ray process group
99+
cc.init_collective_group(self.world_size, self.rank, group_name="consumer_pg")
88100
for i in range(self.num_producers):
89101
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
90102
if self.pp_size > 1:
@@ -152,44 +164,12 @@ def loop(self) -> None:
152164
torch.cuda.reset_peak_memory_stats()
153165
i = 0
154166
for _ in range(self.num_recv_per_update):
155-
# after sync model, do not wait for more data to arrive as rollout takes time, use buffered data
156-
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
157-
while len(effective_group_to_raw_group_mapping) > max(
158-
self.dp_size * self.batch_size
159-
- self.dp_size
160-
* self.minibatch_size
161-
* self.grpo_config.get("num_minibatch_during_rollout", 1),
162-
self.dp_size * self.minibatch_size,
163-
):
164-
self.profiler.log(
165-
f"Still have {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training"
166-
)
167-
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
168-
effective_group_to_raw_group_mapping
169-
)
170-
self.profiler.enter("step")
171-
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
172-
self.profiler.exit("step")
173-
self.buffer = self.buffer[
174-
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
175-
]
176-
# recalculate the effective group to raw group mapping
177-
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
178-
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping()
179-
assert (
180-
len(effective_group_to_raw_group_mapping)
181-
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
182-
)
183-
if loss is not None:
184-
pbar.set_postfix({"loss": loss})
185-
i += 1
186-
187167
# receive data from producers
188168
for r in range(self.num_producers):
189169
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
190170
self.profiler.enter(f"recv_broadcast_data_P{r}")
191171
raw_batch = ray_broadcast_tensor_dict(
192-
None, src=0, device=self.device, group_name=f"sync_data_{r}"
172+
None, src=0, device=self.device, group_name=f"sync_data_{r}", offload_to_cpu=False
193173
)
194174
self.profiler.exit(f"recv_broadcast_data_P{r}")
195175
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
@@ -238,10 +218,7 @@ def loop(self) -> None:
238218
f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"
239219
)
240220

241-
while len(effective_group_to_raw_group_mapping) > self.dp_size * self.batch_size:
242-
self.profiler.log(
243-
f"Received {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.batch_size}, start training after recv"
244-
)
221+
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
245222
# always keep at least dp_size * batch_size effective samples in the buffer for training during the rollout times after each sync model
246223
# on each dp_rank, we use minibatch_size effective samples to form a batch
247224
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
@@ -273,34 +250,67 @@ def loop(self) -> None:
273250
if self.rank == 0:
274251
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
275252

276-
if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and (
277-
episode != 0 or step >= self.n_behind
278-
):
253+
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
279254
if self.pp_size > 1:
280255
print(
281256
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
282257
)
283258
else:
284259
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
285-
self.profiler.enter("sync_model")
286260
torch.cuda.empty_cache()
287-
state_dict = self.state_dict()
261+
self.state_dict_cpu = {k: v.cpu() for k, v in self.state_dict().items()}
262+
cc.barrier(group_name="consumer_pg")
288263
if self.pp_size > 1:
289264
if self.tp_rank == 0 and self.dp_rank == 0:
265+
self.profiler.enter("sync_model")
290266
ray_broadcast_tensor_dict(
291-
state_dict,
267+
self.state_dict_cpu,
292268
src=self.num_producers,
293269
device=self.device,
294270
group_name=f"sync_model_{self.pp_rank}",
271+
offload_to_cpu=True,
295272
)
273+
self.profiler.exit("sync_model")
296274
else:
297275
if self.rank == 0:
298-
ray_broadcast_tensor_dict(
299-
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
300-
)
301-
del state_dict
302-
torch.cuda.empty_cache()
303-
self.profiler.exit("sync_model")
276+
# ray_broadcast_tensor_dict(
277+
# self.state_dict_cpu, src=self.num_producers, device=self.device, group_name="sync_model", offload_to_cpu=True
278+
# )
279+
if not self.thread_started:
280+
281+
def broadcast_state_dict():
282+
self.thread_started = True
283+
self.profiler.enter("sync_model")
284+
# lazy broadcast state_dict if and only if both consumer and all producers are idle (not broadcasting the last state_dict)
285+
ray.get(
286+
self.shared_sync_model_actor.increase_ready_process_count.remote(
287+
name=self.model_sync_step
288+
)
289+
)
290+
ready_process_count = ray.get(
291+
self.shared_sync_model_actor.get_ready_process_count.remote(
292+
name=self.model_sync_step
293+
)
294+
)
295+
while ready_process_count != self.num_producers + 1:
296+
time.sleep(0.5)
297+
ready_process_count = ray.get(
298+
self.shared_sync_model_actor.get_ready_process_count.remote(
299+
name=self.model_sync_step
300+
)
301+
)
302+
ray_broadcast_tensor_dict(
303+
self.state_dict_cpu,
304+
src=self.num_producers,
305+
device=self.device,
306+
group_name="sync_model",
307+
offload_to_cpu=True,
308+
)
309+
self.model_sync_step += 1
310+
self.thread_started = False
311+
self.profiler.exit("sync_model")
312+
313+
threading.Thread(target=broadcast_state_dict, daemon=True).start()
304314
self.profiler.log(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
305315

306316
def __del__(self):
@@ -312,6 +322,8 @@ def __del__(self):
312322
class SimpleConsumer(BaseConsumer):
313323
def __init__(
314324
self,
325+
shared_sync_data_actor: SharedVariableActor,
326+
shared_sync_model_actor: SharedVariableActor,
315327
num_producers,
316328
num_episodes,
317329
rank,
@@ -328,6 +340,8 @@ def __init__(
328340
save_dir="./model",
329341
):
330342
super().__init__(
343+
shared_sync_data_actor,
344+
shared_sync_model_actor,
331345
num_producers,
332346
num_episodes,
333347
rank,

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import ray
55
import torch
66
import wandb
7+
from coati.distributed.comm import SharedVariableActor
78
from coati.distributed.consumer import BaseConsumer
89
from coati.distributed.loss import PolicyLoss
910
from coati.distributed.utils import memory_efficient_logprob
@@ -18,6 +19,8 @@
1819
class GRPOConsumer(BaseConsumer):
1920
def __init__(
2021
self,
22+
shared_sync_data_actor: SharedVariableActor,
23+
shared_sync_model_actor: SharedVariableActor,
2124
num_producers,
2225
num_episodes,
2326
rank,
@@ -51,6 +54,8 @@ def __init__(
5154
1, grpo_config.get("train_microbatch_size") // plugin_config.get("pp_size", 1)
5255
)
5356
super().__init__(
57+
shared_sync_data_actor,
58+
shared_sync_model_actor,
5459
num_producers,
5560
num_episodes,
5661
rank,

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import ray
77

8+
from .comm import SharedVariableActor
89
from .consumer import SimpleConsumer
910
from .grpo_consumer import GRPOConsumer
1011
from .producer import SimpleProducer
@@ -87,6 +88,10 @@ def launch_distributed(
8788
# allocating the producer to nodes with lower node id and the consumer to the resouces from nodes with higher
8889
# node id. See the reference here: https://docs.ray.io/en/latest/ray-core/scheduling/index.html#nodeaffinityschedulingstrategy
8990
nodes = ray.nodes()
91+
92+
shared_sync_data_actor = SharedVariableActor.remote()
93+
shared_sync_model_actor = SharedVariableActor.remote()
94+
9095
node_info = {
9196
node["NodeID"]: {
9297
"num_gpus": node["Resources"].get("GPU", 0),
@@ -111,6 +116,8 @@ def launch_distributed(
111116
gpu_to_ip_address.pop(0)
112117
print(f"Schedual Producer P[{i}] which requires {num_proc_per_producer} GPUs on node {producer_ip_address}")
113118
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
119+
shared_sync_data_actor=shared_sync_data_actor,
120+
shared_sync_model_actor=shared_sync_model_actor,
114121
producer_idx=i,
115122
num_producers=num_producers,
116123
num_consumer_procs=num_consumer_procs,
@@ -155,6 +162,8 @@ def launch_distributed(
155162
gpu_to_ip_address.pop(0)
156163
print(f"Schedual Consumer T[{i}] which requires 1 GPUs on node {consumer_ip_address}")
157164
consumer = core_consumer.options(num_gpus=1).remote(
165+
shared_sync_data_actor=shared_sync_data_actor,
166+
shared_sync_model_actor=shared_sync_model_actor,
158167
num_producers=num_producers,
159168
num_episodes=num_episodes,
160169
rank=i,

0 commit comments

Comments
 (0)