diff --git a/tests/integration_tests/test_vllm_policy_correctness.py b/tests/integration_tests/test_vllm_policy_correctness.py index b512591ba..e2da9b068 100644 --- a/tests/integration_tests/test_vllm_policy_correctness.py +++ b/tests/integration_tests/test_vllm_policy_correctness.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import asyncio +import pytest from forge.actors.policy import Policy from vllm import SamplingParams @@ -28,6 +28,7 @@ N_SAMPLES = 1 +@pytest.mark.asyncio async def test_same_output(): """Compare outputs between vLLM and Policy service""" test_prompts = [ @@ -96,15 +97,14 @@ async def test_same_output(): for vllm_output, policy_output in zip(vllm_outputs, policy_outputs): assert vllm_output != "" assert policy_output != "" - if vllm_output != policy_output: - print(f"āŒ Got different results: {vllm_output} vs. {policy_output}") - print("āœ… Outputs are the same!") + assert vllm_output == policy_output finally: if policy is not None: await policy.shutdown() +@pytest.mark.asyncio async def test_cache_usage(): """Test that KV cache usage is consistent between vLLM and Policy service. @@ -232,16 +232,8 @@ async def test_cache_usage(): for vllm_output, policy_output in zip(vllm_outputs, policy_outputs): assert vllm_output != "" assert policy_output != "" - if vllm_output != policy_output: - print(f"āŒ Got different results: {vllm_output} vs. {policy_output}") - - print("\nāœ… Prefix cache usage is the same!") + assert vllm_output == policy_output finally: if policy is not None: await policy.shutdown() - - -if __name__ == "__main__": - asyncio.run(test_same_output()) - asyncio.run(test_cache_usage())