Skip to content

Commit 7b3fad8

Browse files
feat: refit metadata optimization (#686)
Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com> Signed-off-by: Zhiyu Li <zhiyul@nvidia.com> Signed-off-by: Yuki Huang <yukih@nvidia.com> Co-authored-by: Yuki Huang <yukih@nvidia.com> Co-authored-by: yuki <48991475+yuki-666@users.noreply.github.com>
1 parent def7682 commit 7b3fad8

File tree

9 files changed

+124
-106
lines changed

9 files changed

+124
-106
lines changed

.gitmodules

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[submodule "3rdparty/NeMo"]
22
path = 3rdparty/NeMo-workspace/NeMo
33
url = https://github.com/NVIDIA/NeMo.git
4-
branch = ashors/nemorl-qwen3
4+
branch = zhiyul/yukih/prepare-refit-info
55
shallow = true
66
[submodule "3rdparty/Megatron-LM"]
77
path = 3rdparty/Megatron-LM-workspace/Megatron-LM

3rdparty/NeMo-workspace/NeMo

Submodule NeMo updated from 33259f2 to 8ddf438

nemo_rl/algorithms/grpo.py

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import os
1515
import warnings
16+
from contextlib import nullcontext
1617
from pathlib import Path
1718
from typing import Any, NotRequired, Optional, TypedDict, TypeVar, cast
1819

@@ -400,6 +401,7 @@ def refit_policy_generation(
400401
policy_generation: GenerationInterface,
401402
colocated_inference: bool,
402403
_refit_buffer_size_gb: Optional[int] = None,
404+
timer: Optional[Timer] = None,
403405
) -> None:
404406
"""Refit the policy generation interface with the latest policy weights.
405407
@@ -414,43 +416,50 @@ def refit_policy_generation(
414416
policy.offload_before_refit()
415417
policy_generation.prepare_for_generation(tags=["weights"])
416418

417-
# update weights
418-
update_success = False
419-
if colocated_inference:
420-
# get model param keys, which is grouped by size
421-
grouped_param_keys = policy.prepare_weights_for_ipc(
422-
_refit_buffer_size_gb=_refit_buffer_size_gb
423-
)
424-
total_num_keys = sum(len(k) for k in grouped_param_keys)
425-
print(
426-
f"[Refit] Split {total_num_keys} keys into {len(grouped_param_keys)} groups"
427-
)
428-
# do update
429-
for keys in grouped_param_keys:
430-
ipc_handles = policy.get_weights_ipc_handles(keys)
431-
update_success = policy_generation.update_weights_from_ipc_handles(
432-
ipc_handles
419+
# Create a context manager that does nothing when timer is None
420+
timer_context = (
421+
timer.time("prepare_for_generation/transfer_and_update_weights")
422+
if timer is not None
423+
else nullcontext()
424+
)
425+
with timer_context:
426+
# update weights
427+
update_success = False
428+
if colocated_inference:
429+
# get model param keys, which is grouped by size
430+
grouped_param_keys = policy.prepare_weights_for_ipc(
431+
_refit_buffer_size_gb=_refit_buffer_size_gb
433432
)
434-
if not update_success:
435-
break
436-
else:
437-
# update weights through nccl
438-
futures_train = policy.broadcast_weights_for_collective()
439-
futures_inference = policy_generation.update_weights_from_collective()
440-
# wait for all futures to complete
441-
ray.get(futures_train)
442-
results = ray.get(futures_inference)
443-
update_success = all(result for result in results if result is not None)
444-
445-
# check if update is successful
446-
if not update_success:
447-
error_tag = "cuda-ipc" if colocated_inference else "nccl"
448-
error_message = (
449-
"❌ Error: Updating weights for the generation policy failed during refit.\n"
450-
f"This often indicates an issue with {error_tag} or "
451-
"a problem within the generation backend (e.g., vLLM worker).\n"
452-
)
453-
raise RuntimeError(error_message)
433+
total_num_keys = sum(len(k) for k in grouped_param_keys)
434+
print(
435+
f"[Refit] Split {total_num_keys} keys into {len(grouped_param_keys)} groups"
436+
)
437+
# do update
438+
for keys in grouped_param_keys:
439+
ipc_handles = policy.get_weights_ipc_handles(keys)
440+
update_success = policy_generation.update_weights_from_ipc_handles(
441+
ipc_handles
442+
)
443+
if not update_success:
444+
break
445+
else:
446+
# update weights through nccl
447+
futures_train = policy.broadcast_weights_for_collective()
448+
futures_inference = policy_generation.update_weights_from_collective()
449+
# wait for all futures to complete
450+
ray.get(futures_train)
451+
results = ray.get(futures_inference)
452+
update_success = all(result for result in results if result is not None)
453+
454+
# check if update is successful
455+
if not update_success:
456+
error_tag = "cuda-ipc" if colocated_inference else "nccl"
457+
error_message = (
458+
"❌ Error: Updating weights for the generation policy failed during refit.\n"
459+
f"This often indicates an issue with {error_tag} or "
460+
"a problem within the generation backend (e.g., vLLM worker).\n"
461+
)
462+
raise RuntimeError(error_message)
454463

455464
if colocated_inference:
456465
policy.offload_after_refit()
@@ -544,7 +553,7 @@ def grpo_train(
544553
with timer.time("prepare_for_generation"):
545554
if NEED_REFIT and POLICY_GENERATION_STALE:
546555
refit_policy_generation(
547-
policy, policy_generation, colocated_inference
556+
policy, policy_generation, colocated_inference, timer=timer
548557
)
549558
POLICY_GENERATION_STALE = False
550559
else:

nemo_rl/models/generation/vllm_backend.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from collections import defaultdict
1516
from typing import Any, Iterable, Optional
1617

1718
import torch
19+
from torch.multiprocessing.reductions import rebuild_cuda_tensor
1820

1921
try:
2022
import vllm # noqa: F401
@@ -136,7 +138,7 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles):
136138
try:
137139
is_tensor_packed = local_device_ipc_handles[0]
138140
if is_tensor_packed:
139-
_, all_handles, tensor_metadata = local_device_ipc_handles
141+
_, all_handles, list_keys = local_device_ipc_handles
140142
else:
141143
_, name_and_handle_list = local_device_ipc_handles
142144

@@ -152,33 +154,40 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles):
152154
# Extract packed tensor from IPC handle
153155
dtype_to_packed_tensor = {}
154156
for dtype, tensor_handle in all_handles:
155-
func, args = tensor_handle
157+
func = rebuild_cuda_tensor
158+
args = tensor_handle[0]
156159
list_args = list(args)
157160
list_args[6] = device_id
158161
tensor = func(*list_args)
159162
dtype_to_packed_tensor[dtype] = tensor
160163

161-
# Unpack tensor to weights. Here we only return a view of the tensor to avoid
162-
# using extra memory.
163-
for key, metadata in tensor_metadata.items():
164-
# dtype for the 1st and 2nd steps may be different (e.g. e_score_correction_bias)
165-
if isinstance(metadata, tuple):
166-
# use dtype of current step
167-
offset, dtype = metadata
168-
shape, _, size = self.state_dict_info[key]
169-
# update record
170-
self.state_dict_info[key] = (shape, dtype, size)
171-
else:
172-
offset = metadata
173-
shape, dtype, size = self.state_dict_info[key]
174-
tensor = dtype_to_packed_tensor[dtype][offset : offset + size].view(
175-
*shape
164+
weights = []
165+
dtype_to_offset = defaultdict(lambda: 0)
166+
for key in list_keys:
167+
shape, dtype, size = self.state_dict_info[key]
168+
weights.append(
169+
(
170+
key,
171+
dtype_to_packed_tensor[dtype][
172+
dtype_to_offset[dtype] : dtype_to_offset[dtype] + size
173+
].view(*shape),
174+
)
176175
)
177-
weights.append((key, tensor))
176+
dtype_to_offset[dtype] += size
177+
178+
expected_sizes = {
179+
dtype: tensor.numel()
180+
for dtype, tensor in dtype_to_packed_tensor.items()
181+
}
182+
assert dtype_to_offset == expected_sizes, (
183+
f"Packed tensor size mismatch: expected sizes from keys list {expected_sizes} != actual packed tensor sizes {dtype_to_offset}. "
184+
f"This indicates the keys list order doesn't match the order used when packing tensors."
185+
)
178186
else:
179187
# Process each handle to get the tensor
180188
for name, handle in name_and_handle_list:
181-
func, args = handle
189+
func = rebuild_cuda_tensor
190+
args = handle[0]
182191
list_args = list(args)
183192
list_args[6] = device_id
184193
tensor = func(*list_args)

nemo_rl/models/megatron/refit_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@ def gather_params(model, keys: list[str], key_to_global_keys: dict[str, list[str
156156
if k is not None:
157157
gathered_params[k] = p
158158

159-
print(f"Time taken to gather params: {time.perf_counter() - st}")
160159
return gathered_params
161160

162161

nemo_rl/models/policy/dtensor_policy_worker.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
from nemo_rl.models.policy.utils import (
6969
configure_expandable_segments,
7070
get_gpu_info,
71+
get_handle_from_tensor,
7172
get_runtime_env_for_policy_worker,
7273
import_class_from_path,
7374
is_vllm_v1_engine_enabled,
@@ -1235,8 +1236,6 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]:
12351236

12361237
@torch.no_grad()
12371238
def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]:
1238-
from torch.multiprocessing.reductions import reduce_tensor
1239-
12401239
assert self._held_sharded_state_dict_reference is not None, (
12411240
"prepare_weights_for_ipc must be called before get_weights_ipc_handles"
12421241
)
@@ -1266,7 +1265,7 @@ def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]:
12661265
# Create handles for the tensors
12671266
all_handles = []
12681267
for key, p in converted_params.items():
1269-
handle = reduce_tensor(p.detach())
1268+
handle = get_handle_from_tensor(p)
12701269
all_handles.append((key, handle))
12711270

12721271
# (pack_tensor_for_ipc: bool, handles: list)

0 commit comments

Comments
 (0)