|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +""" |
| 6 | +Assumption tests for KVBM connector's expectations of vLLM interfaces. |
| 7 | +
|
| 8 | +These unit tests validate that KVBM's assumptions about vLLM's internal |
| 9 | +interfaces remain stable across vLLM releases. They do NOT test functional |
| 10 | +correctness of KVBM or vLLM logic, but rather ensure the API contract remains |
| 11 | +intact to prevent silent breakage. |
| 12 | +
|
| 13 | +Inspired by vLLM's test_lmcache_integration.py approach to interface testing. |
| 14 | +""" |
| 15 | + |
| 16 | +from typing import Any |
| 17 | + |
| 18 | +import pytest |
| 19 | + |
| 20 | +# Skip if vLLM is not available |
| 21 | +pytest.importorskip("vllm", reason="vLLM not available") |
| 22 | + |
| 23 | +# ruff: noqa: E402 |
| 24 | +# Imports must be after pytest.importorskip() to handle missing vLLM gracefully |
| 25 | +from vllm.config import ( |
| 26 | + CacheConfig, |
| 27 | + KVTransferConfig, |
| 28 | + ModelConfig, |
| 29 | + ParallelConfig, |
| 30 | + VllmConfig, |
| 31 | +) |
| 32 | +from vllm.lora.request import LoRARequest |
| 33 | +from vllm.sampling_params import SamplingParams |
| 34 | +from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput |
| 35 | +from vllm.v1.request import Request |
| 36 | + |
| 37 | +# Test markers |
| 38 | +pytestmark = [ |
| 39 | + pytest.mark.kvbm, |
| 40 | + pytest.mark.gpu_0, |
| 41 | + pytest.mark.nightly, |
| 42 | + pytest.mark.pre_merge, |
| 43 | +] |
| 44 | + |
| 45 | + |
| 46 | +def _get_obj_name(obj: Any) -> str: |
| 47 | + """Get a readable name for an object (class name or repr).""" |
| 48 | + return getattr(obj, "__name__", None) or obj.__class__.__name__ |
| 49 | + |
| 50 | + |
| 51 | +def _assert_attr_exists(obj: Any, attr: str) -> None: |
| 52 | + """Assert that an attribute exists on an object or dataclass.""" |
| 53 | + obj_name = _get_obj_name(obj) |
| 54 | + # Check __dataclass_fields__ directly - works for both classes and instances, |
| 55 | + # and handles decorated dataclasses (e.g., @config @dataclass) |
| 56 | + dataclass_fields = getattr(obj, "__dataclass_fields__", None) |
| 57 | + if dataclass_fields is not None: |
| 58 | + assert attr in dataclass_fields, f"Dataclass {obj_name} missing field '{attr}'" |
| 59 | + else: |
| 60 | + assert hasattr(obj, attr), f"Object {obj_name} missing attribute '{attr}'" |
| 61 | + |
| 62 | + |
| 63 | +def _get_property_return_type(prop: property) -> Any: |
| 64 | + """Extract return type from a property's fget annotations.""" |
| 65 | + fget = prop.fget |
| 66 | + if fget is None or not hasattr(fget, "__annotations__"): |
| 67 | + return None |
| 68 | + annotations = fget.__annotations__ |
| 69 | + if "return" not in annotations: |
| 70 | + return None |
| 71 | + return_type = annotations["return"] |
| 72 | + # Handle Optional types by extracting the inner type |
| 73 | + if hasattr(return_type, "__origin__") and return_type.__origin__ is type(None): |
| 74 | + return_type = return_type.__args__[0] |
| 75 | + return return_type |
| 76 | + |
| 77 | + |
| 78 | +def _assert_instance_of(obj: Any, attr: str, value: Any, expected_type: Any) -> None: |
| 79 | + """Assert that value matches expected type, handling properties specially.""" |
| 80 | + prop = type(obj).__dict__.get(attr) |
| 81 | + |
| 82 | + if isinstance(prop, property): |
| 83 | + return_type = _get_property_return_type(prop) |
| 84 | + if return_type is not None: |
| 85 | + is_match = return_type == expected_type or ( |
| 86 | + isinstance(return_type, type) and issubclass(return_type, expected_type) |
| 87 | + ) |
| 88 | + assert ( |
| 89 | + is_match |
| 90 | + ), f"Property '{attr}' return type {return_type} is not {expected_type}" |
| 91 | + return |
| 92 | + |
| 93 | + assert isinstance( |
| 94 | + value, expected_type |
| 95 | + ), f"Attribute '{attr}' value {type(value)} is not instance of {expected_type}" |
| 96 | + |
| 97 | + |
| 98 | +def _get_type_origin(t: Any) -> Any: |
| 99 | + """Extract the origin type from a potentially parameterized generic. |
| 100 | +
|
| 101 | + e.g., list[int] -> list, set[str] -> set, dict[str, Any] -> dict |
| 102 | + """ |
| 103 | + origin = getattr(t, "__origin__", None) |
| 104 | + return origin if origin is not None else t |
| 105 | + |
| 106 | + |
| 107 | +def _check_dataclass_field_type(obj: type, attr: str, expected_type: Any) -> None: |
| 108 | + """Check dataclass field type annotation matches expected type.""" |
| 109 | + field = getattr(obj, "__dataclass_fields__")[attr] |
| 110 | + field_type = field.type |
| 111 | + |
| 112 | + # Handle generic types (e.g., list[int] -> list, set[str] -> set) |
| 113 | + field_type_origin = _get_type_origin(field_type) |
| 114 | + expected_type_origin = _get_type_origin(expected_type) |
| 115 | + |
| 116 | + obj_name = _get_obj_name(obj) |
| 117 | + |
| 118 | + # First check exact match (including parameterized generics) |
| 119 | + if field_type == expected_type: |
| 120 | + return |
| 121 | + |
| 122 | + # Then check origin types match (e.g., set[str] vs set[int] both have origin set) |
| 123 | + if field_type_origin == expected_type_origin: |
| 124 | + return |
| 125 | + |
| 126 | + # Finally check subclass relationship (only works with actual types, not generics) |
| 127 | + if isinstance(field_type_origin, type) and isinstance(expected_type_origin, type): |
| 128 | + if issubclass(field_type_origin, expected_type_origin): |
| 129 | + return |
| 130 | + |
| 131 | + raise AssertionError( |
| 132 | + f"Dataclass {obj_name}.{attr} type {field_type} is not {expected_type}" |
| 133 | + ) |
| 134 | + |
| 135 | + |
| 136 | +def assumes(obj: Any, attr: str, is_callable: bool = False, is_instance_of: Any = None): |
| 137 | + """ |
| 138 | + Helper function to validate interface assumptions. |
| 139 | +
|
| 140 | + Checks that an object has the expected attribute with correct type and callability. |
| 141 | + Used to guard against breaking changes in vLLM's internal interfaces. |
| 142 | +
|
| 143 | + Args: |
| 144 | + obj: The object to check |
| 145 | + attr: The attribute name to validate |
| 146 | + is_callable: If True, verify the attribute is callable |
| 147 | + is_instance_of: If provided, verify the attribute is an instance of this type |
| 148 | + """ |
| 149 | + _assert_attr_exists(obj, attr) |
| 150 | + |
| 151 | + # For dataclass classes (not instances), fields with default_factory don't exist |
| 152 | + # as class attributes, so check field type annotation instead of getattr |
| 153 | + dataclass_fields = getattr(obj, "__dataclass_fields__", None) |
| 154 | + is_dataclass_class = dataclass_fields is not None and isinstance(obj, type) |
| 155 | + |
| 156 | + if is_dataclass_class: |
| 157 | + if is_instance_of is not None: |
| 158 | + _check_dataclass_field_type(obj, attr, is_instance_of) |
| 159 | + # Note: is_callable check not supported for dataclass class fields |
| 160 | + return |
| 161 | + |
| 162 | + value = getattr(obj, attr) |
| 163 | + |
| 164 | + if is_callable: |
| 165 | + assert callable( |
| 166 | + value |
| 167 | + ), f"Attribute '{attr}' on {_get_obj_name(obj)} is not callable" |
| 168 | + |
| 169 | + if is_instance_of is not None: |
| 170 | + _assert_instance_of(obj, attr, value, is_instance_of) |
| 171 | + |
| 172 | + |
| 173 | +def test_config_interface(): |
| 174 | + assumes(VllmConfig, "model_config") |
| 175 | + assumes(VllmConfig, "cache_config") |
| 176 | + assumes(VllmConfig, "parallel_config") |
| 177 | + assumes(VllmConfig, "kv_transfer_config") |
| 178 | + assumes(VllmConfig, "kv_events_config") |
| 179 | + |
| 180 | + assumes(KVTransferConfig, "kv_role") |
| 181 | + assumes(KVTransferConfig, "kv_load_failure_policy") |
| 182 | + assumes(KVTransferConfig, "kv_connector_module_path") |
| 183 | + assumes(KVTransferConfig, "engine_id") |
| 184 | + assumes(KVTransferConfig, "kv_connector") |
| 185 | + assumes(KVTransferConfig, "kv_connector_extra_config") |
| 186 | + |
| 187 | + assumes(ModelConfig, "dtype") |
| 188 | + |
| 189 | + assumes(ParallelConfig, "world_size") |
| 190 | + assumes(ParallelConfig, "data_parallel_rank") |
| 191 | + |
| 192 | + assumes(CacheConfig, "cache_dtype") |
| 193 | + assumes(CacheConfig, "block_size") |
| 194 | + assumes(CacheConfig, "gpu_memory_utilization") |
| 195 | + assumes(CacheConfig, "enable_prefix_caching") |
| 196 | + |
| 197 | + |
| 198 | +def test_scheduler_output_interface(): |
| 199 | + """ |
| 200 | + Test SchedulerOutput interface expectations for KVBM vLLM integration. |
| 201 | + Protects against interface changes in vLLM's SchedulerOutput object. |
| 202 | + """ |
| 203 | + assumes(SchedulerOutput, "finished_req_ids", is_instance_of=set[str]) |
| 204 | + assumes(SchedulerOutput, "scheduled_new_reqs", is_instance_of=list[NewRequestData]) |
| 205 | + assumes(SchedulerOutput, "num_scheduled_tokens", is_instance_of=dict) |
| 206 | + assumes(SchedulerOutput, "total_num_scheduled_tokens") |
| 207 | + |
| 208 | + |
| 209 | +def test_request_interface(): |
| 210 | + """ |
| 211 | + Test Request interface expectations for KVBM vLLM integration. |
| 212 | + Protects against interface changes in vLLM's Request object. |
| 213 | + """ |
| 214 | + req = Request( |
| 215 | + request_id="test_request", |
| 216 | + prompt_token_ids=[1, 2, 3], |
| 217 | + sampling_params=SamplingParams(max_tokens=10), |
| 218 | + pooling_params=None, |
| 219 | + eos_token_id=100, |
| 220 | + lora_request=LoRARequest( |
| 221 | + lora_name="test_lora", lora_int_id=1, lora_path="test_path" |
| 222 | + ), |
| 223 | + cache_salt="test_salt", |
| 224 | + ) |
| 225 | + |
| 226 | + assumes(req, "request_id", is_instance_of=str) |
| 227 | + assumes(req, "all_token_ids") # ConstantList |
| 228 | + assumes(req, "num_tokens", is_instance_of=int) |
| 229 | + assumes(req, "num_computed_tokens", is_instance_of=int) |
| 230 | + assumes(req, "cache_salt", is_instance_of=str) |
| 231 | + assumes(req, "lora_request", is_instance_of=LoRARequest) |
| 232 | + assumes(req, "priority", is_instance_of=int) |
| 233 | + assumes(req, "sampling_params", is_instance_of=SamplingParams) |
| 234 | + |
| 235 | + |
| 236 | +def test_new_request_interface(): |
| 237 | + """ |
| 238 | + Test NewRequestData interface expectations for KVBM vLLM integration. |
| 239 | + Protects against interface changes in vLLM's NewRequestData object. |
| 240 | + """ |
| 241 | + assumes(NewRequestData, "req_id", is_instance_of=str) |
| 242 | + assumes(NewRequestData, "block_ids", is_instance_of=tuple[list[int], ...]) |
| 243 | + assumes(NewRequestData, "prompt_token_ids", is_instance_of=(list[int] | None)) |
| 244 | + assumes(NewRequestData, "num_computed_tokens", is_instance_of=int) |
| 245 | + |
| 246 | + |
| 247 | +def test_cached_request_interface(): |
| 248 | + assumes(CachedRequestData, "resumed_req_ids", is_instance_of=set[str]) |
| 249 | + assumes(CachedRequestData, "req_ids", is_instance_of=list[str]) |
| 250 | + assumes(CachedRequestData, "new_token_ids", is_instance_of=list[list[int]]) |
| 251 | + assumes( |
| 252 | + CachedRequestData, |
| 253 | + "new_block_ids", |
| 254 | + is_instance_of=list[tuple[list[int], ...] | None], |
| 255 | + ) |
| 256 | + assumes(CachedRequestData, "num_computed_tokens", is_instance_of=list[int]) |
0 commit comments