@@ -101,6 +101,9 @@ class QueueManager(BaseManager):
101
101
self .finish_request_barrier = [
102
102
threading .Barrier (self .num_client ) for _ in range (self .local_data_parallel_size )
103
103
]
104
+ self .worker_process_tp_barrier = [
105
+ threading .Barrier (self .num_client ) for _ in range (self .local_data_parallel_size )
106
+ ]
104
107
105
108
self .finish_add_cache_task_barrier = [
106
109
threading .Barrier (self .num_client ) for _ in range (self .local_data_parallel_size )
@@ -193,6 +196,10 @@ class QueueManager(BaseManager):
193
196
"get_finish_add_cache_task_barrier" ,
194
197
callable = lambda idx : self .finish_add_cache_task_barrier [idx ],
195
198
)
199
+ QueueManager .register (
200
+ "get_worker_process_tp_barrier" ,
201
+ callable = lambda idx : self .worker_process_tp_barrier [idx ],
202
+ )
196
203
self .manager : BaseManager = QueueManager (address = self .address , authkey = self .authkey )
197
204
self .manager .start ()
198
205
else :
@@ -217,6 +224,7 @@ class QueueManager(BaseManager):
217
224
QueueManager .register ("get_connect_rdma_tasks" )
218
225
QueueManager .register ("get_connect_rdma_tasks_responses" )
219
226
QueueManager .register ("get_connect_task_lock" )
227
+ QueueManager .register ("get_worker_process_tp_barrier" )
220
228
self .manager = QueueManager (address = self .address , authkey = self .authkey )
221
229
self ._connect_with_retry ()
222
230
@@ -239,6 +247,7 @@ class QueueManager(BaseManager):
239
247
self .finish_add_cache_task_barrier = self .manager .get_finish_add_cache_task_barrier (
240
248
self .local_data_parallel_id
241
249
)
250
+ self .worker_process_tp_barrier = self .manager .get_worker_process_tp_barrier (self .local_data_parallel_id )
242
251
self .finished_req_queue = self .manager .get_finish_request_queue (self .local_data_parallel_id )
243
252
self .finished_add_cache_task_queue = self .manager .get_finish_add_cache_task_queue (
244
253
self .local_data_parallel_id
0 commit comments