Skip to content

Commit 9ee4621

Browse files
committed
fix sync model, support tp+pp
1 parent 880d886 commit 9ee4621

File tree

8 files changed

+372
-345
lines changed

8 files changed

+372
-345
lines changed

applications/ColossalChat/coati/distributed/comm.py

Lines changed: 148 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import copy
2+
import time
23
from typing import Any, Dict
34

45
import ray
56
import ray.util.collective as cc
67
import torch
78
import torch.distributed.distributed_c10d as c10d
9+
from coati.distributed.profiling_utils import CustomProfiler
810
from packaging.version import Version
911

12+
from colossalai.utils import get_current_device
13+
1014

1115
def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str = "default") -> Any:
1216
rank = cc.get_rank(group_name)
@@ -71,46 +75,157 @@ def ray_broadcast_tensor_dict(
7175

7276
@ray.remote
7377
class SharedVariableActor:
74-
def __init__(self):
75-
# double queues
76-
self.data_queue = None
77-
self.data_queue_buffered = None
78+
def __init__(self, number_of_readers: int = 1):
79+
self.data_queue = []
7880
self.model_weights = None
7981
self.data_access_count = 0
8082
self.ready_process_count = {}
83+
self.number_of_readers = number_of_readers
84+
self.consumer_buffer_size = 0
85+
self.signals = {}
8186

82-
def increase_ready_process_count(self, name):
83-
self.ready_process_count = {k: v for k, v in self.ready_process_count.items() if k > name - 5}
84-
if name not in self.ready_process_count:
85-
self.ready_process_count[name] = 0
86-
self.ready_process_count[name] += 1
87-
88-
def get_ready_process_count(self, name):
89-
return self.ready_process_count[name]
90-
91-
def extend_data(self, data):
92-
if self.data_access_count > 0:
93-
# update the buffered data if data is not being accessed by all consumers
94-
# if producer are too fast, will not overwrite the data but extend the data
95-
if self.data_queue_buffered is None:
96-
self.data_queue_buffered = []
97-
self.data_queue_buffered.extend(data)
98-
return True
99-
if self.data_queue is None:
100-
self.data_queue = []
101-
self.data_queue.extend(data)
102-
self.data_access_count = 0
87+
def get_queued_data_size(self):
88+
queued_data_size = sum([data["input_ids"].size(0) for data in self.data_queue])
89+
return queued_data_size
90+
91+
def append_data(self, data):
92+
self.data_queue.append(data)
10393
return True
10494

10595
def get_data(self):
106-
if self.data_queue is None:
96+
if not self.data_queue:
97+
# no data in the queue, return None
10798
return None
108-
data = copy.deepcopy(self.data_queue)
99+
data = copy.deepcopy(self.data_queue[0])
109100
self.data_access_count += 1
110-
if self.data_access_count == 4:
111-
# data in data_queue has been accessed by all consumers
112-
# swap the data queue with the buffered data, erase the old data
113-
if self.data_queue_buffered is not None:
114-
self.data_queue = self.data_queue_buffered
115-
self.data_queue_buffered = None
101+
if self.data_access_count == self.number_of_readers:
102+
# first data in data_queue has been accessed by all consumers
103+
# remove it from the queue
104+
self.data_queue.pop(0)
105+
self.data_access_count = 0
116106
return data
107+
108+
def set_signal(self, key: str, signal: str):
109+
self.signals[key] = signal
110+
111+
def get_signal(self):
112+
return self.signals
113+
114+
115+
@ray.remote
116+
class SharedVariableActorNCCL:
117+
def __init__(
118+
self, consumer_pp_size, num_producers, shared_signal_actor: SharedVariableActor, enable_profiling: bool = True
119+
):
120+
self.consumer_pp_size = consumer_pp_size
121+
self.state_dict_cpu = {i: {"not_ready_sync_model": torch.ones((1)).cpu()} for i in range(self.consumer_pp_size)}
122+
self.num_producers = num_producers
123+
self.shared_signal_actor = shared_signal_actor
124+
self.device = get_current_device()
125+
self.profiler = CustomProfiler(f"D", disabled=not enable_profiling)
126+
self.weight_version = {i: 0 for i in range(self.consumer_pp_size)}
127+
self.producer_weight_version = {
128+
j: {f"producer_{i}": 0 for i in range(self.num_producers)} for j in range(self.consumer_pp_size)
129+
}
130+
131+
def setup(self):
132+
if self.consumer_pp_size == 1:
133+
cc.init_collective_group(2, 1, group_name="sync_model_consumer")
134+
for i in range(self.num_producers):
135+
cc.init_collective_group(2, 1, group_name=f"sync_model_producer_{i}")
136+
else:
137+
for i in range(self.consumer_pp_size):
138+
cc.init_collective_group(2, 1, group_name=f"sync_model_consumer_pp_{i}")
139+
for i in range(self.num_producers):
140+
for j in range(self.consumer_pp_size):
141+
cc.init_collective_group(2, 1, group_name=f"sync_model_producer_{i}_pp_{j}")
142+
143+
def loop(self):
144+
while True:
145+
time.sleep(1)
146+
signal = ray.get(self.shared_signal_actor.get_signal.remote())
147+
if self.consumer_pp_size > 1:
148+
for i in range(self.consumer_pp_size):
149+
if signal.get(f"consumer_pp_{i}", None) == "ready_sync_model":
150+
self.profiler.enter(f"sync_model_consumer_pp_{i}")
151+
ray.get(self.shared_signal_actor.set_signal.remote(f"consumer_pp_{i}", "not_ready_sync_model"))
152+
# Broadcast the model state dict from consumer to shared variable actor
153+
self.state_dict_cpu[i] = ray_broadcast_tensor_dict(
154+
None,
155+
0,
156+
device=self.device,
157+
group_name=f"sync_model_consumer_pp_{i}",
158+
offload_to_cpu=True,
159+
)
160+
self.profiler.exit(f"sync_model_consumer_pp_{i}")
161+
self.weight_version[i] += 1
162+
for j in range(self.num_producers):
163+
for i in range(self.consumer_pp_size):
164+
if signal.get(f"producer_{j}_pp_{i}", None) == "ready_sync_model":
165+
self.profiler.enter(f"sync_model_producer_{j}_pp_{i}")
166+
# Broadcast the model state dict to all producers
167+
ray.get(
168+
self.shared_signal_actor.set_signal.remote(
169+
f"producer_{j}_pp_{i}", "not_ready_sync_model"
170+
)
171+
)
172+
if self.producer_weight_version[i][f"producer_{j}"] < self.weight_version[i]:
173+
self.producer_weight_version[i][f"producer_{j}"] = self.weight_version[i]
174+
ray_broadcast_tensor_dict(
175+
self.state_dict_cpu[i],
176+
1,
177+
device=self.device,
178+
group_name=f"sync_model_producer_{j}_pp_{i}",
179+
offload_to_cpu=True,
180+
)
181+
else:
182+
# broadcast a dummy tensor to save the communication cost
183+
ray_broadcast_tensor_dict(
184+
{"not_ready_sync_model": torch.ones((1)).cpu()},
185+
1,
186+
device=self.device,
187+
group_name=f"sync_model_producer_{j}_pp_{i}",
188+
offload_to_cpu=True,
189+
)
190+
self.profiler.exit(f"sync_model_producer_{j}_pp_{i}")
191+
else:
192+
if signal.get("consumer", None) == "ready_sync_model":
193+
self.profiler.enter("sync_model_consumer")
194+
ray.get(self.shared_signal_actor.set_signal.remote("consumer", "not_ready_sync_model"))
195+
# Broadcast the model state dict from consumer to shared variable actor
196+
self.state_dict_cpu = ray_broadcast_tensor_dict(
197+
None,
198+
0,
199+
device=self.device,
200+
group_name="sync_model_consumer",
201+
offload_to_cpu=True,
202+
)
203+
self.profiler.exit("sync_model_consumer")
204+
self.weight_version[0] += 1
205+
for i in range(self.num_producers):
206+
if signal.get(f"producer_{i}", None) == "ready_sync_model":
207+
self.profiler.enter(f"sync_model_producer_{i}")
208+
# Broadcast the model state dict to all producers
209+
ray.get(self.shared_signal_actor.set_signal.remote(f"producer_{i}", "not_ready_sync_model"))
210+
if self.producer_weight_version[0][f"producer_{i}"] < self.weight_version[0]:
211+
self.producer_weight_version[0][f"producer_{i}"] = self.weight_version[0]
212+
ray_broadcast_tensor_dict(
213+
self.state_dict_cpu,
214+
1,
215+
device=self.device,
216+
group_name=f"sync_model_producer_{i}",
217+
offload_to_cpu=True,
218+
)
219+
else:
220+
# broadcast a dummy tensor to save the communication cost
221+
ray_broadcast_tensor_dict(
222+
{"not_ready_sync_model": torch.ones((1)).cpu()},
223+
1,
224+
device=self.device,
225+
group_name=f"sync_model_producer_{i}",
226+
offload_to_cpu=True,
227+
)
228+
self.profiler.exit(f"sync_model_producer_{i}")
229+
if signal.get("consumer", None) == "terminate":
230+
self.profiler.log("terminate sync model worker")
231+
break

0 commit comments

Comments
 (0)