Skip to content

Commit e36f488

Browse files
authored
fix: Fix missing import (#222)
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
1 parent 98b7a90 commit e36f488

File tree

2 files changed

+66
-18
lines changed

2 files changed

+66
-18
lines changed

nemo_reinforcer/models/policy/fsdp1_policy_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch
2323
from torch.distributed.device_mesh import init_device_mesh
2424
from torch.distributed.fsdp import (
25+
CPUOffload,
2526
FullyShardedDataParallel,
2627
MixedPrecision,
2728
)

tests/unit/models/policy/test_fsdp1_worker.py

Lines changed: 65 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -261,18 +261,32 @@ def test_hf_policy_init(policy_setup, num_gpus):
261261

262262

263263
@pytest.fixture
264-
def training_setup(tokenizer, num_gpus):
265-
"""Setup and teardown specifically for training tests."""
264+
def training_setup(tokenizer, request, num_gpus):
265+
"""
266+
Setup and teardown specifically for training tests.
267+
268+
When used without parameterization, uses the default config.
269+
When parameterized, takes any config updates as a dictionary in request.param
270+
and applies them to the basic config.
271+
"""
266272
policy = None
267273
cluster = None
268274
data = None
269275
loss_fn = None
270276

277+
# Get config updates from request.param if available
278+
config_updates = {}
279+
config_suffix = ""
280+
if hasattr(request, "param") and request.param is not None:
281+
config_updates = request.param
282+
config_suffix = "-" + "-".join([f"{k}={v}" for k, v in config_updates.items()])
283+
271284
try:
272285
# Create resources with unique name
273-
cluster_name = f"test-train-{num_gpus}gpu"
286+
cluster_name = f"test-train-{num_gpus}gpu{config_suffix}"
274287
print(
275-
f"Creating training virtual cluster '{cluster_name}' for {num_gpus} GPUs..."
288+
f"Creating training virtual cluster '{cluster_name}' for {num_gpus} GPUs"
289+
f"{' with config updates: ' + str(config_updates) if config_updates else ''}"
276290
)
277291

278292
cluster = RayVirtualCluster(
@@ -283,7 +297,10 @@ def training_setup(tokenizer, num_gpus):
283297
max_colocated_worker_groups=1,
284298
)
285299

286-
config = basic_llama_test_config
300+
# Create a config with optional modifications
301+
config = deepcopy(basic_llama_test_config)
302+
if config_updates:
303+
config.update(config_updates)
287304

288305
print("Creating training HfPolicy...")
289306
policy = HfPolicy(
@@ -341,8 +358,23 @@ def get_max_gpu_utilization(policy):
341358

342359

343360
@pytest.mark.timeout(180)
344-
@pytest.mark.parametrize("num_gpus", [1, 2], ids=["1gpu", "2gpu"])
345-
def test_hf_policy_training(training_setup, tracker, num_gpus):
361+
@pytest.mark.parametrize(
362+
"num_gpus, training_setup, config_name",
363+
[
364+
(1, None, "default"),
365+
(2, None, "default"),
366+
(2, {"fsdp_offload_enabled": True}, "fsdp_offload"),
367+
(2, {"activation_checkpointing_enabled": True}, "activation_checkpointing"),
368+
],
369+
indirect=["training_setup"],
370+
ids=[
371+
"1gpu_default",
372+
"2gpu_default",
373+
"2gpu_fsdp_offload",
374+
"2gpu_activation_checkpointing",
375+
],
376+
)
377+
def test_hf_policy_training(training_setup, tracker, num_gpus, config_name):
346378
def verify_loss_tensor(loss_tensor):
347379
assert not torch.isnan(loss_tensor).any(), "Loss should not be NaN"
348380
assert not torch.isinf(loss_tensor).any(), "Loss should not be Inf"
@@ -357,7 +389,9 @@ def verify_loss_tensor(loss_tensor):
357389
assert loss_fn is not None, "Loss function was not created properly"
358390

359391
# Call prepare_for_training if available
360-
print("\nPreparing for training...")
392+
print(
393+
f"\nPreparing for training with {num_gpus} GPU(s) and {config_name} config..."
394+
)
361395
policy.prepare_for_training()
362396

363397
losses = []
@@ -370,7 +404,9 @@ def verify_loss_tensor(loss_tensor):
370404
verify_loss_tensor(loss_tensor)
371405
losses.append(loss_tensor[-1].item())
372406

373-
print(f"Training loss: {results['loss']}")
407+
print(
408+
f"Training loss with {num_gpus} GPU(s) and {config_name} config: {results['loss']}"
409+
)
374410

375411
policy.finish_training()
376412
assert losses[0] > losses[-1], "Loss should decrease over training iterations"
@@ -379,35 +415,46 @@ def verify_loss_tensor(loss_tensor):
379415
policy
380416
)
381417
print(
382-
f"Max GPU Utilization after training: {after_training_mem_allocated:,.1f} MB allocated, "
418+
f"Max GPU Utilization after training with {num_gpus} GPU(s) and {config_name} config: {after_training_mem_allocated:,.1f} MB allocated, "
383419
f"{after_training_mem_reserved:,.1f} MB reserved"
384420
)
385421
tracker.track(
386-
f"after_training_mem_allocated_{num_gpus}gpu", after_training_mem_allocated
422+
f"{num_gpus}gpu_{config_name}_after_training_mem_allocated",
423+
after_training_mem_allocated,
387424
)
388425
tracker.track(
389-
f"after_training_mem_reserved_{num_gpus}gpu", after_training_mem_reserved
426+
f"{num_gpus}gpu_{config_name}_after_training_mem_reserved",
427+
after_training_mem_reserved,
390428
)
391429

392430
policy.offload_after_refit()
393431
after_offload_mem_allocated, after_offload_mem_reserved = get_max_gpu_utilization(
394432
policy
395433
)
396434
print(
397-
f"Max GPU Utilization after offload: {after_offload_mem_allocated:,.1f} MB allocated, "
435+
f"Max GPU Utilization after offload with {num_gpus} GPU(s) and {config_name} config: {after_offload_mem_allocated:,.1f} MB allocated, "
398436
f"{after_offload_mem_reserved:,.1f} MB reserved"
399437
)
400438
tracker.track(
401-
f"after_offload_mem_allocated_{num_gpus}gpu", after_offload_mem_allocated
439+
f"{num_gpus}gpu_{config_name}_after_offload_mem_allocated",
440+
after_offload_mem_allocated,
402441
)
403442
tracker.track(
404-
f"after_offload_mem_reserved_{num_gpus}gpu", after_offload_mem_reserved
443+
f"{num_gpus}gpu_{config_name}_after_offload_mem_reserved",
444+
after_offload_mem_reserved,
405445
)
406446

407447
# Compare memory after offload to memory after training
408-
assert after_training_mem_allocated > 10_000, (
409-
"Memory after training should be more than 10GB"
410-
)
448+
if config_name == "fsdp_offload":
449+
# With FSDP offload, memory usage after training should already be low
450+
assert after_training_mem_allocated < 1_200, (
451+
"FSDP offload after training should be less than 1.2GB)"
452+
)
453+
else:
454+
assert after_training_mem_allocated > 10_000, (
455+
f"Memory after training with {config_name} config should be more than 10GB"
456+
)
457+
411458
assert after_offload_mem_allocated < 1_200, (
412459
"Memory after offload should be less than 1.2GB"
413460
)

0 commit comments

Comments
 (0)