Skip to content

Commit 9082f62

Browse files
authored
[xpu] use cpu barrier (#4181)
1 parent 813befa commit 9082f62

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

fastdeploy/inter_communicator/engine_worker_queue.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ class QueueManager(BaseManager):
101101
self.finish_request_barrier = [
102102
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
103103
]
104+
self.worker_process_tp_barrier = [
105+
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
106+
]
104107

105108
self.finish_add_cache_task_barrier = [
106109
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
@@ -193,6 +196,10 @@ class QueueManager(BaseManager):
193196
"get_finish_add_cache_task_barrier",
194197
callable=lambda idx: self.finish_add_cache_task_barrier[idx],
195198
)
199+
QueueManager.register(
200+
"get_worker_process_tp_barrier",
201+
callable=lambda idx: self.worker_process_tp_barrier[idx],
202+
)
196203
self.manager: BaseManager = QueueManager(address=self.address, authkey=self.authkey)
197204
self.manager.start()
198205
else:
@@ -217,6 +224,7 @@ class QueueManager(BaseManager):
217224
QueueManager.register("get_connect_rdma_tasks")
218225
QueueManager.register("get_connect_rdma_tasks_responses")
219226
QueueManager.register("get_connect_task_lock")
227+
QueueManager.register("get_worker_process_tp_barrier")
220228
self.manager = QueueManager(address=self.address, authkey=self.authkey)
221229
self._connect_with_retry()
222230

@@ -239,6 +247,7 @@ class QueueManager(BaseManager):
239247
self.finish_add_cache_task_barrier = self.manager.get_finish_add_cache_task_barrier(
240248
self.local_data_parallel_id
241249
)
250+
self.worker_process_tp_barrier = self.manager.get_worker_process_tp_barrier(self.local_data_parallel_id)
242251
self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id)
243252
self.finished_add_cache_task_queue = self.manager.get_finish_add_cache_task_queue(
244253
self.local_data_parallel_id

fastdeploy/worker/worker_process.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,12 @@ def _broadcast_model_weights_signal(self, src: int, group) -> int:
256256
paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group)
257257
return model_weights_signal_tensor.item()
258258

259+
def _tp_barrier_wait(self):
260+
if current_platform.is_xpu():
261+
self.task_queue.worker_process_tp_barrier.wait()
262+
else:
263+
paddle.distributed.barrier(self.parallel_config.tp_group)
264+
259265
def event_loop_normal(self) -> None:
260266
"""Main event loop for Paddle Distributed Workers.
261267
TODO(gongshaotian): support remote calling of functions that control worker.
@@ -299,7 +305,7 @@ def event_loop_normal(self) -> None:
299305

300306
if self.parallel_config.tensor_parallel_size > 1:
301307
# Synchronize the signal for other workers
302-
paddle.distributed.barrier(self.parallel_config.tp_group)
308+
self._tp_barrier_wait()
303309

304310
if self.fd_config.load_config.dynamic_load_weight:
305311
if self.parallel_config.enable_expert_parallel:
@@ -350,7 +356,7 @@ def event_loop_normal(self) -> None:
350356

351357
if (not self.parallel_config.use_ep) and (not self.worker.model_runner.not_need_stop()):
352358
if self.ranks > 1:
353-
paddle.distributed.barrier(self.parallel_config.tp_group)
359+
self._tp_barrier_wait()
354360

355361
time.sleep(0.001)
356362
continue

0 commit comments

Comments
 (0)