@@ -415,6 +415,115 @@ def test_noncolocated_inference_requires_explicit_gpus_per_node_single_node():
415415 setup (master_config , tokenizer , dataset , None )
416416
417417
418+ def test_distillation_setup_non_colocated_smoke (monkeypatch ):
419+ """Smoke test: calling setup with a non-colocated config should succeed."""
420+ from unittest .mock import MagicMock , patch
421+
422+ import nemo_rl .algorithms .distillation as distil_mod
423+
424+ # Single node cluster; inference uses a subset of GPUs on same node
425+ master_config = {
426+ "policy" : {
427+ "generation" : {
428+ "backend" : "vllm" ,
429+ "colocated" : {
430+ "enabled" : False ,
431+ "resources" : {
432+ "gpus_per_node" : 8 , # inference on 8 GPU
433+ "num_nodes" : 1 ,
434+ },
435+ },
436+ },
437+ "dtensor_cfg" : {
438+ "enabled" : False ,
439+ },
440+ "model_name" : "test-policy" ,
441+ },
442+ "teacher" : {
443+ "model_name" : "test-teacher" ,
444+ "dtensor_cfg" : {
445+ "enabled" : False ,
446+ },
447+ },
448+ "loss_fn" : {
449+ "kl_type" : "forward" ,
450+ "mixed_kl_weight" : 0.5 ,
451+ "zero_outside_topk" : False ,
452+ },
453+ "distillation" : {
454+ "seed" : 42 ,
455+ "topk_logits_k" : 64 ,
456+ "num_prompts_per_step" : 1 ,
457+ "val_period" : 0 ,
458+ "val_at_start" : False ,
459+ },
460+ "data" : {"shuffle" : False },
461+ "logger" : {},
462+ "checkpointing" : {},
463+ "cluster" : {"num_nodes" : 2 , "gpus_per_node" : 8 },
464+ }
465+
466+ tokenizer = MagicMock ()
467+ dataset = MagicMock ()
468+ dataset .__len__ = MagicMock (return_value = 1 )
469+
470+ # Skip tokenizer/vocab equality check inside setup
471+ monkeypatch .setenv ("NRL_SKIP_DISTILLATION_TOKENIZER_CHECK" , "1" )
472+
473+ ip_port = ("127.0.0.1" , 12345 )
474+
475+ class DummyCluster :
476+ def __init__ (self , * args , ** kwargs ):
477+ pass
478+
479+ def world_size (self ):
480+ return 1
481+
482+ def get_master_address_and_port (self ):
483+ return ip_port
484+
485+ class DummyPolicy :
486+ def __init__ (self , * args , ** kwargs ):
487+ pass
488+
489+ def prepare_refit_info (self ):
490+ return {}
491+
492+ def init_collective (self , * args , ** kwargs ):
493+ return [MagicMock ()]
494+
495+ class DummyVllmGeneration :
496+ def __init__ (self , * args , ** kwargs ):
497+ pass
498+
499+ def finish_generation (self ):
500+ return None
501+
502+ def prepare_refit_info (self , * args , ** kwargs ):
503+ return None
504+
505+ def init_collective (self , * args , ** kwargs ):
506+ return [MagicMock ()]
507+
508+ with (
509+ patch .object (distil_mod , "RayVirtualCluster" , DummyCluster ),
510+ patch .object (distil_mod , "Logger" ),
511+ patch .object (distil_mod , "CheckpointManager" ) as mock_ckpt_mgr ,
512+ patch .object (distil_mod , "StatefulDataLoader" ),
513+ patch .object (distil_mod , "Policy" , DummyPolicy ),
514+ patch .object (distil_mod , "VllmGeneration" , DummyVllmGeneration ),
515+ patch .object (distil_mod , "ray" ) as mock_ray ,
516+ ):
517+ mock_ckpt_mgr .return_value .get_latest_checkpoint_path .return_value = None
518+ mock_ray .get = MagicMock (return_value = None )
519+
520+ # Should not raise
521+ result = distil_mod .setup (master_config , tokenizer , dataset , None )
522+
523+ # Basic shape check of returned tuple
524+ assert isinstance (result , tuple )
525+
526+
418527def test_noncolocated_inference_requires_explicit_gpus_per_node_multi_node ():
419528 """Test that non-colocated inference requires explicit gpus_per_node when cluster.num_nodes>1."""
420529 from unittest .mock import MagicMock , patch
0 commit comments