|
39 | 39 | }, |
40 | 40 | "dtype": "bfloat16", |
41 | 41 | "max_new_tokens": 5, # Small number of tokens for testing |
42 | | - "temperature": 0.8, |
| 42 | + # Set temperature=1.0 to ensure consistent probability scaling when comparing vLLM and HF policy outputs. |
| 43 | + # Note: greedy=True is only used in tests for deterministic behavior and not used in the real training. |
| 44 | + # In vLLM, enabling greedy=True disables temperature scaling (temperature is overridden to None). |
| 45 | + # The HF policy worker does not currently support greedy=True for get_logprobs. |
| 46 | + # Using temperature=1.0 allows us to meaningfully test the average probability multiplicative error between the two implementations, |
| 47 | + # while still maintaining the deterministic behavior. |
| 48 | + "temperature": 1.0, |
43 | 49 | "top_p": 1.0, |
44 | 50 | "top_k": None, |
45 | 51 | "stop_token_ids": None, |
@@ -326,6 +332,43 @@ def test_vllm_missing_required_config_key(cluster): |
326 | 332 | print(f"Successfully caught missing config key with error: {error_message}") |
327 | 333 |
|
328 | 334 |
|
| 335 | +def test_vllm_top_p_top_k_validation(cluster): |
| 336 | + """Test that top_p and top_k validation works correctly with threshold-based logic.""" |
| 337 | + # Test that values above thresholds are allowed |
| 338 | + config_above_thresholds = deepcopy(basic_vllm_test_config) |
| 339 | + config_above_thresholds["top_p"] = 0.99 # Above TOP_P_THRESHOLD |
| 340 | + config_above_thresholds["top_k"] = 8000 # Above TOP_K_THRESHOLD |
| 341 | + |
| 342 | + # Should not raise an error |
| 343 | + try: |
| 344 | + VllmGeneration(cluster, config_above_thresholds) |
| 345 | + print("Successfully initialized with top_p=0.99 and top_k=8000") |
| 346 | + except Exception as e: |
| 347 | + pytest.fail(f"Should not raise error with values above thresholds: {e}") |
| 348 | + |
| 349 | + # Test that values below thresholds are rejected |
| 350 | + config_below_thresholds = deepcopy(basic_vllm_test_config) |
| 351 | + config_below_thresholds["top_p"] = 0.9 # Below TOP_P_THRESHOLD |
| 352 | + |
| 353 | + with pytest.raises(ValueError) as excinfo: |
| 354 | + VllmGeneration(cluster, config_below_thresholds) |
| 355 | + |
| 356 | + error_message = str(excinfo.value) |
| 357 | + assert "top_p sampling with values < 0.99 is not supported" in error_message |
| 358 | + print(f"Successfully caught low top_p value with error: {error_message}") |
| 359 | + |
| 360 | + # Test that low top_k values are rejected |
| 361 | + config_low_top_k = deepcopy(basic_vllm_test_config) |
| 362 | + config_low_top_k["top_k"] = 7999 # Below TOP_K_THRESHOLD |
| 363 | + |
| 364 | + with pytest.raises(ValueError) as excinfo: |
| 365 | + VllmGeneration(cluster, config_low_top_k) |
| 366 | + |
| 367 | + error_message = str(excinfo.value) |
| 368 | + assert "top_k sampling with values < 8000 is not supported" in error_message |
| 369 | + print(f"Successfully caught low top_k value with error: {error_message}") |
| 370 | + |
| 371 | + |
329 | 372 | def test_vllm_policy_generation(policy, test_input_data, tokenizer): |
330 | 373 | """Test vLLM policy generation capabilities.""" |
331 | 374 | # Test generation |
|
0 commit comments