Skip to content

Commit 84f92bf

Browse files
add test for distillation
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
1 parent 67cd851 commit 84f92bf

File tree

1 file changed

+109
-0
lines changed

1 file changed

+109
-0
lines changed

tests/unit/algorithms/test_distillation.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,115 @@ def test_noncolocated_inference_requires_explicit_gpus_per_node_single_node():
415415
setup(master_config, tokenizer, dataset, None)
416416

417417

418+
def test_distillation_setup_non_colocated_smoke(monkeypatch):
419+
"""Smoke test: calling setup with a non-colocated config should succeed."""
420+
from unittest.mock import MagicMock, patch
421+
422+
import nemo_rl.algorithms.distillation as distil_mod
423+
424+
# Single node cluster; inference uses a subset of GPUs on same node
425+
master_config = {
426+
"policy": {
427+
"generation": {
428+
"backend": "vllm",
429+
"colocated": {
430+
"enabled": False,
431+
"resources": {
432+
"gpus_per_node": 8, # inference on 8 GPU
433+
"num_nodes": 1,
434+
},
435+
},
436+
},
437+
"dtensor_cfg": {
438+
"enabled": False,
439+
},
440+
"model_name": "test-policy",
441+
},
442+
"teacher": {
443+
"model_name": "test-teacher",
444+
"dtensor_cfg": {
445+
"enabled": False,
446+
},
447+
},
448+
"loss_fn": {
449+
"kl_type": "forward",
450+
"mixed_kl_weight": 0.5,
451+
"zero_outside_topk": False,
452+
},
453+
"distillation": {
454+
"seed": 42,
455+
"topk_logits_k": 64,
456+
"num_prompts_per_step": 1,
457+
"val_period": 0,
458+
"val_at_start": False,
459+
},
460+
"data": {"shuffle": False},
461+
"logger": {},
462+
"checkpointing": {},
463+
"cluster": {"num_nodes": 2, "gpus_per_node": 8},
464+
}
465+
466+
tokenizer = MagicMock()
467+
dataset = MagicMock()
468+
dataset.__len__ = MagicMock(return_value=1)
469+
470+
# Skip tokenizer/vocab equality check inside setup
471+
monkeypatch.setenv("NRL_SKIP_DISTILLATION_TOKENIZER_CHECK", "1")
472+
473+
ip_port = ("127.0.0.1", 12345)
474+
475+
class DummyCluster:
476+
def __init__(self, *args, **kwargs):
477+
pass
478+
479+
def world_size(self):
480+
return 1
481+
482+
def get_master_address_and_port(self):
483+
return ip_port
484+
485+
class DummyPolicy:
486+
def __init__(self, *args, **kwargs):
487+
pass
488+
489+
def prepare_refit_info(self):
490+
return {}
491+
492+
def init_collective(self, *args, **kwargs):
493+
return [MagicMock()]
494+
495+
class DummyVllmGeneration:
496+
def __init__(self, *args, **kwargs):
497+
pass
498+
499+
def finish_generation(self):
500+
return None
501+
502+
def prepare_refit_info(self, *args, **kwargs):
503+
return None
504+
505+
def init_collective(self, *args, **kwargs):
506+
return [MagicMock()]
507+
508+
with (
509+
patch.object(distil_mod, "RayVirtualCluster", DummyCluster),
510+
patch.object(distil_mod, "Logger"),
511+
patch.object(distil_mod, "CheckpointManager") as mock_ckpt_mgr,
512+
patch.object(distil_mod, "StatefulDataLoader"),
513+
patch.object(distil_mod, "Policy", DummyPolicy),
514+
patch.object(distil_mod, "VllmGeneration", DummyVllmGeneration),
515+
patch.object(distil_mod, "ray") as mock_ray,
516+
):
517+
mock_ckpt_mgr.return_value.get_latest_checkpoint_path.return_value = None
518+
mock_ray.get = MagicMock(return_value=None)
519+
520+
# Should not raise
521+
result = distil_mod.setup(master_config, tokenizer, dataset, None)
522+
523+
# Basic shape check of returned tuple
524+
assert isinstance(result, tuple)
525+
526+
418527
def test_noncolocated_inference_requires_explicit_gpus_per_node_multi_node():
419528
"""Test that non-colocated inference requires explicit gpus_per_node when cluster.num_nodes>1."""
420529
from unittest.mock import MagicMock, patch

0 commit comments

Comments
 (0)