Skip to content

Commit 509274c

Browse files
committed
add code for zero-bubble implementation
1 parent b1f646c commit 509274c

File tree

8 files changed

+2267
-11
lines changed

8 files changed

+2267
-11
lines changed

applications/ColossalChat/coati/distributed/comm.py

Lines changed: 106 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Dict
2-
2+
import copy
3+
import ray
34
import ray.util.collective as cc
45
import torch
56
import torch.distributed.distributed_c10d as c10d
@@ -30,28 +31,122 @@ def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str =
3031
obj = c10d._tensor_to_object(obj, size_tensor.item())
3132
return obj
3233

33-
3434
def ray_broadcast_tensor_dict(
35-
tensor_dict: Dict[str, torch.Tensor], src: int = 0, device=None, group_name: str = "default"
35+
tensor_dict: Dict[str, torch.Tensor],
36+
src: int = 0,
37+
device=None,
38+
group_name: str = "default",
39+
backend: str = "nccl",
40+
offload_to_cpu: bool = False,
41+
pin_memory: bool = False,
3642
) -> Dict[str, torch.Tensor]:
3743
rank = cc.get_rank(group_name)
44+
if tensor_dict is None:
45+
tensor_dict = {}
3846
if rank == src:
3947
metadata = []
4048
for k, v in tensor_dict.items():
4149
metadata.append((k, v.shape, v.dtype))
4250
else:
4351
metadata = None
4452
metadata = ray_broadcast_object(metadata, src, device, group_name)
45-
if rank != src:
46-
out_dict = {}
4753
for k, shape, dtype in metadata:
4854
if rank == src:
49-
tensor = tensor_dict[k]
55+
if offload_to_cpu:
56+
tensor = tensor_dict[k].to(device)
57+
else:
58+
tensor = tensor_dict[k]
5059
else:
51-
tensor = torch.empty(shape, dtype=dtype, device=device)
60+
tensor = tensor_dict.get(k, torch.zeros(shape, dtype=dtype, device=device, pin_memory=pin_memory))
61+
if backend == "gloo" and dtype == torch.bfloat16:
62+
# Gloo does not support bfloat16, convert to float16
63+
tensor = tensor.view(torch.float16)
5264
cc.broadcast(tensor, src, group_name)
65+
if backend == "gloo" and dtype == torch.bfloat16:
66+
# Convert back to bfloat16 if it was converted to float16
67+
tensor = tensor.view(torch.bfloat16)
5368
if rank != src:
54-
out_dict[k] = tensor
55-
if rank == src:
56-
out_dict = tensor_dict
57-
return out_dict
69+
if offload_to_cpu:
70+
tensor_dict[k] = tensor.cpu()
71+
else:
72+
tensor_dict[k] = tensor
73+
return tensor_dict
74+
75+
76+
@ray.remote
77+
class SharedVariableActor:
78+
def __init__(self, number_of_readers: int = 0, buffer_size_limit: int = 1000):
79+
self.data_queue = []
80+
self.data_uid = 0
81+
self.number_of_readers = number_of_readers
82+
self.queue_size = 0
83+
self.signals = {}
84+
self.process_locks = {}
85+
self.signal_procs_meet_count = {}
86+
self.buffer_size_limit = buffer_size_limit
87+
88+
def pickup_rollout_task(self, num_tasks: int):
89+
"""
90+
use queue size to control whether producers should generating new rollouts or wait
91+
for consumer to consumer more data. if queue size is less than threshold,
92+
it means consumer is consuming data fast enough, so producers can generate new rollouts.
93+
if queue size is greater than threshold, it means consumer is consuming data slowly,
94+
so producers should wait for consumer to consume more data.
95+
96+
Any free producer can pick up the task to generate rollout then increase the queued_data_size
97+
to prevent other producer to pick up the task redundantly, Note it is not the real
98+
queue length as data may still be generating
99+
"""
100+
ret = False
101+
if self.queue_size < self.buffer_size_limit:
102+
ret = True
103+
self.queue_size += num_tasks
104+
return ret
105+
106+
def append_data(self, data):
107+
self.data_queue.append([self.data_uid, data, 0]) # [data_uid, data, access_count]
108+
self.data_uid += 1
109+
return True
110+
111+
def get_data(self, data_uid: int):
112+
# for multi-process data reading
113+
if not self.data_queue:
114+
# no data in the queue, return None
115+
return None
116+
to_pop_index = None
117+
ret = None
118+
for i, (uid, data, access_count) in enumerate(self.data_queue):
119+
if uid == data_uid:
120+
# found the data with the given uid
121+
self.data_queue[i][2] += 1
122+
ret = copy.deepcopy(data)
123+
if self.data_queue[i][2] == self.number_of_readers:
124+
to_pop_index = i
125+
break
126+
if to_pop_index is not None:
127+
# remove the data from the queue if it has been accessed by all readers
128+
self.data_queue.pop(to_pop_index)
129+
self.queue_size -= data["input_ids"].size(0)
130+
return ret
131+
132+
def acquire_process_lock(self, key: str):
133+
# atomic lock for process
134+
if key not in self.process_locks:
135+
self.process_locks[key] = 1 # locked
136+
return 0
137+
if self.process_locks[key] == 0:
138+
self.process_locks[key] = 1 # lock the process
139+
return 0
140+
else:
141+
return 1
142+
143+
def release_process_lock(self, key: str):
144+
# atomic unlock for process
145+
assert self.process_locks.get(key, 0) == 1, f"Releasing a process lock {key} that is not locked."
146+
self.process_locks[key] = 0
147+
148+
def set_signal(self, key: str, signal: str):
149+
self.signals[key] = signal
150+
151+
def get_signal(self):
152+
return self.signals

0 commit comments

Comments
 (0)