@@ -303,9 +303,9 @@ def train_batch(self, batch_data: Dict[str, torch.Tensor], kl_ctl: float, step_i
303303 status [k ] = sum (v ) / len (v )
304304
305305 if isinstance (kl_loss , torch .Tensor ):
306- status ["kl " ] = kl_loss .detach ().float ().mean ().item ()
306+ status ["cur_refer_kl " ] = kl_loss .detach ().float ().mean ().item ()
307307 else :
308- status ["kl " ] = float (kl_loss )
308+ status ["cur_refer_kl " ] = float (kl_loss )
309309
310310 status = self .strategy .all_reduce (status )
311311 status_list .append (status )
@@ -314,36 +314,102 @@ def train_batch(self, batch_data: Dict[str, torch.Tensor], kl_ctl: float, step_i
314314 "policy_loss" : status ["policy_loss" ],
315315 # "approx_kl": status["approx_kl"],
316316 "cur_old_kl" : status ["cur_old_kl" ],
317- "cur_refer_kl" : status ["kl " ],
317+ "cur_refer_kl" : status ["cur_refer_kl " ],
318318 "clipfrac" : status ["clipfrac" ],
319319 "lr" : status ["lr" ],
320320 "iter" : self .train_iter ,
321321 })
322322 self .train_iter += 1
323323 return status_list
324324
325+ # def _deepspeed_broadcast(self):
326+ # use_prefix_cache = getattr(self.strategy.args, "enable_prefix_caching", False)
327+ # if use_prefix_cache:
328+ # self.vllm_engine.reset_prefix_cache()
329+
330+ # torch.cuda.empty_cache()
331+ # model = self.actor.model.module
332+ # count, num_params = 0, len(list(model.named_parameters()))
333+ # for name, param in model.named_parameters():
334+ # count += 1 # empty_cache at last param
335+ # # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0
336+ # with deepspeed.zero.GatheredParameters([param], enabled=self.strategy.args.zero_stage == 3):
337+ # shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape
338+ # self.vllm_engine.update_weight(name, dtype=param.dtype, shape=shape, weight=param.data, empty_cache=(count == num_params))
339+
325340 def _deepspeed_broadcast (self ):
326- # FIX: Add barrier before vLLM weight update to prevent NCCL deadlock with tp>1
341+ # 1. 前置 Barrier:防止上一轮训练未结束就开始更新权重
327342 if torch .distributed .is_initialized ():
328343 torch .distributed .barrier ()
329344
345+ # 2. 只有 Rank 0 重置缓存(这是纯逻辑操作,不需要通信,所以可以包在 if 里)
330346 use_prefix_cache = getattr (self .strategy .args , "enable_prefix_caching" , False )
331- if use_prefix_cache :
347+ if use_prefix_cache and torch . distributed . get_rank () == 0 :
332348 self .vllm_engine .reset_prefix_cache ()
333349
334350 torch .cuda .empty_cache ()
351+
352+ # 3. 权重更新逻辑
335353 model = self .actor .model .module
336- count , num_params = 0 , len (list (model .named_parameters ()))
337- for name , param in model .named_parameters ():
338- count += 1 # empty_cache at last param
339- # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0
340- with deepspeed .zero .GatheredParameters ([param ], enabled = self .strategy .args .zero_stage == 3 ):
341- shape = param .shape if self .strategy .args .zero_stage != 3 else param .ds_shape
342- self .vllm_engine .update_weight (name , dtype = param .dtype , shape = shape , weight = param .data , empty_cache = (count == num_params ))
354+ # 注意:所有 rank 都要获取参数列表,以便进入循环
355+ params_list = list (model .named_parameters ())
356+ count , num_params = 0 , len (params_list )
343357
344- # FIX: Add barrier after vLLM weight update to ensure all ranks complete
358+ for name , param in params_list :
359+ count += 1
360+
361+ # 【关键修正 1】所有 Rank 必须都进入这个上下文管理器!
362+ # modifier_rank=0:表示只在 Rank 0 上将参数聚合成完整形状,其他 Rank 不占用完整显存
363+ with deepspeed .zero .GatheredParameters ([param ],
364+ modifier_rank = 0 ,
365+ enabled = self .strategy .args .zero_stage == 3 ):
366+
367+ # 【关键修正 2】在上下文内部,只有 Rank 0 拿到完整数据并发送给 vLLM
368+ if torch .distributed .get_rank () == 0 :
369+ # 此时 param.data 在 Rank 0 上是完整的,在其他 Rank 上可能是空的或分片的
370+ shape = param .shape if self .strategy .args .zero_stage != 3 else param .ds_shape
371+
372+ self .vllm_engine .update_weight (
373+ name ,
374+ dtype = param .dtype ,
375+ shape = shape ,
376+ weight = param .data ,
377+ empty_cache = (count == num_params )
378+ )
379+
380+ # 4. 后置 Barrier:确保 Rank 0 完成所有 RPC 后,大家再一起继续
345381 if torch .distributed .is_initialized ():
346- torch .distributed .barrier ()
382+ torch .distributed .barrier ()
383+
384+ # def _deepspeed_broadcast(self):
385+ # # FIX: Add barrier before vLLM weight update to prevent NCCL deadlock with tp>1
386+ # if torch.distributed.is_initialized():
387+ # torch.distributed.barrier()
388+
389+ # # Only rank 0 should reset prefix cache and update vLLM weights
390+ # use_prefix_cache = getattr(self.strategy.args, "enable_prefix_caching", False)
391+ # if use_prefix_cache and torch.distributed.get_rank() == 0:
392+ # self.vllm_engine.reset_prefix_cache()
393+
394+ # torch.cuda.empty_cache()
395+
396+ # # Only rank 0 updates vLLM weights to avoid:
397+ # # 1. Redundant collective_rpc calls (8 ranks × 3000 params = 24000 RPC calls)
398+ # # 2. NCCL communication congestion from simultaneous GatheredParameters
399+ # # 3. GPU memory bandwidth competition from simultaneous load_weights
400+ # if torch.distributed.get_rank() == 0:
401+ # model = self.actor.model.module
402+ # count, num_params = 0, len(list(model.named_parameters()))
403+ # for name, param in model.named_parameters():
404+ # count += 1 # empty_cache at last param
405+ # # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0
406+ # with deepspeed.zero.GatheredParameters([param], enabled=self.strategy.args.zero_stage == 3):
407+ # shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape
408+ # self.vllm_engine.update_weight(name, dtype=param.dtype, shape=shape, weight=param.data, empty_cache=(count == num_params))
409+
410+ # # FIX: Add barrier after vLLM weight update to ensure all ranks complete
411+ # if torch.distributed.is_initialized():
412+ # torch.distributed.barrier()
347413
348414 def _broadcast_to_vllm (self ):
349415 use_prefix_cache = getattr (self .strategy .args , "enable_prefix_caching" , False )
@@ -357,9 +423,10 @@ def _broadcast_to_vllm(self):
357423 def _broadcast_param (param , count , num_params ):
358424 if torch .distributed .get_rank () == 0 :
359425 shape = param .shape if self .strategy .args .zero_stage != 3 else param .ds_shape
360- self .vllm_engine .update_weight (name , dtype = param .dtype , shape = shape , empty_cache = count == num_params )
361-
362- self ._model_update_group .broadcast (param .data , src = 0 , stream = torch .cuda .current_stream ())
426+ self .vllm_engine .update_weight (name , dtype = param .dtype , shape = shape , empty_cache = count == num_params )
427+
428+ # All ranks must participate in broadcast (collective operation)
429+ self ._model_update_group .broadcast (param .data , src = 0 , stream = torch .cuda .current_stream ())
363430
364431 def _handle_cuda_ipc (param , count , num_params ):
365432 from torch .multiprocessing .reductions import reduce_tensor
0 commit comments