Skip to content

Commit ed546ae

Browse files
yuki-97zpqiuparthchadha
authored
feat: streaming each dtensor in refit (#176)
Signed-off-by: Yuki Huang <yukih@nvidia.com> Signed-off-by: Alex Qiu <alexq@nvidia.com> Signed-off-by: Parth Chadha <pchadha@nvidia.com> Co-authored-by: Alex Qiu <alexq@nvidia.com> Co-authored-by: Parth Chadha <pchadha@nvidia.com>
1 parent 5c62657 commit ed546ae

File tree

13 files changed

+510
-179
lines changed

13 files changed

+510
-179
lines changed

examples/configs/grpo_math_1B.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ policy:
3939
precision: "bfloat16"
4040
fsdp_offload_enabled: false
4141
activation_checkpointing_enabled: false
42+
refit_buffer_size_gb: 4 # used for refitting inference engine, the unit is GB
4243

4344
dtensor_cfg:
4445
enabled: false

examples/configs/grpo_math_8B.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ policy:
1717
precision: "bfloat16"
1818
fsdp_offload_enabled: false
1919
activation_checkpointing_enabled: false
20+
refit_buffer_size_gb: 4 # used for refitting inference engine, the unit is GB
2021

2122
optimizer:
2223
name: "torch.optim.AdamW"

nemo_reinforcer/algorithms/grpo.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -279,13 +279,34 @@ def setup(
279279
def refit_policy_generation(
280280
policy: PolicyInterface,
281281
policy_generation: GenerationInterface,
282+
refit_buffer_size_gb: int, # GB
282283
):
283284
"""Refit the policy generation interface with the latest policy weights."""
284285
policy.offload_before_refit()
285-
ipc_handles = policy.get_weights_ipc_handles()
286-
policy_generation.prepare_for_generation()
287-
policy_generation.update_weights(ipc_handles)
286+
policy_generation.prepare_for_generation(tags=["weights"])
287+
# Streaming update weights to save memory
288+
state_dict_info = policy.prepare_weights_for_ipc()
289+
# group keys to save time
290+
available_bytes = refit_buffer_size_gb * (1024**3)
291+
split_keys, keys = [], []
292+
for key, size_in_bytes in state_dict_info:
293+
if size_in_bytes > available_bytes:
294+
if keys:
295+
split_keys.append(keys)
296+
keys = []
297+
available_bytes = refit_buffer_size_gb * (1024**3)
298+
299+
keys.append(key)
300+
available_bytes -= size_in_bytes
301+
302+
if len(keys) > 0:
303+
split_keys.append(keys)
304+
# do update
305+
for keys in split_keys:
306+
ipc_handles = policy.get_weights_ipc_handles(keys)
307+
policy_generation.update_weights(ipc_handles)
288308
policy.offload_after_refit()
309+
policy_generation.prepare_for_generation(tags=["kv_cache"])
289310

290311

291312
# ===============================================================================
@@ -321,12 +342,13 @@ def grpo_train(
321342
consumed_samples = grpo_save_state["consumed_samples"]
322343
val_period = master_config["grpo"]["val_period"]
323344
val_at_start = master_config["grpo"]["val_at_start"]
345+
refit_buffer_size_gb = master_config["policy"]["refit_buffer_size_gb"]
324346

325347
# Run validation at the start if configured
326348
if val_at_start and step == 0:
327349
print("\n🔍 Running initial validation...")
328350
if NEED_REFIT and POLICY_GENERATION_STALE:
329-
refit_policy_generation(policy, policy_generation)
351+
refit_policy_generation(policy, policy_generation, refit_buffer_size_gb)
330352
POLICY_GENERATION_STALE = False
331353
else:
332354
policy_generation.prepare_for_generation()
@@ -368,7 +390,11 @@ def grpo_train(
368390
print(f"▶ Generating responses for batch of size {repeated_batch.size}...")
369391
with timer.time("prepare_for_generation"):
370392
if NEED_REFIT and POLICY_GENERATION_STALE:
371-
refit_policy_generation(policy, policy_generation)
393+
refit_policy_generation(
394+
policy,
395+
policy_generation,
396+
refit_buffer_size_gb,
397+
)
372398
POLICY_GENERATION_STALE = False
373399
else:
374400
policy_generation.prepare_for_generation()
@@ -476,7 +502,11 @@ def grpo_train(
476502
# Run validation if it's a validation step
477503
if val_period > 0 and (step + 1) % val_period == 0:
478504
if NEED_REFIT and POLICY_GENERATION_STALE:
479-
refit_policy_generation(policy, policy_generation)
505+
refit_policy_generation(
506+
policy,
507+
policy_generation,
508+
refit_buffer_size_gb,
509+
)
480510
POLICY_GENERATION_STALE = False
481511
else:
482512
policy_generation.prepare_for_generation()

nemo_reinforcer/models/generation/vllm.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,14 @@ def sleep(self):
454454
gc.collect()
455455
torch.cuda.empty_cache()
456456

457-
def wake_up(self):
458-
self.llm.wake_up()
457+
def wake_up(self, **kwargs):
458+
# tags like ["weights", "kv_cache"]
459+
# We can call this function with just tags=["weights"] while doing refit to
460+
# avoid spiking memory with the kv_cache while the training fwk is awake.
461+
if "tags" in kwargs:
462+
self.llm.wake_up(tags=kwargs["tags"])
463+
else:
464+
self.llm.wake_up()
459465

460466

461467
class VllmGeneration(GenerationInterface):
@@ -622,7 +628,7 @@ def prepare_for_generation(self, *args, **kwargs):
622628
try:
623629
# Use run_all_workers_single_data for methods that don't need data
624630
futures = self.worker_group.run_all_workers_single_data(
625-
"wake_up", only_on="tied_leader"
631+
"wake_up", only_on="tied_leader", **kwargs
626632
)
627633
# Wait for all futures to complete
628634
results = ray.get(futures)

nemo_reinforcer/models/generation/vllm_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def update_weights_from_ipc_handles(self, ipc_handles):
4545
weights = []
4646

4747
# Process each handle to get the tensor
48-
for name, handle in handles.items():
48+
for name, handle in handles:
4949
func, args = handle
5050
list_args = list(args)
5151
# Update device ID to match the current device

nemo_reinforcer/models/policy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,4 @@ class PolicyConfig(TypedDict):
4444
max_grad_norm: Optional[Union[float, int]]
4545
fsdp_offload_enabled: bool
4646
activation_checkpointing_enabled: bool
47+
refit_buffer_size_gb: int

nemo_reinforcer/models/policy/dtensor_policy_worker.py

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,9 @@ def __init__(
174174
if self.cpu_offload:
175175
self.model = self.move_buffer_to_device(self.model, "cpu")
176176

177-
self._held_model_params = None
177+
# used for streaming update inference engine weights
178+
self._held_sharded_state_dict_reference = None
179+
self._held_streamed_param_reference = None
178180

179181
if init_reference_model:
180182
self.reference_model_state_dict = get_cpu_state_dict(
@@ -235,6 +237,9 @@ def __init__(
235237
def is_alive(self):
236238
return True
237239

240+
def reset_peak_memory_stats(self):
241+
torch.cuda.reset_peak_memory_stats()
242+
238243
def get_gpu_info(self):
239244
"""Return information about the GPU being used by this worker."""
240245
return get_gpu_info(self.model)
@@ -554,50 +559,45 @@ def report_device_id(self) -> str:
554559
return get_device_uuid(device_idx)
555560

556561
@torch.no_grad()
557-
def get_weight_ipc_handles(self, offload_model=True):
558-
from torch.multiprocessing.reductions import reduce_tensor
559-
562+
def prepare_weights_for_ipc(self):
560563
self.model = self.move_to_cuda(self.model)
561-
params = self.model.state_dict()
564+
self._held_sharded_state_dict_reference = self.model.state_dict()
565+
# Collect info for streaming multiple tensors
566+
state_dict_info = []
567+
for name, tensor in self._held_sharded_state_dict_reference.items():
568+
# dtensor's numel will return complete tensor instead of only local tensor
569+
size_in_bytes = tensor.element_size() * tensor.numel()
570+
state_dict_info.append((name, size_in_bytes))
571+
return state_dict_info
562572

563-
# Create a copy of parameters in the desired dtype (bfloat16 or float32)
564-
dtype_params = {}
565-
for name, param in params.items():
566-
if isinstance(param, DTensor):
567-
param = param.full_tensor()
573+
@torch.no_grad()
574+
def get_weights_ipc_handles(self, keys):
575+
from torch.multiprocessing.reductions import reduce_tensor
568576

577+
converted_params = {}
578+
for key in keys:
579+
# Get full_tensor for dtensor (GPU > 1)
580+
tensor = self._held_sharded_state_dict_reference[key]
581+
if isinstance(tensor, DTensor):
582+
full_tensor = tensor.full_tensor()
583+
else:
584+
full_tensor = tensor
569585
# Convert parameters to the configured dtype
570-
dtype_params[name] = param.to(
571-
device="cuda", dtype=self.dtype, non_blocking=True
572-
)
586+
converted_params[key] = full_tensor.to(self.dtype, non_blocking=True)
573587

574-
for name, buffer in self.model.named_buffers():
575-
if isinstance(buffer, DTensor):
576-
buffer = buffer.full_tensor()
588+
# Temporary record the full tensor for cleanup
589+
# It is needed for cleanup the last full_tensor in the refit process
590+
self._held_streamed_param_reference = converted_params
577591

578-
dtype_params[name] = buffer.to(
579-
device="cuda", dtype=self.dtype, non_blocking=True
580-
)
581-
582-
torch.cuda.synchronize()
583-
584-
# Replace the original params with the converted ones
585-
params = dtype_params
586-
587-
# hold on to the params so we can explicitly delete them after refit
588-
self._held_model_params = params
589-
590-
data = {}
592+
# Get device UUID for IPC
591593
device_uuid = self.report_device_id()
592-
for name, p in params.items():
593-
data[name] = reduce_tensor(p.detach())
594-
595-
if offload_model or self.cpu_offload:
596-
self.model = self.move_to_cpu(self.model)
597-
gc.collect()
598-
torch.cuda.empty_cache()
594+
# Create handles for the tensors
595+
all_handles = []
596+
for key, p in converted_params.items():
597+
handle = reduce_tensor(p.detach())
598+
all_handles.append((key, handle))
599599

600-
return {device_uuid: data}
600+
return {device_uuid: all_handles}
601601

602602
def prepare_for_lp_inference(self):
603603
if not self.cpu_offload:
@@ -655,9 +655,13 @@ def offload_after_refit(self):
655655
torch.randn(1).cuda() # wake up torch allocator
656656
self.offload_before_refit() # rerun the old offload function
657657

658-
if self._held_model_params is not None:
659-
del self._held_model_params
660-
self._held_model_params = None
658+
# Clean up the held tensors
659+
if self._held_sharded_state_dict_reference is not None:
660+
del self._held_sharded_state_dict_reference
661+
self._held_sharded_state_dict_reference = None
662+
if self._held_streamed_param_reference is not None:
663+
del self._held_streamed_param_reference
664+
self._held_streamed_param_reference = None
661665

662666
gc.collect()
663667
torch.cuda.empty_cache()

nemo_reinforcer/models/policy/fsdp1_policy_worker.py

Lines changed: 61 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,11 @@ def do_fsdp(model):
149149
self.reference_model = do_fsdp(self.reference_model)
150150
self.reference_model = self.manual_offload_to_cpu(self.reference_model)
151151
self.model = self.manual_load_to_gpu(self.model)
152-
self._held_reference_model_params = None
152+
153+
# used for streaming update inference engine weights
154+
self._held_sharded_state_dict_reference = None
155+
self._held_streamed_param_reference = None
156+
153157
# register_fsdp_forward_method(self.model, "generate")
154158
if init_optimizer:
155159
optimizer_cls = import_class_from_path(self.cfg["optimizer"]["name"])
@@ -205,6 +209,9 @@ def do_fsdp(model):
205209
def is_alive(self):
206210
return True
207211

212+
def reset_peak_memory_stats(self):
213+
torch.cuda.reset_peak_memory_stats()
214+
208215
def get_gpu_info(self):
209216
"""Return information about the GPU being used by this worker."""
210217
return get_gpu_info(self.model)
@@ -720,38 +727,61 @@ def report_device_id(self) -> str:
720727
return get_device_uuid(device_idx)
721728

722729
@torch.no_grad()
723-
def get_weight_ipc_handles(self, offload_model=True):
724-
from torch.multiprocessing.reductions import reduce_tensor
730+
def prepare_weights_for_ipc(self):
731+
from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
725732

726733
# If the model is not FSDP, then we need to manually move it to the GPU
727734
# For an FSDP model, model.state_dict() will move the params to the GPU
728-
if not isinstance(self.model, torch.distributed.fsdp.FullyShardedDataParallel):
735+
if not isinstance(self.model, FullyShardedDataParallel):
729736
self.model = self.manual_load_to_gpu(self.model)
737+
self._held_sharded_state_dict_reference = self.model.state_dict()
738+
else:
739+
# Get sharded state dict instead of full state dict for FSDP1
740+
with FullyShardedDataParallel.state_dict_type(
741+
self.model,
742+
state_dict_type=StateDictType.SHARDED_STATE_DICT,
743+
state_dict_config=ShardedStateDictConfig(),
744+
):
745+
self._held_sharded_state_dict_reference = self.model.state_dict()
746+
747+
# Collect info for streaming multiple tensors
748+
state_dict_info = []
749+
for name, tensor in self._held_sharded_state_dict_reference.items():
750+
# dtensor's numel will return complete tensor instead of only local tensor
751+
size_in_bytes = tensor.element_size() * tensor.numel()
752+
state_dict_info.append((name, size_in_bytes))
753+
754+
return state_dict_info
730755

731-
# TODO @sahilj: do this without an allgather (maybe FSDP2)
732-
params = self.model.state_dict()
756+
@torch.no_grad()
757+
def get_weights_ipc_handles(self, keys):
758+
from torch.distributed.tensor import DTensor
759+
from torch.multiprocessing.reductions import reduce_tensor
733760

734-
# Create a copy of parameters in the desired dtype (bfloat16 or float32)
735-
dtype_params = {}
736-
for name, param in params.items():
761+
converted_params = {}
762+
for key in keys:
763+
# Get full_tensor for dtensor (GPU > 1)
764+
tensor = self._held_sharded_state_dict_reference[key]
765+
if isinstance(tensor, DTensor):
766+
full_tensor = tensor.full_tensor()
767+
else:
768+
full_tensor = tensor
737769
# Convert parameters to the configured dtype
738-
dtype_params[name] = param.to(self.dtype, non_blocking=True)
739-
740-
# Replace the original params with the converted ones
741-
params = dtype_params
742-
# For FSDP1, params may get GC'ed before sending to vllm,
743-
# so we need to hold a reference to them
744-
self._held_reference_model_params = params
745-
data = {}
770+
converted_params[key] = full_tensor.to(self.dtype, non_blocking=True)
771+
772+
# Temporary record the full tensor for cleanup
773+
# It is needed for cleanup the last full_tensor in the refit process
774+
self._held_streamed_param_reference = converted_params
775+
776+
# Get device UUID for IPC
746777
device_uuid = self.report_device_id()
747-
for name, p in params.items():
748-
data[name] = reduce_tensor(p.detach())
778+
# Create handles for the tensors
779+
all_handles = []
780+
for key, p in converted_params.items():
781+
handle = reduce_tensor(p.detach())
782+
all_handles.append((key, handle))
749783

750-
if offload_model:
751-
self.model = self.manual_offload_to_cpu(self.model)
752-
gc.collect()
753-
torch.cuda.empty_cache()
754-
return {device_uuid: data}
784+
return {device_uuid: all_handles}
755785

756786
def prepare_for_lp_inference(self):
757787
self.model = self.manual_load_to_gpu(self.model)
@@ -802,9 +832,13 @@ def offload_after_refit(self):
802832
torch.randn(1).cuda() # wake up torch allocator
803833
self.offload_before_refit() # rerun the old offload function
804834

805-
if self._held_reference_model_params is not None:
806-
del self._held_reference_model_params
807-
self._held_reference_model_params = None
835+
# Clean up the held tensors
836+
if self._held_sharded_state_dict_reference is not None:
837+
del self._held_sharded_state_dict_reference
838+
self._held_sharded_state_dict_reference = None
839+
if self._held_streamed_param_reference is not None:
840+
del self._held_streamed_param_reference
841+
self._held_streamed_param_reference = None
808842

809843
gc.collect()
810844
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)