Skip to content

Commit d6aa72a

Browse files
committed
polish(pu): polish _deepspeed_broadcast
1 parent 5fd2584 commit d6aa72a

File tree

2 files changed

+88
-18
lines changed

2 files changed

+88
-18
lines changed

zoo/jericho/priorzero/models/actor.py

Lines changed: 84 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -303,9 +303,9 @@ def train_batch(self, batch_data: Dict[str, torch.Tensor], kl_ctl: float, step_i
303303
status[k] = sum(v) / len(v)
304304

305305
if isinstance(kl_loss, torch.Tensor):
306-
status["kl"] = kl_loss.detach().float().mean().item()
306+
status["cur_refer_kl"] = kl_loss.detach().float().mean().item()
307307
else:
308-
status["kl"] = float(kl_loss)
308+
status["cur_refer_kl"] = float(kl_loss)
309309

310310
status = self.strategy.all_reduce(status)
311311
status_list.append(status)
@@ -314,36 +314,102 @@ def train_batch(self, batch_data: Dict[str, torch.Tensor], kl_ctl: float, step_i
314314
"policy_loss": status["policy_loss"],
315315
# "approx_kl": status["approx_kl"],
316316
"cur_old_kl": status["cur_old_kl"],
317-
"cur_refer_kl": status["kl"],
317+
"cur_refer_kl": status["cur_refer_kl"],
318318
"clipfrac": status["clipfrac"],
319319
"lr": status["lr"],
320320
"iter": self.train_iter,
321321
})
322322
self.train_iter += 1
323323
return status_list
324324

325+
# def _deepspeed_broadcast(self):
326+
# use_prefix_cache = getattr(self.strategy.args, "enable_prefix_caching", False)
327+
# if use_prefix_cache:
328+
# self.vllm_engine.reset_prefix_cache()
329+
330+
# torch.cuda.empty_cache()
331+
# model = self.actor.model.module
332+
# count, num_params = 0, len(list(model.named_parameters()))
333+
# for name, param in model.named_parameters():
334+
# count += 1 # empty_cache at last param
335+
# # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0
336+
# with deepspeed.zero.GatheredParameters([param], enabled=self.strategy.args.zero_stage == 3):
337+
# shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape
338+
# self.vllm_engine.update_weight(name, dtype=param.dtype, shape=shape, weight=param.data, empty_cache=(count == num_params))
339+
325340
def _deepspeed_broadcast(self):
326-
# FIX: Add barrier before vLLM weight update to prevent NCCL deadlock with tp>1
341+
# 1. 前置 Barrier:防止上一轮训练未结束就开始更新权重
327342
if torch.distributed.is_initialized():
328343
torch.distributed.barrier()
329344

345+
# 2. 只有 Rank 0 重置缓存(这是纯逻辑操作,不需要通信,所以可以包在 if 里)
330346
use_prefix_cache = getattr(self.strategy.args, "enable_prefix_caching", False)
331-
if use_prefix_cache:
347+
if use_prefix_cache and torch.distributed.get_rank() == 0:
332348
self.vllm_engine.reset_prefix_cache()
333349

334350
torch.cuda.empty_cache()
351+
352+
# 3. 权重更新逻辑
335353
model = self.actor.model.module
336-
count, num_params = 0, len(list(model.named_parameters()))
337-
for name, param in model.named_parameters():
338-
count += 1 # empty_cache at last param
339-
# For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0
340-
with deepspeed.zero.GatheredParameters([param], enabled=self.strategy.args.zero_stage == 3):
341-
shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape
342-
self.vllm_engine.update_weight(name, dtype=param.dtype, shape=shape, weight=param.data, empty_cache=(count == num_params))
354+
# 注意:所有 rank 都要获取参数列表,以便进入循环
355+
params_list = list(model.named_parameters())
356+
count, num_params = 0, len(params_list)
343357

344-
# FIX: Add barrier after vLLM weight update to ensure all ranks complete
358+
for name, param in params_list:
359+
count += 1
360+
361+
# 【关键修正 1】所有 Rank 必须都进入这个上下文管理器!
362+
# modifier_rank=0:表示只在 Rank 0 上将参数聚合成完整形状,其他 Rank 不占用完整显存
363+
with deepspeed.zero.GatheredParameters([param],
364+
modifier_rank=0,
365+
enabled=self.strategy.args.zero_stage == 3):
366+
367+
# 【关键修正 2】在上下文内部,只有 Rank 0 拿到完整数据并发送给 vLLM
368+
if torch.distributed.get_rank() == 0:
369+
# 此时 param.data 在 Rank 0 上是完整的,在其他 Rank 上可能是空的或分片的
370+
shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape
371+
372+
self.vllm_engine.update_weight(
373+
name,
374+
dtype=param.dtype,
375+
shape=shape,
376+
weight=param.data,
377+
empty_cache=(count == num_params)
378+
)
379+
380+
# 4. 后置 Barrier:确保 Rank 0 完成所有 RPC 后,大家再一起继续
345381
if torch.distributed.is_initialized():
346-
torch.distributed.barrier()
382+
torch.distributed.barrier()
383+
384+
# def _deepspeed_broadcast(self):
385+
# # FIX: Add barrier before vLLM weight update to prevent NCCL deadlock with tp>1
386+
# if torch.distributed.is_initialized():
387+
# torch.distributed.barrier()
388+
389+
# # Only rank 0 should reset prefix cache and update vLLM weights
390+
# use_prefix_cache = getattr(self.strategy.args, "enable_prefix_caching", False)
391+
# if use_prefix_cache and torch.distributed.get_rank() == 0:
392+
# self.vllm_engine.reset_prefix_cache()
393+
394+
# torch.cuda.empty_cache()
395+
396+
# # Only rank 0 updates vLLM weights to avoid:
397+
# # 1. Redundant collective_rpc calls (8 ranks × 3000 params = 24000 RPC calls)
398+
# # 2. NCCL communication congestion from simultaneous GatheredParameters
399+
# # 3. GPU memory bandwidth competition from simultaneous load_weights
400+
# if torch.distributed.get_rank() == 0:
401+
# model = self.actor.model.module
402+
# count, num_params = 0, len(list(model.named_parameters()))
403+
# for name, param in model.named_parameters():
404+
# count += 1 # empty_cache at last param
405+
# # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0
406+
# with deepspeed.zero.GatheredParameters([param], enabled=self.strategy.args.zero_stage == 3):
407+
# shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape
408+
# self.vllm_engine.update_weight(name, dtype=param.dtype, shape=shape, weight=param.data, empty_cache=(count == num_params))
409+
410+
# # FIX: Add barrier after vLLM weight update to ensure all ranks complete
411+
# if torch.distributed.is_initialized():
412+
# torch.distributed.barrier()
347413

348414
def _broadcast_to_vllm(self):
349415
use_prefix_cache = getattr(self.strategy.args, "enable_prefix_caching", False)
@@ -357,9 +423,10 @@ def _broadcast_to_vllm(self):
357423
def _broadcast_param(param, count, num_params):
358424
if torch.distributed.get_rank() == 0:
359425
shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape
360-
self.vllm_engine.update_weight(name, dtype=param.dtype, shape=shape, empty_cache=count == num_params)
361-
362-
self._model_update_group.broadcast(param.data, src=0, stream=torch.cuda.current_stream())
426+
self.vllm_engine.update_weight(name, dtype=param.dtype, shape=shape, empty_cache=count == num_params)
427+
428+
# All ranks must participate in broadcast (collective operation)
429+
self._model_update_group.broadcast(param.data, src=0, stream=torch.cuda.current_stream())
363430

364431
def _handle_cuda_ipc(param, count, num_params):
365432
from torch.multiprocessing.reductions import reduce_tensor

zoo/jericho/priorzero/priorzero_config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,10 @@ class PriorZeroLLMConfig:
133133
),
134134
}))
135135
# advantage = target_value - pred_value
136-
advantage_type: str = "advantage_running_norm" # "advantage", "target_reward", "advantage_batch_norm", "advantage_running_norm"
136+
# advantage_type: str = "advantage_running_norm" # "advantage", "target_reward", "advantage_batch_norm", "advantage_running_norm"
137+
# TODO========
138+
advantage_type: str = "advantage_batch_norm" # "advantage", "target_reward", "advantage_batch_norm", "advantage_running_norm"
139+
137140
eps_clip_low_high: Tuple[float, float] = (0.2, 0.2)
138141
rft_kl_coef: float = 0.01
139142
kl_estimator: str = "k3"

0 commit comments

Comments
 (0)