Skip to content

Commit bd7e4b0

Browse files
SahilJain314terrykongparthchadha
authored
fix: Mixed Prec memory improvements and better default configs (converge-able) (#32)
Signed-off-by: Sahil Jain <sahilj@nvidia.com> Signed-off-by: Terry Kong <terryk@nvidia.com> Co-authored-by: Terry Kong <terryk@nvidia.com> Co-authored-by: Parth Chadha <pchadha@nvidia.com>
1 parent c49571e commit bd7e4b0

File tree

5 files changed

+45
-7
lines changed

5 files changed

+45
-7
lines changed

examples/configs/grpo_math_1B.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# GRPO Algorithm Configuration
22
grpo:
33
num_prompts_per_step: 32
4-
num_generations_per_prompt: 8
4+
num_generations_per_prompt: 16
55
max_num_steps: 1000000
66
normalize_rewards: true
77
use_leave_one_out_baseline: true
88
val_period: 10
9-
val_at_start: true
9+
val_at_start: false
1010
max_val_samples: 256
11-
val_batch_size: 16
11+
val_batch_size: 256
1212

1313
loss_fn:
1414
reference_policy_kl_penalty: 0.01
@@ -24,7 +24,7 @@ checkpointing:
2424

2525
policy:
2626
model_name: "meta-llama/Llama-3.2-1B-Instruct"
27-
train_global_batch_size: 32
27+
train_global_batch_size: 512
2828
train_micro_batch_size: 4
2929
generation_batch_size: 32
3030
logprob_batch_size: 4

examples/configs/grpo_math_8B.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ defaults: "grpo_math_1B.yaml"
33

44
policy:
55
model_name: "meta-llama/Llama-3.1-8B-Instruct"
6-
train_global_batch_size: 32
6+
train_global_batch_size: 512
77
train_micro_batch_size: 1
88
generation_batch_size: 32
99
logprob_batch_size: 2
@@ -13,7 +13,7 @@ policy:
1313
optimizer:
1414
name: "torch.optim.AdamW"
1515
kwargs:
16-
lr: 5.0e-6
16+
lr: 3.0e-7
1717
weight_decay: 0.01
1818
betas: [0.9, 0.999]
1919
eps: 1e-8

nemo_reinforcer/models/generation/vllm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def __init__(
166166

167167
self.llm = LLM(
168168
model=self.model_name,
169+
load_format="dummy",
169170
tensor_parallel_size=self.tensor_parallel_size,
170171
gpu_memory_utilization=self.gpu_memory_utilization,
171172
enable_prefix_caching=True,

nemo_reinforcer/models/policy/hf_policy.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,17 @@ def __init__(
107107
def do_fsdp(model):
108108
# Create a device mesh with 'world_size' GPUs in a 1D arrangement.
109109
mesh = init_device_mesh("cuda", (world_size,))
110+
mp_policy = MixedPrecision(
111+
param_dtype=self.dtype,
112+
reduce_dtype=torch.float32,
113+
buffer_dtype=torch.float32,
114+
)
110115

111116
return FullyShardedDataParallel(
112117
model,
113118
device_mesh=mesh,
114119
auto_wrap_policy=size_based_auto_wrap_policy,
120+
mixed_precision=mp_policy,
115121
)
116122

117123
self.model.to("cuda")
@@ -676,16 +682,31 @@ def report_device_id(self) -> str:
676682
self.device_uuid = current_platform.get_device_uuid(torch.cuda.current_device())
677683
return self.device_uuid
678684

679-
def get_weight_ipc_handles(self):
685+
@torch.no_grad()
686+
def get_weight_ipc_handles(self, offload_model=True):
680687
from torch.multiprocessing.reductions import reduce_tensor
681688

682689
# TODO @sahilj: do this without an allgather (maybe FSDP2)
683690
params = self.model.state_dict()
691+
692+
# Create a copy of parameters in the desired dtype (bfloat16 or float32)
693+
dtype_params = {}
694+
for name, param in params.items():
695+
# Convert parameters to the configured dtype
696+
dtype_params[name] = param.to(self.dtype, non_blocking=True)
697+
698+
# Replace the original params with the converted ones
699+
params = dtype_params
684700
self._held_reference_model_params = params
685701
data = {}
686702
self.device_uuid = self.report_device_id()
687703
for name, p in params.items():
688704
data[name] = reduce_tensor(p.detach())
705+
706+
if offload_model:
707+
self.model = self.move_to_cpu(self.model)
708+
gc.collect()
709+
torch.cuda.empty_cache()
689710
return {self.device_uuid: data}
690711

691712
def prepare_for_lp_inference(self):
@@ -707,13 +728,19 @@ def prepare_for_training(self, *args, **kwargs):
707728

708729
torch.cuda.empty_cache()
709730

731+
@torch.no_grad()
710732
def offload_before_refit(self):
711733
"""Offload the optimizer and buffers to the CPU."""
734+
torch.randn(1).cuda() # wake up torch allocator
712735
if hasattr(self, "optimizer") and self.optimizer is not None:
713736
for state in self.optimizer.state.values():
714737
for k, v in state.items():
715738
if torch.is_tensor(v):
716739
state[k] = v.to("cpu")
740+
741+
for buffer in self.model.buffers():
742+
buffer.data = buffer.data.to("cpu")
743+
717744
gc.collect()
718745
torch.cuda.empty_cache()
719746

@@ -724,10 +751,12 @@ def offload_before_refit(self):
724751
f"GPU Memory after optimizer offload: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved"
725752
)
726753

754+
@torch.no_grad()
727755
def offload_after_refit(self):
728756
# Offload as much as possible on the CPU
729757
self.model = self.move_to_cpu(self.model)
730758
self.model.eval()
759+
torch.randn(1).cuda() # wake up torch allocator
731760
self.offload_before_refit() # rerun the old offload function
732761

733762
if self._held_reference_model_params is not None:

tests/unit/models/generation/test_vllm_generation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,16 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer):
231231
# Create both policies
232232
print("Creating vLLM policy...")
233233
vllm_policy = VllmGeneration(cluster, vllm_config)
234+
vllm_policy.finish_generation()
234235

235236
print("Creating HF policy...")
236237
hf_policy = HfPolicy(cluster, hf_config)
237238

239+
print(f"refitting vllm policy...")
240+
ipc_handles = hf_policy.get_weights_ipc_handles()
241+
vllm_policy.prepare_for_generation()
242+
vllm_policy.update_weights(ipc_handles)
243+
238244
# Step 1: Use vLLM for generation
239245
print("Using vLLM policy for fast generation...")
240246
generation_results = vllm_policy.generate(test_input_data)
@@ -262,6 +268,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer):
262268
}
263269
)
264270
# Get logprobs from HF policy
271+
hf_policy.prepare_for_lp_inference()
265272
fprop_results = hf_policy.get_logprobs(fprop_logprob_data)
266273
# Zero out logprobs for input tokens
267274

@@ -327,6 +334,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer):
327334
print(f"Training loss: {results['loss']}")
328335

329336
hf_policy.finish_training()
337+
hf_policy.offload_after_refit()
330338

331339
# Step 4: Use vLLM for generation again to complete the workflow
332340
print("Using vLLM for generation again...")

0 commit comments

Comments
 (0)