Skip to content

Commit 522f664

Browse files
committed
support gloo from comminication backend
1 parent 9ee4621 commit 522f664

File tree

8 files changed

+402
-257
lines changed

8 files changed

+402
-257
lines changed
Lines changed: 28 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
import copy
2-
import time
32
from typing import Any, Dict
43

54
import ray
65
import ray.util.collective as cc
76
import torch
87
import torch.distributed.distributed_c10d as c10d
9-
from coati.distributed.profiling_utils import CustomProfiler
108
from packaging.version import Version
119

12-
from colossalai.utils import get_current_device
13-
1410

1511
def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str = "default") -> Any:
1612
rank = cc.get_rank(group_name)
@@ -42,6 +38,7 @@ def ray_broadcast_tensor_dict(
4238
src: int = 0,
4339
device=None,
4440
group_name: str = "default",
41+
backend: str = "nccl",
4542
offload_to_cpu: bool = False,
4643
) -> Dict[str, torch.Tensor]:
4744
rank = cc.get_rank(group_name)
@@ -62,7 +59,13 @@ def ray_broadcast_tensor_dict(
6259
tensor = tensor_dict[k]
6360
else:
6461
tensor = torch.empty(shape, dtype=dtype, device=device)
62+
if backend == "gloo" and dtype == torch.bfloat16:
63+
# Gloo does not support bfloat16, convert to float16
64+
tensor = tensor.view(torch.float16)
6565
cc.broadcast(tensor, src, group_name)
66+
if backend == "gloo" and dtype == torch.bfloat16:
67+
# Convert back to bfloat16 if it was converted to float16
68+
tensor = tensor.view(torch.bfloat16)
6669
if rank != src:
6770
if offload_to_cpu:
6871
out_dict[k] = tensor.cpu()
@@ -77,155 +80,42 @@ def ray_broadcast_tensor_dict(
7780
class SharedVariableActor:
7881
def __init__(self, number_of_readers: int = 1):
7982
self.data_queue = []
80-
self.model_weights = None
81-
self.data_access_count = 0
82-
self.ready_process_count = {}
83+
self.data_uid = 0
8384
self.number_of_readers = number_of_readers
84-
self.consumer_buffer_size = 0
8585
self.signals = {}
86+
self.signal_procs_meet_count = {}
8687

8788
def get_queued_data_size(self):
88-
queued_data_size = sum([data["input_ids"].size(0) for data in self.data_queue])
89+
queued_data_size = sum([data[1]["input_ids"].size(0) for data in self.data_queue])
8990
return queued_data_size
9091

9192
def append_data(self, data):
92-
self.data_queue.append(data)
93+
self.data_queue.append([self.data_uid, data, 0]) # [data_uid, data, access_count]
94+
self.data_uid += 1
9395
return True
9496

95-
def get_data(self):
97+
def get_data(self, data_uid: int):
98+
# for multi-process data reading
9699
if not self.data_queue:
97100
# no data in the queue, return None
98101
return None
99-
data = copy.deepcopy(self.data_queue[0])
100-
self.data_access_count += 1
101-
if self.data_access_count == self.number_of_readers:
102-
# first data in data_queue has been accessed by all consumers
103-
# remove it from the queue
104-
self.data_queue.pop(0)
105-
self.data_access_count = 0
106-
return data
102+
to_pop_index = None
103+
ret = None
104+
for i, (uid, data, access_count) in enumerate(self.data_queue):
105+
if uid == data_uid:
106+
# found the data with the given uid
107+
self.data_queue[i][2] += 1
108+
ret = copy.deepcopy(data)
109+
if self.data_queue[i][2] == self.number_of_readers:
110+
to_pop_index = i
111+
break
112+
if to_pop_index is not None:
113+
# remove the data from the queue if it has been accessed by all readers
114+
self.data_queue.pop(to_pop_index)
115+
return ret
107116

108117
def set_signal(self, key: str, signal: str):
109118
self.signals[key] = signal
110119

111120
def get_signal(self):
112121
return self.signals
113-
114-
115-
@ray.remote
116-
class SharedVariableActorNCCL:
117-
def __init__(
118-
self, consumer_pp_size, num_producers, shared_signal_actor: SharedVariableActor, enable_profiling: bool = True
119-
):
120-
self.consumer_pp_size = consumer_pp_size
121-
self.state_dict_cpu = {i: {"not_ready_sync_model": torch.ones((1)).cpu()} for i in range(self.consumer_pp_size)}
122-
self.num_producers = num_producers
123-
self.shared_signal_actor = shared_signal_actor
124-
self.device = get_current_device()
125-
self.profiler = CustomProfiler(f"D", disabled=not enable_profiling)
126-
self.weight_version = {i: 0 for i in range(self.consumer_pp_size)}
127-
self.producer_weight_version = {
128-
j: {f"producer_{i}": 0 for i in range(self.num_producers)} for j in range(self.consumer_pp_size)
129-
}
130-
131-
def setup(self):
132-
if self.consumer_pp_size == 1:
133-
cc.init_collective_group(2, 1, group_name="sync_model_consumer")
134-
for i in range(self.num_producers):
135-
cc.init_collective_group(2, 1, group_name=f"sync_model_producer_{i}")
136-
else:
137-
for i in range(self.consumer_pp_size):
138-
cc.init_collective_group(2, 1, group_name=f"sync_model_consumer_pp_{i}")
139-
for i in range(self.num_producers):
140-
for j in range(self.consumer_pp_size):
141-
cc.init_collective_group(2, 1, group_name=f"sync_model_producer_{i}_pp_{j}")
142-
143-
def loop(self):
144-
while True:
145-
time.sleep(1)
146-
signal = ray.get(self.shared_signal_actor.get_signal.remote())
147-
if self.consumer_pp_size > 1:
148-
for i in range(self.consumer_pp_size):
149-
if signal.get(f"consumer_pp_{i}", None) == "ready_sync_model":
150-
self.profiler.enter(f"sync_model_consumer_pp_{i}")
151-
ray.get(self.shared_signal_actor.set_signal.remote(f"consumer_pp_{i}", "not_ready_sync_model"))
152-
# Broadcast the model state dict from consumer to shared variable actor
153-
self.state_dict_cpu[i] = ray_broadcast_tensor_dict(
154-
None,
155-
0,
156-
device=self.device,
157-
group_name=f"sync_model_consumer_pp_{i}",
158-
offload_to_cpu=True,
159-
)
160-
self.profiler.exit(f"sync_model_consumer_pp_{i}")
161-
self.weight_version[i] += 1
162-
for j in range(self.num_producers):
163-
for i in range(self.consumer_pp_size):
164-
if signal.get(f"producer_{j}_pp_{i}", None) == "ready_sync_model":
165-
self.profiler.enter(f"sync_model_producer_{j}_pp_{i}")
166-
# Broadcast the model state dict to all producers
167-
ray.get(
168-
self.shared_signal_actor.set_signal.remote(
169-
f"producer_{j}_pp_{i}", "not_ready_sync_model"
170-
)
171-
)
172-
if self.producer_weight_version[i][f"producer_{j}"] < self.weight_version[i]:
173-
self.producer_weight_version[i][f"producer_{j}"] = self.weight_version[i]
174-
ray_broadcast_tensor_dict(
175-
self.state_dict_cpu[i],
176-
1,
177-
device=self.device,
178-
group_name=f"sync_model_producer_{j}_pp_{i}",
179-
offload_to_cpu=True,
180-
)
181-
else:
182-
# broadcast a dummy tensor to save the communication cost
183-
ray_broadcast_tensor_dict(
184-
{"not_ready_sync_model": torch.ones((1)).cpu()},
185-
1,
186-
device=self.device,
187-
group_name=f"sync_model_producer_{j}_pp_{i}",
188-
offload_to_cpu=True,
189-
)
190-
self.profiler.exit(f"sync_model_producer_{j}_pp_{i}")
191-
else:
192-
if signal.get("consumer", None) == "ready_sync_model":
193-
self.profiler.enter("sync_model_consumer")
194-
ray.get(self.shared_signal_actor.set_signal.remote("consumer", "not_ready_sync_model"))
195-
# Broadcast the model state dict from consumer to shared variable actor
196-
self.state_dict_cpu = ray_broadcast_tensor_dict(
197-
None,
198-
0,
199-
device=self.device,
200-
group_name="sync_model_consumer",
201-
offload_to_cpu=True,
202-
)
203-
self.profiler.exit("sync_model_consumer")
204-
self.weight_version[0] += 1
205-
for i in range(self.num_producers):
206-
if signal.get(f"producer_{i}", None) == "ready_sync_model":
207-
self.profiler.enter(f"sync_model_producer_{i}")
208-
# Broadcast the model state dict to all producers
209-
ray.get(self.shared_signal_actor.set_signal.remote(f"producer_{i}", "not_ready_sync_model"))
210-
if self.producer_weight_version[0][f"producer_{i}"] < self.weight_version[0]:
211-
self.producer_weight_version[0][f"producer_{i}"] = self.weight_version[0]
212-
ray_broadcast_tensor_dict(
213-
self.state_dict_cpu,
214-
1,
215-
device=self.device,
216-
group_name=f"sync_model_producer_{i}",
217-
offload_to_cpu=True,
218-
)
219-
else:
220-
# broadcast a dummy tensor to save the communication cost
221-
ray_broadcast_tensor_dict(
222-
{"not_ready_sync_model": torch.ones((1)).cpu()},
223-
1,
224-
device=self.device,
225-
group_name=f"sync_model_producer_{i}",
226-
offload_to_cpu=True,
227-
)
228-
self.profiler.exit(f"sync_model_producer_{i}")
229-
if signal.get("consumer", None) == "terminate":
230-
self.profiler.log("terminate sync model worker")
231-
break

0 commit comments

Comments
 (0)