1
1
import copy
2
- import time
3
2
from typing import Any , Dict
4
3
5
4
import ray
6
5
import ray .util .collective as cc
7
6
import torch
8
7
import torch .distributed .distributed_c10d as c10d
9
- from coati .distributed .profiling_utils import CustomProfiler
10
8
from packaging .version import Version
11
9
12
- from colossalai .utils import get_current_device
13
-
14
10
15
11
def ray_broadcast_object (obj : Any , src : int = 0 , device = None , group_name : str = "default" ) -> Any :
16
12
rank = cc .get_rank (group_name )
@@ -42,6 +38,7 @@ def ray_broadcast_tensor_dict(
42
38
src : int = 0 ,
43
39
device = None ,
44
40
group_name : str = "default" ,
41
+ backend : str = "nccl" ,
45
42
offload_to_cpu : bool = False ,
46
43
) -> Dict [str , torch .Tensor ]:
47
44
rank = cc .get_rank (group_name )
@@ -62,7 +59,13 @@ def ray_broadcast_tensor_dict(
62
59
tensor = tensor_dict [k ]
63
60
else :
64
61
tensor = torch .empty (shape , dtype = dtype , device = device )
62
+ if backend == "gloo" and dtype == torch .bfloat16 :
63
+ # Gloo does not support bfloat16, convert to float16
64
+ tensor = tensor .view (torch .float16 )
65
65
cc .broadcast (tensor , src , group_name )
66
+ if backend == "gloo" and dtype == torch .bfloat16 :
67
+ # Convert back to bfloat16 if it was converted to float16
68
+ tensor = tensor .view (torch .bfloat16 )
66
69
if rank != src :
67
70
if offload_to_cpu :
68
71
out_dict [k ] = tensor .cpu ()
@@ -77,155 +80,42 @@ def ray_broadcast_tensor_dict(
77
80
class SharedVariableActor :
78
81
def __init__ (self , number_of_readers : int = 1 ):
79
82
self .data_queue = []
80
- self .model_weights = None
81
- self .data_access_count = 0
82
- self .ready_process_count = {}
83
+ self .data_uid = 0
83
84
self .number_of_readers = number_of_readers
84
- self .consumer_buffer_size = 0
85
85
self .signals = {}
86
+ self .signal_procs_meet_count = {}
86
87
87
88
def get_queued_data_size (self ):
88
- queued_data_size = sum ([data ["input_ids" ].size (0 ) for data in self .data_queue ])
89
+ queued_data_size = sum ([data [1 ][ "input_ids" ].size (0 ) for data in self .data_queue ])
89
90
return queued_data_size
90
91
91
92
def append_data (self , data ):
92
- self .data_queue .append (data )
93
+ self .data_queue .append ([self .data_uid , data , 0 ]) # [data_uid, data, access_count]
94
+ self .data_uid += 1
93
95
return True
94
96
95
- def get_data (self ):
97
+ def get_data (self , data_uid : int ):
98
+ # for multi-process data reading
96
99
if not self .data_queue :
97
100
# no data in the queue, return None
98
101
return None
99
- data = copy .deepcopy (self .data_queue [0 ])
100
- self .data_access_count += 1
101
- if self .data_access_count == self .number_of_readers :
102
- # first data in data_queue has been accessed by all consumers
103
- # remove it from the queue
104
- self .data_queue .pop (0 )
105
- self .data_access_count = 0
106
- return data
102
+ to_pop_index = None
103
+ ret = None
104
+ for i , (uid , data , access_count ) in enumerate (self .data_queue ):
105
+ if uid == data_uid :
106
+ # found the data with the given uid
107
+ self .data_queue [i ][2 ] += 1
108
+ ret = copy .deepcopy (data )
109
+ if self .data_queue [i ][2 ] == self .number_of_readers :
110
+ to_pop_index = i
111
+ break
112
+ if to_pop_index is not None :
113
+ # remove the data from the queue if it has been accessed by all readers
114
+ self .data_queue .pop (to_pop_index )
115
+ return ret
107
116
108
117
def set_signal (self , key : str , signal : str ):
109
118
self .signals [key ] = signal
110
119
111
120
def get_signal (self ):
112
121
return self .signals
113
-
114
-
115
- @ray .remote
116
- class SharedVariableActorNCCL :
117
- def __init__ (
118
- self , consumer_pp_size , num_producers , shared_signal_actor : SharedVariableActor , enable_profiling : bool = True
119
- ):
120
- self .consumer_pp_size = consumer_pp_size
121
- self .state_dict_cpu = {i : {"not_ready_sync_model" : torch .ones ((1 )).cpu ()} for i in range (self .consumer_pp_size )}
122
- self .num_producers = num_producers
123
- self .shared_signal_actor = shared_signal_actor
124
- self .device = get_current_device ()
125
- self .profiler = CustomProfiler (f"D" , disabled = not enable_profiling )
126
- self .weight_version = {i : 0 for i in range (self .consumer_pp_size )}
127
- self .producer_weight_version = {
128
- j : {f"producer_{ i } " : 0 for i in range (self .num_producers )} for j in range (self .consumer_pp_size )
129
- }
130
-
131
- def setup (self ):
132
- if self .consumer_pp_size == 1 :
133
- cc .init_collective_group (2 , 1 , group_name = "sync_model_consumer" )
134
- for i in range (self .num_producers ):
135
- cc .init_collective_group (2 , 1 , group_name = f"sync_model_producer_{ i } " )
136
- else :
137
- for i in range (self .consumer_pp_size ):
138
- cc .init_collective_group (2 , 1 , group_name = f"sync_model_consumer_pp_{ i } " )
139
- for i in range (self .num_producers ):
140
- for j in range (self .consumer_pp_size ):
141
- cc .init_collective_group (2 , 1 , group_name = f"sync_model_producer_{ i } _pp_{ j } " )
142
-
143
- def loop (self ):
144
- while True :
145
- time .sleep (1 )
146
- signal = ray .get (self .shared_signal_actor .get_signal .remote ())
147
- if self .consumer_pp_size > 1 :
148
- for i in range (self .consumer_pp_size ):
149
- if signal .get (f"consumer_pp_{ i } " , None ) == "ready_sync_model" :
150
- self .profiler .enter (f"sync_model_consumer_pp_{ i } " )
151
- ray .get (self .shared_signal_actor .set_signal .remote (f"consumer_pp_{ i } " , "not_ready_sync_model" ))
152
- # Broadcast the model state dict from consumer to shared variable actor
153
- self .state_dict_cpu [i ] = ray_broadcast_tensor_dict (
154
- None ,
155
- 0 ,
156
- device = self .device ,
157
- group_name = f"sync_model_consumer_pp_{ i } " ,
158
- offload_to_cpu = True ,
159
- )
160
- self .profiler .exit (f"sync_model_consumer_pp_{ i } " )
161
- self .weight_version [i ] += 1
162
- for j in range (self .num_producers ):
163
- for i in range (self .consumer_pp_size ):
164
- if signal .get (f"producer_{ j } _pp_{ i } " , None ) == "ready_sync_model" :
165
- self .profiler .enter (f"sync_model_producer_{ j } _pp_{ i } " )
166
- # Broadcast the model state dict to all producers
167
- ray .get (
168
- self .shared_signal_actor .set_signal .remote (
169
- f"producer_{ j } _pp_{ i } " , "not_ready_sync_model"
170
- )
171
- )
172
- if self .producer_weight_version [i ][f"producer_{ j } " ] < self .weight_version [i ]:
173
- self .producer_weight_version [i ][f"producer_{ j } " ] = self .weight_version [i ]
174
- ray_broadcast_tensor_dict (
175
- self .state_dict_cpu [i ],
176
- 1 ,
177
- device = self .device ,
178
- group_name = f"sync_model_producer_{ j } _pp_{ i } " ,
179
- offload_to_cpu = True ,
180
- )
181
- else :
182
- # broadcast a dummy tensor to save the communication cost
183
- ray_broadcast_tensor_dict (
184
- {"not_ready_sync_model" : torch .ones ((1 )).cpu ()},
185
- 1 ,
186
- device = self .device ,
187
- group_name = f"sync_model_producer_{ j } _pp_{ i } " ,
188
- offload_to_cpu = True ,
189
- )
190
- self .profiler .exit (f"sync_model_producer_{ j } _pp_{ i } " )
191
- else :
192
- if signal .get ("consumer" , None ) == "ready_sync_model" :
193
- self .profiler .enter ("sync_model_consumer" )
194
- ray .get (self .shared_signal_actor .set_signal .remote ("consumer" , "not_ready_sync_model" ))
195
- # Broadcast the model state dict from consumer to shared variable actor
196
- self .state_dict_cpu = ray_broadcast_tensor_dict (
197
- None ,
198
- 0 ,
199
- device = self .device ,
200
- group_name = "sync_model_consumer" ,
201
- offload_to_cpu = True ,
202
- )
203
- self .profiler .exit ("sync_model_consumer" )
204
- self .weight_version [0 ] += 1
205
- for i in range (self .num_producers ):
206
- if signal .get (f"producer_{ i } " , None ) == "ready_sync_model" :
207
- self .profiler .enter (f"sync_model_producer_{ i } " )
208
- # Broadcast the model state dict to all producers
209
- ray .get (self .shared_signal_actor .set_signal .remote (f"producer_{ i } " , "not_ready_sync_model" ))
210
- if self .producer_weight_version [0 ][f"producer_{ i } " ] < self .weight_version [0 ]:
211
- self .producer_weight_version [0 ][f"producer_{ i } " ] = self .weight_version [0 ]
212
- ray_broadcast_tensor_dict (
213
- self .state_dict_cpu ,
214
- 1 ,
215
- device = self .device ,
216
- group_name = f"sync_model_producer_{ i } " ,
217
- offload_to_cpu = True ,
218
- )
219
- else :
220
- # broadcast a dummy tensor to save the communication cost
221
- ray_broadcast_tensor_dict (
222
- {"not_ready_sync_model" : torch .ones ((1 )).cpu ()},
223
- 1 ,
224
- device = self .device ,
225
- group_name = f"sync_model_producer_{ i } " ,
226
- offload_to_cpu = True ,
227
- )
228
- self .profiler .exit (f"sync_model_producer_{ i } " )
229
- if signal .get ("consumer" , None ) == "terminate" :
230
- self .profiler .log ("terminate sync model worker" )
231
- break
0 commit comments