Skip to content

Commit 2883746

Browse files
authored
fix model_weights_signal (#4092)
* fix model_weights_signal
1 parent 2485333 commit 2883746

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

fastdeploy/worker/worker_process.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,11 @@ def init_health_status(self) -> None:
248248
create=False,
249249
)
250250

251+
def _broadcast_model_weights_signal(self, src: int, group) -> int:
252+
model_weights_signal_tensor = paddle.full(shape=[1], fill_value=self.model_weights_signal[0], dtype="int32")
253+
paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group)
254+
return model_weights_signal_tensor.item()
255+
251256
def event_loop_normal(self) -> None:
252257
"""Main event loop for Paddle Distrubuted Workers.
253258
TODO(gongshaotian): support remote calling of functions that control worker.
@@ -257,15 +262,19 @@ def event_loop_normal(self) -> None:
257262
req_ids = []
258263
num_running_requests = 0
259264
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
260-
self.model_weights_signal = paddle.zeros([1], dtype=paddle.int32)
265+
self.model_weights_signal = np.zeros([1], dtype=np.int32)
261266
while True:
262267
if local_rank == 0:
263268
if self.model_weights_status.value[0] != 0:
264269
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
265270
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel:
266-
paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.ep_group)
267-
if self.fd_config.load_config.dynamic_load_weight:
268-
paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.tp_group)
271+
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
272+
src=0, group=self.parallel_config.ep_group
273+
)
274+
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.tensor_parallel_size > 1:
275+
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
276+
src=0, group=self.parallel_config.tp_group
277+
)
269278

270279
self.insert_step = False
271280
req_dicts = None
@@ -293,7 +302,9 @@ def event_loop_normal(self) -> None:
293302
else:
294303
paddle.distributed.barrier(self.parallel_config.tp_group)
295304
if self.model_weights_signal[0] != 0:
296-
logger.info(f"Rank: {self.local_rank} has updated parameters.")
305+
logger.info(
306+
f"Rank: {self.local_rank} to update or clear parameters, signal is {self.model_weights_signal[0]}, [-1:clear, 1:update]"
307+
)
297308
from fastdeploy.rl.dynamic_weight_manager import (
298309
DynamicWeightManager,
299310
)
@@ -305,6 +316,7 @@ def event_loop_normal(self) -> None:
305316
self.parallel_config.engine_worker_queue_port,
306317
)
307318
self.model_weights_signal[0] = 0
319+
logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.")
308320

309321
if self.exist_task_signal.value[0] == 1 or self.task_queue.read_finish_flag.get() == 1:
310322
logger.info(f"Rank: {self.local_rank} Detected new requests.")

0 commit comments

Comments
 (0)