Skip to content

Commit 150ca20

Browse files
committed
offload sync model to threads
1 parent 522f664 commit 150ca20

File tree

8 files changed

+250
-179
lines changed

8 files changed

+250
-179
lines changed

applications/ColossalChat/coati/distributed/comm.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,25 +40,26 @@ def ray_broadcast_tensor_dict(
4040
group_name: str = "default",
4141
backend: str = "nccl",
4242
offload_to_cpu: bool = False,
43+
pin_memory: bool = False,
4344
) -> Dict[str, torch.Tensor]:
4445
rank = cc.get_rank(group_name)
46+
if tensor_dict is None:
47+
tensor_dict = {}
4548
if rank == src:
4649
metadata = []
4750
for k, v in tensor_dict.items():
4851
metadata.append((k, v.shape, v.dtype))
4952
else:
5053
metadata = None
5154
metadata = ray_broadcast_object(metadata, src, device, group_name)
52-
if rank != src:
53-
out_dict = {}
5455
for k, shape, dtype in metadata:
5556
if rank == src:
5657
if offload_to_cpu:
5758
tensor = tensor_dict[k].to(device)
5859
else:
5960
tensor = tensor_dict[k]
6061
else:
61-
tensor = torch.empty(shape, dtype=dtype, device=device)
62+
tensor = tensor_dict.get(k, torch.zeros(shape, dtype=dtype, device=device, pin_memory=pin_memory))
6263
if backend == "gloo" and dtype == torch.bfloat16:
6364
# Gloo does not support bfloat16, convert to float16
6465
tensor = tensor.view(torch.float16)
@@ -68,26 +69,41 @@ def ray_broadcast_tensor_dict(
6869
tensor = tensor.view(torch.bfloat16)
6970
if rank != src:
7071
if offload_to_cpu:
71-
out_dict[k] = tensor.cpu()
72+
tensor_dict[k] = tensor.cpu()
7273
else:
73-
out_dict[k] = tensor
74-
if rank == src:
75-
out_dict = tensor_dict
76-
return out_dict
74+
tensor_dict[k] = tensor
75+
return tensor_dict
7776

7877

7978
@ray.remote
8079
class SharedVariableActor:
81-
def __init__(self, number_of_readers: int = 1):
80+
def __init__(self, number_of_readers: int = 0, buffer_size_limit: int = 1000):
8281
self.data_queue = []
8382
self.data_uid = 0
8483
self.number_of_readers = number_of_readers
84+
self.queue_size = 0
8585
self.signals = {}
86+
self.process_locks = {}
8687
self.signal_procs_meet_count = {}
88+
self.buffer_size_limit = buffer_size_limit
8789

88-
def get_queued_data_size(self):
89-
queued_data_size = sum([data[1]["input_ids"].size(0) for data in self.data_queue])
90-
return queued_data_size
90+
def pickup_rollout_task(self, num_tasks: int):
91+
"""
92+
use queue size to control whether producers should generating new rollouts or wait
93+
for consumer to consumer more data. if queue size is less than threshold,
94+
it means consumer is consuming data fast enough, so producers can generate new rollouts.
95+
if queue size is greater than threshold, it means consumer is consuming data slowly,
96+
so producers should wait for consumer to consume more data.
97+
98+
Any free producer can pick up the task to generate rollout then increase the queued_data_size
99+
to prevent other producer to pick up the task redundantly, Note it is not the real
100+
queue length as data may still be generating
101+
"""
102+
ret = False
103+
if self.queue_size < self.buffer_size_limit:
104+
ret = True
105+
self.queue_size += num_tasks
106+
return ret
91107

92108
def append_data(self, data):
93109
self.data_queue.append([self.data_uid, data, 0]) # [data_uid, data, access_count]
@@ -112,8 +128,25 @@ def get_data(self, data_uid: int):
112128
if to_pop_index is not None:
113129
# remove the data from the queue if it has been accessed by all readers
114130
self.data_queue.pop(to_pop_index)
131+
self.queue_size -= data["input_ids"].size(0)
115132
return ret
116133

134+
def acquire_process_lock(self, key: str):
135+
# atomic lock for process
136+
if key not in self.process_locks:
137+
self.process_locks[key] = 1 # locked
138+
return 0
139+
if self.process_locks[key] == 0:
140+
self.process_locks[key] = 1 # lock the process
141+
return 0
142+
else:
143+
return 1
144+
145+
def release_process_lock(self, key: str):
146+
# atomic unlock for process
147+
assert self.process_locks.get(key, 0) == 1, f"Releasing a process lock {key} that is not locked."
148+
self.process_locks[key] = 0
149+
117150
def set_signal(self, key: str, signal: str):
118151
self.signals[key] = signal
119152

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 68 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import threading
23
import time
34
from typing import Any, Dict, Optional
45

@@ -54,6 +55,7 @@ def __init__(
5455
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
5556
self.num_microbatches = batch_size // minibatch_size
5657
self.data_uid = 0
58+
self.sync_model_thread_started = False
5759

5860
self.model_config = model_config
5961
self.plugin_config = plugin_config
@@ -64,7 +66,6 @@ def __init__(
6466
self.shared_sync_data_actor = shared_sync_data_actor
6567
self.shared_signal_actor = shared_signal_actor
6668
self.state_dict_cpu = {}
67-
self.next_data_source = 0 # used to track which producer to get data from next
6869

6970
def setup(self) -> None:
7071
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
@@ -183,7 +184,6 @@ def loop(self) -> None:
183184
raw_batch = ray.get(self.shared_sync_data_actor.get_data.remote(self.data_uid))
184185
continue
185186
self.data_uid += 1
186-
self.next_data_source = (self.next_data_source + 1) % self.num_producers
187187
raw_batch = {k: v.to(self.device) for k, v in raw_batch.items()}
188188
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
189189
# we need to calculate the metrics before filtering here for logging
@@ -253,6 +253,7 @@ def loop(self) -> None:
253253
if loss is not None:
254254
pbar.set_postfix({"loss": loss})
255255
need_sync_model = True
256+
ray.get(self.shared_signal_actor.set_signal.remote("global_step", self.global_step + 1))
256257
if need_sync_model and (
257258
(self.global_step + 1) % self.save_interval == 0
258259
or self.received_prompts >= self.train_dataset_size
@@ -269,49 +270,76 @@ def loop(self) -> None:
269270
if need_sync_model and (
270271
episode != self.num_episodes - 1 or self.received_prompts != self.train_dataset_size
271272
):
272-
# sync model weights to all producers, if no model update or it is the last training step, skip syncing
273-
if self.pp_size > 1:
274-
print(
275-
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {self.global_step}"
276-
)
277-
else:
278-
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {self.global_step}")
279-
torch.cuda.empty_cache()
280-
self.state_dict_cpu = {k: v.cpu() for k, v in self.state_dict().items()}
281-
cc.barrier(group_name="consumer_pg")
282-
if self.pp_size > 1:
283-
if self.tp_rank == 0 and self.dp_rank == 0:
284-
self.profiler.enter("sync_model")
285-
ray.get(
286-
self.shared_signal_actor.set_signal.remote(
287-
f"consumer_pp_{self.pp_rank}", "ready_sync_model"
288-
)
289-
)
273+
274+
def sync_model_thread():
275+
# sync model weights to all producers, if no model update or it is the last training step, skip syncing
276+
if self.pp_size > 1:
290277
print(
291278
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {self.global_step}"
292279
)
293-
ray_broadcast_tensor_dict(
294-
self.state_dict_cpu,
295-
src=0,
296-
device=torch.device("cpu"),
297-
group_name=f"sync_model_consumer_pp_{self.pp_rank}",
298-
backend="gloo",
299-
)
300-
self.profiler.exit("sync_model")
301-
else:
302-
if self.rank == 0:
303-
self.profiler.enter("sync_model")
304-
ray.get(self.shared_signal_actor.set_signal.remote("consumer", "ready_sync_model"))
280+
else:
305281
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {self.global_step}")
306-
ray_broadcast_tensor_dict(
307-
self.state_dict_cpu,
308-
src=0,
309-
device=torch.device("cpu"),
310-
group_name="sync_model_consumer",
311-
backend="gloo",
312-
)
313-
self.profiler.exit("sync_model")
282+
torch.cuda.empty_cache()
283+
if self.pp_size > 1:
284+
if self.tp_rank == 0 and self.dp_rank == 0:
285+
self.profiler.enter("sync_model")
286+
ray.get(
287+
self.shared_signal_actor.set_signal.remote(
288+
f"consumer_pp_{self.pp_rank}", "ready_sync_model"
289+
)
290+
)
291+
print(
292+
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {self.global_step}"
293+
)
294+
ray_broadcast_tensor_dict(
295+
self.state_dict_cpu,
296+
src=0,
297+
device=torch.device("cpu"),
298+
group_name=f"sync_model_consumer_pp_{self.pp_rank}",
299+
backend="gloo",
300+
)
301+
self.profiler.exit("sync_model")
302+
else:
303+
if self.rank == 0:
304+
self.profiler.enter("sync_model")
305+
ray.get(self.shared_signal_actor.set_signal.remote("consumer", "ready_sync_model"))
306+
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {self.global_step}")
307+
ray_broadcast_tensor_dict(
308+
self.state_dict_cpu,
309+
src=0,
310+
device=torch.device("cpu"),
311+
group_name="sync_model_consumer",
312+
backend="gloo",
313+
)
314+
self.profiler.exit("sync_model")
315+
316+
if not self.sync_model_thread_started:
317+
# only sync model when the thread is not started and no other thread is broadcasting
318+
self.sync_model_thread_started = True
319+
state_dict_ = self.state_dict()
320+
if (self.pp_size > 1 and self.tp_rank == 0 and self.dp_rank == 0) or (
321+
self.pp_size == 1 and self.rank == 0
322+
):
323+
if len(self.state_dict_cpu) == 0:
324+
# use pinned memory to speed up the transfer
325+
self.state_dict_cpu = {k: v.cpu().pin_memory() for k, v in state_dict_.items()}
326+
torch.cuda.synchronize()
327+
for k, v in state_dict_.items():
328+
self.state_dict_cpu[k].copy_(v, non_blocking=True)
329+
torch.cuda.synchronize()
330+
cc.barrier(
331+
group_name="consumer_pg"
332+
) # to make sure all ranks have state dict offloaded to CPU before starting the thread
333+
time_before_starting_thread = time.time()
334+
threading.Thread(target=sync_model_thread).start()
335+
# sync_model_thread()
336+
self.profiler.log(
337+
f"Sync model, took {time.time() - time_before_starting_thread:.2f} seconds"
338+
)
339+
self.sync_model_thread_started = False
340+
# ray.get(self.shared_signal_actor.release_process_lock.remote("broadcasting_lock"))
314341
self.profiler.log(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
342+
self.received_prompts = 0
315343
ray.get(self.shared_signal_actor.set_signal.remote("consumer", "terminate"))
316344

317345
def __del__(self):

applications/ColossalChat/coati/distributed/distributor.py

Lines changed: 15 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,11 @@ def __init__(
2222
):
2323
self.distributor_id = distributor_id
2424
self.consumer_pp_size = consumer_pp_size
25-
self.state_dict_cpu = {i: {"not_ready_sync_model": torch.ones((1)).cpu()} for i in range(self.consumer_pp_size)}
25+
self.state_dict_cpu = {}
2626
self.num_producers = num_producers
2727
self.shared_signal_actor = shared_signal_actor
2828
self.device = get_current_device()
2929
self.profiler = CustomProfiler(f"D{self.distributor_id}", disabled=not enable_profiling)
30-
self.weight_version = {i: 0 for i in range(self.consumer_pp_size)}
31-
self.producer_weight_version = {
32-
j: {f"producer_{i}": 0 for i in range(self.num_producers)} for j in range(self.consumer_pp_size)
33-
}
3430

3531
def init_collective_group(
3632
self,
@@ -64,7 +60,6 @@ def loop(self):
6460
backend="gloo",
6561
)
6662
self.profiler.exit(f"sync_model_consumer_pp_{i}")
67-
self.weight_version[i] += 1
6863
for i in range(self.consumer_pp_size):
6964
if signal.get(f"producer_{self.distributor_id}_pp_{i}", None) == "ready_sync_model":
7065
self.profiler.enter(f"sync_model_producer_{self.distributor_id}_pp_{i}")
@@ -74,24 +69,13 @@ def loop(self):
7469
f"producer_{self.distributor_id}_pp_{i}", "not_ready_sync_model"
7570
)
7671
)
77-
if self.producer_weight_version[i][f"producer_{self.distributor_id}"] < self.weight_version[i]:
78-
self.producer_weight_version[i][f"producer_{self.distributor_id}"] = self.weight_version[i]
79-
ray_broadcast_tensor_dict(
80-
self.state_dict_cpu[i],
81-
1,
82-
device=torch.device("cpu"),
83-
group_name=f"sync_model_producer_{self.distributor_id}_pp_{i}",
84-
backend="gloo",
85-
)
86-
else:
87-
# broadcast a dummy tensor to save the communication cost
88-
ray_broadcast_tensor_dict(
89-
{"not_ready_sync_model": torch.ones((1)).cpu()},
90-
1,
91-
device=torch.device("cpu"),
92-
group_name=f"sync_model_producer_{self.distributor_id}_pp_{i}",
93-
backend="gloo",
94-
)
72+
ray_broadcast_tensor_dict(
73+
self.state_dict_cpu[i],
74+
1,
75+
device=torch.device("cpu"),
76+
group_name=f"sync_model_producer_{self.distributor_id}_pp_{i}",
77+
backend="gloo",
78+
)
9579
self.profiler.exit(f"sync_model_producer_{self.distributor_id}_pp_{i}")
9680
else:
9781
if signal.get("consumer", None) == "ready_sync_model":
@@ -103,7 +87,6 @@ def loop(self):
10387
None, 0, device=torch.device("cpu"), group_name="sync_model_consumer", backend="gloo"
10488
)
10589
self.profiler.exit("sync_model_consumer")
106-
self.weight_version[0] += 1
10790
if signal.get(f"producer_{self.distributor_id}", None) == "ready_sync_model":
10891
self.profiler.enter(f"sync_model_producer_{self.distributor_id}")
10992
# Broadcast the model state dict to all producers
@@ -112,24 +95,13 @@ def loop(self):
11295
f"producer_{self.distributor_id}", "not_ready_sync_model"
11396
)
11497
)
115-
if self.producer_weight_version[0][f"producer_{self.distributor_id}"] < self.weight_version[0]:
116-
self.producer_weight_version[0][f"producer_{self.distributor_id}"] = self.weight_version[0]
117-
ray_broadcast_tensor_dict(
118-
self.state_dict_cpu,
119-
1,
120-
device=torch.device("cpu"),
121-
group_name=f"sync_model_producer_{self.distributor_id}",
122-
backend="gloo",
123-
)
124-
else:
125-
# broadcast a dummy tensor to save the communication cost
126-
ray_broadcast_tensor_dict(
127-
{"not_ready_sync_model": torch.ones((1)).cpu()},
128-
1,
129-
device=torch.device("cpu"),
130-
group_name=f"sync_model_producer_{self.distributor_id}",
131-
backend="gloo",
132-
)
98+
ray_broadcast_tensor_dict(
99+
self.state_dict_cpu,
100+
1,
101+
device=torch.device("cpu"),
102+
group_name=f"sync_model_producer_{self.distributor_id}",
103+
backend="gloo",
104+
)
133105
self.profiler.exit(f"sync_model_producer_{self.distributor_id}")
134106
if signal.get("consumer", None) == "terminate":
135107
self.profiler.log("terminate sync model worker")

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,5 +495,4 @@ def state_dict(self):
495495
self.policy_model._force_wait_all_gather()
496496
model = self.policy_model.unwrap()
497497
state_dict = model.state_dict()
498-
state_dict["consumer_global_step"] = torch.tensor([self.global_step], device=self.device)
499498
return state_dict

0 commit comments

Comments
 (0)