|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | | -import asyncio |
| 7 | +import pytest |
8 | 8 |
|
9 | 9 | from forge.actors.policy import Policy |
10 | 10 | from vllm import SamplingParams |
|
28 | 28 | N_SAMPLES = 1 |
29 | 29 |
|
30 | 30 |
|
| 31 | +@pytest.mark.asyncio |
31 | 32 | async def test_same_output(): |
32 | 33 | """Compare outputs between vLLM and Policy service""" |
33 | 34 | test_prompts = [ |
@@ -96,15 +97,14 @@ async def test_same_output(): |
96 | 97 | for vllm_output, policy_output in zip(vllm_outputs, policy_outputs): |
97 | 98 | assert vllm_output != "" |
98 | 99 | assert policy_output != "" |
99 | | - if vllm_output != policy_output: |
100 | | - print(f"❌ Got different results: {vllm_output} vs. {policy_output}") |
101 | | - print("✅ Outputs are the same!") |
| 100 | + assert vllm_output == policy_output |
102 | 101 |
|
103 | 102 | finally: |
104 | 103 | if policy is not None: |
105 | 104 | await policy.shutdown() |
106 | 105 |
|
107 | 106 |
|
| 107 | +@pytest.mark.asyncio |
108 | 108 | async def test_cache_usage(): |
109 | 109 | """Test that KV cache usage is consistent between vLLM and Policy service. |
110 | 110 |
|
@@ -232,16 +232,8 @@ async def test_cache_usage(): |
232 | 232 | for vllm_output, policy_output in zip(vllm_outputs, policy_outputs): |
233 | 233 | assert vllm_output != "" |
234 | 234 | assert policy_output != "" |
235 | | - if vllm_output != policy_output: |
236 | | - print(f"❌ Got different results: {vllm_output} vs. {policy_output}") |
237 | | - |
238 | | - print("\n✅ Prefix cache usage is the same!") |
| 235 | + assert vllm_output == policy_output |
239 | 236 |
|
240 | 237 | finally: |
241 | 238 | if policy is not None: |
242 | 239 | await policy.shutdown() |
243 | | - |
244 | | - |
245 | | -if __name__ == "__main__": |
246 | | - asyncio.run(test_same_output()) |
247 | | - asyncio.run(test_cache_usage()) |
0 commit comments