Skip to content

Commit 1f5544a

Browse files
committed
Adding set of assumptions from vllm to protect against interface changes
1 parent e8c7bbf commit 1f5544a

File tree

1 file changed

+256
-0
lines changed

1 file changed

+256
-0
lines changed
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
#!/usr/bin/env python3
2+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""
6+
Assumption tests for KVBM connector's expectations of vLLM interfaces.
7+
8+
These unit tests validate that KVBM's assumptions about vLLM's internal
9+
interfaces remain stable across vLLM releases. They do NOT test functional
10+
correctness of KVBM or vLLM logic, but rather ensure the API contract remains
11+
intact to prevent silent breakage.
12+
13+
Inspired by vLLM's test_lmcache_integration.py approach to interface testing.
14+
"""
15+
16+
from typing import Any
17+
18+
import pytest
19+
20+
# Skip if vLLM is not available
21+
pytest.importorskip("vllm", reason="vLLM not available")
22+
23+
# ruff: noqa: E402
24+
# Imports must be after pytest.importorskip() to handle missing vLLM gracefully
25+
from vllm.config import (
26+
CacheConfig,
27+
KVTransferConfig,
28+
ModelConfig,
29+
ParallelConfig,
30+
VllmConfig,
31+
)
32+
from vllm.lora.request import LoRARequest
33+
from vllm.sampling_params import SamplingParams
34+
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
35+
from vllm.v1.request import Request
36+
37+
# Test markers
38+
pytestmark = [
39+
pytest.mark.kvbm,
40+
pytest.mark.gpu_0,
41+
pytest.mark.nightly,
42+
pytest.mark.pre_merge,
43+
]
44+
45+
46+
def _get_obj_name(obj: Any) -> str:
47+
"""Get a readable name for an object (class name or repr)."""
48+
return getattr(obj, "__name__", None) or obj.__class__.__name__
49+
50+
51+
def _assert_attr_exists(obj: Any, attr: str) -> None:
52+
"""Assert that an attribute exists on an object or dataclass."""
53+
obj_name = _get_obj_name(obj)
54+
# Check __dataclass_fields__ directly - works for both classes and instances,
55+
# and handles decorated dataclasses (e.g., @config @dataclass)
56+
dataclass_fields = getattr(obj, "__dataclass_fields__", None)
57+
if dataclass_fields is not None:
58+
assert attr in dataclass_fields, f"Dataclass {obj_name} missing field '{attr}'"
59+
else:
60+
assert hasattr(obj, attr), f"Object {obj_name} missing attribute '{attr}'"
61+
62+
63+
def _get_property_return_type(prop: property) -> Any:
64+
"""Extract return type from a property's fget annotations."""
65+
fget = prop.fget
66+
if fget is None or not hasattr(fget, "__annotations__"):
67+
return None
68+
annotations = fget.__annotations__
69+
if "return" not in annotations:
70+
return None
71+
return_type = annotations["return"]
72+
# Handle Optional types by extracting the inner type
73+
if hasattr(return_type, "__origin__") and return_type.__origin__ is type(None):
74+
return_type = return_type.__args__[0]
75+
return return_type
76+
77+
78+
def _assert_instance_of(obj: Any, attr: str, value: Any, expected_type: Any) -> None:
79+
"""Assert that value matches expected type, handling properties specially."""
80+
prop = type(obj).__dict__.get(attr)
81+
82+
if isinstance(prop, property):
83+
return_type = _get_property_return_type(prop)
84+
if return_type is not None:
85+
is_match = return_type == expected_type or (
86+
isinstance(return_type, type) and issubclass(return_type, expected_type)
87+
)
88+
assert (
89+
is_match
90+
), f"Property '{attr}' return type {return_type} is not {expected_type}"
91+
return
92+
93+
assert isinstance(
94+
value, expected_type
95+
), f"Attribute '{attr}' value {type(value)} is not instance of {expected_type}"
96+
97+
98+
def _get_type_origin(t: Any) -> Any:
99+
"""Extract the origin type from a potentially parameterized generic.
100+
101+
e.g., list[int] -> list, set[str] -> set, dict[str, Any] -> dict
102+
"""
103+
origin = getattr(t, "__origin__", None)
104+
return origin if origin is not None else t
105+
106+
107+
def _check_dataclass_field_type(obj: type, attr: str, expected_type: Any) -> None:
108+
"""Check dataclass field type annotation matches expected type."""
109+
field = getattr(obj, "__dataclass_fields__")[attr]
110+
field_type = field.type
111+
112+
# Handle generic types (e.g., list[int] -> list, set[str] -> set)
113+
field_type_origin = _get_type_origin(field_type)
114+
expected_type_origin = _get_type_origin(expected_type)
115+
116+
obj_name = _get_obj_name(obj)
117+
118+
# First check exact match (including parameterized generics)
119+
if field_type == expected_type:
120+
return
121+
122+
# Then check origin types match (e.g., set[str] vs set[int] both have origin set)
123+
if field_type_origin == expected_type_origin:
124+
return
125+
126+
# Finally check subclass relationship (only works with actual types, not generics)
127+
if isinstance(field_type_origin, type) and isinstance(expected_type_origin, type):
128+
if issubclass(field_type_origin, expected_type_origin):
129+
return
130+
131+
raise AssertionError(
132+
f"Dataclass {obj_name}.{attr} type {field_type} is not {expected_type}"
133+
)
134+
135+
136+
def assumes(obj: Any, attr: str, is_callable: bool = False, is_instance_of: Any = None):
137+
"""
138+
Helper function to validate interface assumptions.
139+
140+
Checks that an object has the expected attribute with correct type and callability.
141+
Used to guard against breaking changes in vLLM's internal interfaces.
142+
143+
Args:
144+
obj: The object to check
145+
attr: The attribute name to validate
146+
is_callable: If True, verify the attribute is callable
147+
is_instance_of: If provided, verify the attribute is an instance of this type
148+
"""
149+
_assert_attr_exists(obj, attr)
150+
151+
# For dataclass classes (not instances), fields with default_factory don't exist
152+
# as class attributes, so check field type annotation instead of getattr
153+
dataclass_fields = getattr(obj, "__dataclass_fields__", None)
154+
is_dataclass_class = dataclass_fields is not None and isinstance(obj, type)
155+
156+
if is_dataclass_class:
157+
if is_instance_of is not None:
158+
_check_dataclass_field_type(obj, attr, is_instance_of)
159+
# Note: is_callable check not supported for dataclass class fields
160+
return
161+
162+
value = getattr(obj, attr)
163+
164+
if is_callable:
165+
assert callable(
166+
value
167+
), f"Attribute '{attr}' on {_get_obj_name(obj)} is not callable"
168+
169+
if is_instance_of is not None:
170+
_assert_instance_of(obj, attr, value, is_instance_of)
171+
172+
173+
def test_config_interface():
174+
assumes(VllmConfig, "model_config")
175+
assumes(VllmConfig, "cache_config")
176+
assumes(VllmConfig, "parallel_config")
177+
assumes(VllmConfig, "kv_transfer_config")
178+
assumes(VllmConfig, "kv_events_config")
179+
180+
assumes(KVTransferConfig, "kv_role")
181+
assumes(KVTransferConfig, "kv_load_failure_policy")
182+
assumes(KVTransferConfig, "kv_connector_module_path")
183+
assumes(KVTransferConfig, "engine_id")
184+
assumes(KVTransferConfig, "kv_connector")
185+
assumes(KVTransferConfig, "kv_connector_extra_config")
186+
187+
assumes(ModelConfig, "dtype")
188+
189+
assumes(ParallelConfig, "world_size")
190+
assumes(ParallelConfig, "data_parallel_rank")
191+
192+
assumes(CacheConfig, "cache_dtype")
193+
assumes(CacheConfig, "block_size")
194+
assumes(CacheConfig, "gpu_memory_utilization")
195+
assumes(CacheConfig, "enable_prefix_caching")
196+
197+
198+
def test_scheduler_output_interface():
199+
"""
200+
Test SchedulerOutput interface expectations for KVBM vLLM integration.
201+
Protects against interface changes in vLLM's SchedulerOutput object.
202+
"""
203+
assumes(SchedulerOutput, "finished_req_ids", is_instance_of=set[str])
204+
assumes(SchedulerOutput, "scheduled_new_reqs", is_instance_of=list[NewRequestData])
205+
assumes(SchedulerOutput, "num_scheduled_tokens", is_instance_of=dict)
206+
assumes(SchedulerOutput, "total_num_scheduled_tokens")
207+
208+
209+
def test_request_interface():
210+
"""
211+
Test Request interface expectations for KVBM vLLM integration.
212+
Protects against interface changes in vLLM's Request object.
213+
"""
214+
req = Request(
215+
request_id="test_request",
216+
prompt_token_ids=[1, 2, 3],
217+
sampling_params=SamplingParams(max_tokens=10),
218+
pooling_params=None,
219+
eos_token_id=100,
220+
lora_request=LoRARequest(
221+
lora_name="test_lora", lora_int_id=1, lora_path="test_path"
222+
),
223+
cache_salt="test_salt",
224+
)
225+
226+
assumes(req, "request_id", is_instance_of=str)
227+
assumes(req, "all_token_ids") # ConstantList
228+
assumes(req, "num_tokens", is_instance_of=int)
229+
assumes(req, "num_computed_tokens", is_instance_of=int)
230+
assumes(req, "cache_salt", is_instance_of=str)
231+
assumes(req, "lora_request", is_instance_of=LoRARequest)
232+
assumes(req, "priority", is_instance_of=int)
233+
assumes(req, "sampling_params", is_instance_of=SamplingParams)
234+
235+
236+
def test_new_request_interface():
237+
"""
238+
Test NewRequestData interface expectations for KVBM vLLM integration.
239+
Protects against interface changes in vLLM's NewRequestData object.
240+
"""
241+
assumes(NewRequestData, "req_id", is_instance_of=str)
242+
assumes(NewRequestData, "block_ids", is_instance_of=tuple[list[int], ...])
243+
assumes(NewRequestData, "prompt_token_ids", is_instance_of=(list[int] | None))
244+
assumes(NewRequestData, "num_computed_tokens", is_instance_of=int)
245+
246+
247+
def test_cached_request_interface():
248+
assumes(CachedRequestData, "resumed_req_ids", is_instance_of=set[str])
249+
assumes(CachedRequestData, "req_ids", is_instance_of=list[str])
250+
assumes(CachedRequestData, "new_token_ids", is_instance_of=list[list[int]])
251+
assumes(
252+
CachedRequestData,
253+
"new_block_ids",
254+
is_instance_of=list[tuple[list[int], ...] | None],
255+
)
256+
assumes(CachedRequestData, "num_computed_tokens", is_instance_of=list[int])

0 commit comments

Comments
 (0)