|
25 | 25 |
|
26 | 26 | # Define basic vLLM test config |
27 | 27 | basic_vllm_test_config: VllmConfig = { |
| 28 | + "backend": "vllm", |
28 | 29 | "model_name": "meta-llama/Llama-3.2-1B", # Small model for testing |
29 | 30 | "dtype": "bfloat16", |
30 | 31 | "max_new_tokens": 10, |
|
39 | 40 | } |
40 | 41 |
|
41 | 42 |
|
| 43 | +def configure_vllm_with_tokenizer(vllm_config, tokenizer): |
| 44 | + """Apply tokenizer-specific configurations to vLLM config.""" |
| 45 | + vllm_config["vllm_cfg"]["skip_tokenizer_init"] = True |
| 46 | + vllm_config["vllm_cfg"]["load_format"] = "dummy" |
| 47 | + vllm_config["pad_token"] = tokenizer.pad_token_id |
| 48 | + vllm_config["stop_token_ids"] = [tokenizer.eos_token_id] |
| 49 | + return vllm_config |
| 50 | + |
| 51 | + |
42 | 52 | @pytest.fixture(scope="module") |
43 | 53 | def check_vllm_available(): |
44 | 54 | """Skip tests if vLLM is not installed.""" |
@@ -74,9 +84,12 @@ def tokenizer(): |
74 | 84 |
|
75 | 85 |
|
76 | 86 | @pytest.fixture(scope="function") |
77 | | -def policy(cluster, check_vllm_available): |
| 87 | +def policy(cluster, tokenizer, check_vllm_available): |
78 | 88 | """Initialize the vLLM policy.""" |
79 | | - policy = VllmGeneration(cluster, basic_vllm_test_config) |
| 89 | + # Create separate configs for each policy |
| 90 | + vllm_config = basic_vllm_test_config.copy() |
| 91 | + vllm_config = configure_vllm_with_tokenizer(vllm_config, tokenizer) |
| 92 | + policy = VllmGeneration(cluster, vllm_config) |
80 | 93 | yield policy |
81 | 94 |
|
82 | 95 | # Ensure policy is properly shutdown |
@@ -121,6 +134,30 @@ def test_input_data(tokenizer): |
121 | 134 | ) |
122 | 135 |
|
123 | 136 |
|
| 137 | +def test_vllm_missing_required_config_key(cluster, check_vllm_available): |
| 138 | + """Test that an assertion error is raised when a required config key is missing.""" |
| 139 | + # Create a config missing a required key by removing 'model_name' |
| 140 | + incomplete_config = basic_vllm_test_config.copy() |
| 141 | + del incomplete_config["model_name"] # Remove a required key |
| 142 | + |
| 143 | + # Also need to ensure skip_tokenizer_init and load_format are there |
| 144 | + # since these are checked in VllmConfig.__annotations__ |
| 145 | + incomplete_config["skip_tokenizer_init"] = True |
| 146 | + incomplete_config["load_format"] = "auto" |
| 147 | + |
| 148 | + # Attempt to initialize VllmGeneration with incomplete config - should raise AssertionError |
| 149 | + with pytest.raises(AssertionError) as excinfo: |
| 150 | + VllmGeneration(cluster, incomplete_config) |
| 151 | + |
| 152 | + # Verify the error message contains information about the missing key |
| 153 | + error_message = str(excinfo.value) |
| 154 | + assert "Missing required keys in VllmConfig" in error_message |
| 155 | + assert "model_name" in error_message, ( |
| 156 | + "Error should mention the missing 'model_name' key" |
| 157 | + ) |
| 158 | + print(f"Successfully caught missing config key with error: {error_message}") |
| 159 | + |
| 160 | + |
124 | 161 | def test_vllm_policy_generation(policy, test_input_data, tokenizer): |
125 | 162 | """Test vLLM policy generation capabilities.""" |
126 | 163 | # Test generation |
@@ -171,6 +208,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): |
171 | 208 |
|
172 | 209 | # Create separate configs for each policy |
173 | 210 | vllm_config = basic_vllm_test_config.copy() |
| 211 | + vllm_config = configure_vllm_with_tokenizer(vllm_config, tokenizer) |
174 | 212 |
|
175 | 213 | # Create HF-specific config with required parameters |
176 | 214 | hf_config = { |
@@ -359,6 +397,7 @@ def test_vllm_policy_tensor_parallel(cluster, tokenizer): |
359 | 397 | """Test vLLM policy with tensor parallelism > 1.""" |
360 | 398 | # Configure with tensor_parallel_size=2 |
361 | 399 | tp_config = basic_vllm_test_config.copy() |
| 400 | + tp_config = configure_vllm_with_tokenizer(tp_config, tokenizer) |
362 | 401 | tp_config["tensor_parallel_size"] = 2 |
363 | 402 |
|
364 | 403 | # Ensure we specify the distributed executor backend |
@@ -420,6 +459,7 @@ def test_vllm_policy_weight_update(cluster, tokenizer, tensor_parallel_size): |
420 | 459 |
|
421 | 460 | # Create separate configs for each policy |
422 | 461 | vllm_config = basic_vllm_test_config.copy() |
| 462 | + vllm_config = configure_vllm_with_tokenizer(vllm_config, tokenizer) |
423 | 463 | vllm_config["tensor_parallel_size"] = tensor_parallel_size |
424 | 464 |
|
425 | 465 | # Add vllm_kwargs only if using tensor parallelism |
|
0 commit comments