@@ -248,6 +248,11 @@ def init_health_status(self) -> None:
248
248
create = False ,
249
249
)
250
250
251
+ def _broadcast_model_weights_signal (self , src : int , group ) -> int :
252
+ model_weights_signal_tensor = paddle .full (shape = [1 ], fill_value = self .model_weights_signal [0 ], dtype = "int32" )
253
+ paddle .distributed .broadcast (model_weights_signal_tensor , src = src , group = group )
254
+ return model_weights_signal_tensor .item ()
255
+
251
256
def event_loop_normal (self ) -> None :
252
257
"""Main event loop for Paddle Distrubuted Workers.
253
258
TODO(gongshaotian): support remote calling of functions that control worker.
@@ -257,15 +262,19 @@ def event_loop_normal(self) -> None:
257
262
req_ids = []
258
263
num_running_requests = 0
259
264
local_rank = self .local_rank % self .parallel_config .tensor_parallel_size
260
- self .model_weights_signal = paddle .zeros ([1 ], dtype = paddle .int32 )
265
+ self .model_weights_signal = np .zeros ([1 ], dtype = np .int32 )
261
266
while True :
262
267
if local_rank == 0 :
263
268
if self .model_weights_status .value [0 ] != 0 :
264
269
self .model_weights_signal [0 ] = int (self .model_weights_status .value [0 ])
265
270
if self .fd_config .load_config .dynamic_load_weight and self .parallel_config .enable_expert_parallel :
266
- paddle .distributed .broadcast (self .model_weights_signal , src = 0 , group = self .parallel_config .ep_group )
267
- if self .fd_config .load_config .dynamic_load_weight :
268
- paddle .distributed .broadcast (self .model_weights_signal , src = 0 , group = self .parallel_config .tp_group )
271
+ self .model_weights_signal [0 ] = self ._broadcast_model_weights_signal (
272
+ src = 0 , group = self .parallel_config .ep_group
273
+ )
274
+ if self .fd_config .load_config .dynamic_load_weight and self .parallel_config .tensor_parallel_size > 1 :
275
+ self .model_weights_signal [0 ] = self ._broadcast_model_weights_signal (
276
+ src = 0 , group = self .parallel_config .tp_group
277
+ )
269
278
270
279
self .insert_step = False
271
280
req_dicts = None
@@ -293,7 +302,9 @@ def event_loop_normal(self) -> None:
293
302
else :
294
303
paddle .distributed .barrier (self .parallel_config .tp_group )
295
304
if self .model_weights_signal [0 ] != 0 :
296
- logger .info (f"Rank: { self .local_rank } has updated parameters." )
305
+ logger .info (
306
+ f"Rank: { self .local_rank } to update or clear parameters, signal is { self .model_weights_signal [0 ]} , [-1:clear, 1:update]"
307
+ )
297
308
from fastdeploy .rl .dynamic_weight_manager import (
298
309
DynamicWeightManager ,
299
310
)
@@ -305,6 +316,7 @@ def event_loop_normal(self) -> None:
305
316
self .parallel_config .engine_worker_queue_port ,
306
317
)
307
318
self .model_weights_signal [0 ] = 0
319
+ logger .info (f"Rank: { self .local_rank } has updated or cleared parameters." )
308
320
309
321
if self .exist_task_signal .value [0 ] == 1 or self .task_queue .read_finish_flag .get () == 1 :
310
322
logger .info (f"Rank: { self .local_rank } Detected new requests." )
0 commit comments