@@ -915,9 +915,8 @@ def test_vllm_generate_text(cluster, tokenizer):
915915
916916@pytest .mark .timeout (180 )
917917@pytest .mark .parametrize ("tensor_parallel_size" , [1 , 2 ])
918- @pytest .mark .parametrize ("enable_dtensor" , [True , False ])
919918def test_vllm_weight_update_and_prefix_cache_reset (
920- cluster , tokenizer , tensor_parallel_size , enable_dtensor
919+ cluster , tokenizer , tensor_parallel_size
921920):
922921 """Test that the vLLM prefix cache is correctly reset when weights change."""
923922 from nemo_rl .models .policy .lm_policy import Policy
@@ -1021,8 +1020,7 @@ def test_vllm_weight_update_and_prefix_cache_reset(
10211020 torch .cuda .empty_cache ()
10221021
10231022
1024- @pytest .mark .parametrize ("enable_dtensor" , [True , False ])
1025- def test_vllm_weight_update_memory (cluster , tokenizer , enable_dtensor ):
1023+ def test_vllm_weight_update_memory (cluster , tokenizer ):
10261024 """Test that vLLM streaming weight update and can save memory."""
10271025 from nemo_rl .models .policy .lm_policy import Policy
10281026
@@ -1081,23 +1079,16 @@ def test_vllm_weight_update_memory(cluster, tokenizer, enable_dtensor):
10811079 assert current_reserved == 0.0 , "Memory should be 0 after refit completed"
10821080 # memory threshold: memory during non-streaming weight update on 0.6B model on 2 GPUs
10831081 # memory during streaming weight update should less than this baseline threshold
1084- if enable_dtensor :
1085- assert peak_allocated < 4005 , "Peak allocated memory should < 4005 MB"
1086- assert peak_reserved < 4016 , "Peak reserved memory should < 4016 MB"
1087- else :
1088- assert peak_allocated < 5736 , "Peak allocated memory should < 5736 MB"
1089- assert peak_reserved < 5748 , "Peak reserved memory should < 5748 MB"
1082+ assert peak_allocated < 4005 , "Peak allocated memory should < 4005 MB"
1083+ assert peak_reserved < 4016 , "Peak reserved memory should < 4016 MB"
10901084
10911085 # Clean up
10921086 vllm_policy .shutdown ()
10931087 lm_policy .shutdown ()
10941088
10951089
10961090@pytest .mark .parametrize ("is_eval" , [True , False ])
1097- @pytest .mark .parametrize ("enable_dtensor" , [True , False ])
1098- def test_vllm_generation_with_stop (
1099- cluster , test_input_data , tokenizer , is_eval , enable_dtensor
1100- ):
1091+ def test_vllm_generation_with_stop (cluster , test_input_data , tokenizer , is_eval ):
11011092 """Test vLLM generation with stop."""
11021093 from nemo_rl .models .policy .lm_policy import Policy
11031094
0 commit comments