@@ -623,29 +623,16 @@ def configure_worker_fixed_seed(num_gpus, bundle_indices=None):
623623 torch .cuda .empty_cache ()
624624
625625
626- @pytest .mark .timeout (360 )
627- @pytest .mark .asyncio
628- @pytest .mark .parametrize ("async_engine" , [True , False ])
629- async def test_vllm_generation_with_hf_training (cluster , tokenizer , async_engine ):
630- """1. Use vLLM for generation
631- 2. Use HF policy for training and logprob computation
626+ async def run_hf_train_process (
627+ lm_policy , vllm_policy , tokenizer , async_engine , colocated
628+ ):
629+ """Validates that the two policies can work together.
632630
633- This test validates that the two policies can work together.
631+ 1. Use vLLM for generation
632+ 2. Use HF policy for training and logprob computation
634633 """
635- from nemo_rl .models .policy .lm_policy import Policy
636634 from tests .unit .test_utils import SimpleNLLLoss
637635
638- # Create separate configs for each policy
639- vllm_config = deepcopy (basic_vllm_test_config )
640- vllm_config ["vllm_cfg" ]["async_engine" ] = async_engine
641- vllm_config = configure_generation_config (vllm_config , tokenizer )
642-
643- dtensor_config = deepcopy (basic_dtensor_test_config )
644- dtensor_config ["train_global_batch_size" ] = 4
645-
646- vllm_policy = None
647- lm_policy = None
648-
649636 try :
650637 prompts = [
651638 "Write a story about a magical forest" ,
@@ -677,22 +664,8 @@ async def test_vllm_generation_with_hf_training(cluster, tokenizer, async_engine
677664 }
678665 )
679666
680- # Create both policies
681- print ("Creating vLLM policy..." )
682- vllm_policy = VllmGeneration (cluster , vllm_config )
683- vllm_policy .finish_generation ()
684-
685- print ("Creating DTensor policy..." )
686- lm_policy = Policy (cluster , dtensor_config , tokenizer )
687-
688- print ("preparing refit info..." )
689- state_dict_info = lm_policy .prepare_refit_info ()
690- vllm_policy .prepare_refit_info (state_dict_info )
691-
692667 print ("refitting vllm policy..." )
693- refit_policy_generation (
694- lm_policy , vllm_policy , vllm_config ["colocated" ]["enabled" ]
695- )
668+ refit_policy_generation (lm_policy , vllm_policy , colocated )
696669
697670 # Step 1: Use vLLM for generation
698671 print ("Using vLLM policy for fast generation..." )
@@ -794,7 +767,7 @@ async def test_vllm_generation_with_hf_training(cluster, tokenizer, async_engine
794767 print (f"Training loss: { results ['loss' ]} " )
795768
796769 lm_policy .finish_training ()
797- lm_policy . offload_after_refit ( )
770+ refit_policy_generation ( lm_policy , vllm_policy , colocated )
798771
799772 # Step 4: Use vLLM for generation again to complete the workflow
800773 print ("Using vLLM for generation again..." )
@@ -821,6 +794,82 @@ async def test_vllm_generation_with_hf_training(cluster, tokenizer, async_engine
821794 lm_policy .shutdown ()
822795
823796
797+ @pytest .mark .timeout (300 )
798+ @pytest .mark .asyncio
799+ @pytest .mark .parametrize (
800+ ("async_engine" , "cpu_offload" ), [(True , False ), (False , True )]
801+ )
802+ async def test_vllm_generation_with_hf_training_colocated (
803+ cluster , tokenizer , async_engine , cpu_offload
804+ ):
805+ """This test validates that DTensor policy can work together with colocated vLLM policy."""
806+ # Create VllmGeneration Policy
807+ print ("Creating vLLM policy..." )
808+ vllm_config = deepcopy (basic_vllm_test_config )
809+ vllm_config ["vllm_cfg" ]["async_engine" ] = async_engine
810+ vllm_config = configure_generation_config (vllm_config , tokenizer )
811+ vllm_policy = VllmGeneration (cluster , vllm_config )
812+ vllm_policy .finish_generation ()
813+
814+ # Create Policy
815+ print ("Creating DTensor policy..." )
816+ dtensor_config = deepcopy (basic_dtensor_test_config )
817+ dtensor_config ["dtensor_cfg" ]["cpu_offload" ] = cpu_offload
818+ dtensor_config ["train_global_batch_size" ] = 4
819+ lm_policy = Policy (cluster , dtensor_config , tokenizer )
820+
821+ # Prepare refit info
822+ print ("Preparing refit info..." )
823+ state_dict_info = lm_policy .prepare_refit_info ()
824+ vllm_policy .prepare_refit_info (state_dict_info )
825+
826+ # Test
827+ await run_hf_train_process (lm_policy , vllm_policy , tokenizer , async_engine , True )
828+
829+
830+ @pytest .mark .timeout (300 )
831+ @pytest .mark .asyncio
832+ @pytest .mark .parametrize (
833+ ("async_engine" , "cpu_offload" ), [(True , False ), (False , True )]
834+ )
835+ async def test_vllm_generation_with_hf_training_non_colocated (
836+ policy_cluster_separate , tokenizer , async_engine , cpu_offload
837+ ):
838+ """This test validates that DTensor policy can work together with non-colocated vLLM policy."""
839+ generation_cluster_separate = get_generation_cluster_separate (1 )
840+
841+ # Create VllmGeneration Policy
842+ print ("Creating vLLM policy..." )
843+ vllm_config = deepcopy (basic_vllm_test_config )
844+ vllm_config ["vllm_cfg" ]["async_engine" ] = async_engine
845+ vllm_config ["colocated" ]["enabled" ] = False
846+ vllm_config = configure_generation_config (vllm_config , tokenizer )
847+ vllm_policy = VllmGeneration (generation_cluster_separate , vllm_config )
848+ vllm_policy .finish_generation ()
849+
850+ # Create Policy
851+ print ("Creating DTensor policy..." )
852+ dtensor_config = deepcopy (basic_dtensor_test_config )
853+ dtensor_config ["generation" ]["colocated" ]["enabled" ] = False
854+ dtensor_config ["dtensor_cfg" ]["cpu_offload" ] = cpu_offload
855+ dtensor_config ["train_global_batch_size" ] = 4
856+ lm_policy = Policy (policy_cluster_separate , dtensor_config , tokenizer )
857+
858+ # Refit
859+ # initialize collective communication for update weights
860+ ip , port = policy_cluster_separate .get_master_address_and_port ()
861+ futures_train = lm_policy .init_collective (ip , port , world_size = 2 )
862+ futures_inference = vllm_policy .init_collective (ip , port , world_size = 2 )
863+ ray .get (futures_train + futures_inference )
864+
865+ # prepare refit info
866+ state_dict_info = lm_policy .prepare_refit_info ()
867+ vllm_policy .prepare_refit_info (state_dict_info )
868+
869+ # Test
870+ await run_hf_train_process (lm_policy , vllm_policy , tokenizer , async_engine , False )
871+
872+
824873def test_vllm_policy_tensor_parallel (cluster , tokenizer ):
825874 """Test vLLM policy with tensor parallelism > 1."""
826875 # Configure with tensor_parallel_size=2
0 commit comments