Skip to content

Commit 4f7cde7

Browse files
Adds json_count_leaves utility function (vllm-project#23899)
Signed-off-by: aditchawdhary <[email protected]>
1 parent 67c1490 commit 4f7cde7

File tree

3 files changed

+72
-10
lines changed

3 files changed

+72
-10
lines changed

tests/utils_/test_utils.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,9 @@ def test_duplicate_dict_args(caplog_vllm, parser):
379379
def test_supports_kw(callable,kw_name,requires_kw_only,
380380
allow_var_kwargs,is_supported):
381381
assert supports_kw(
382-
callable=callable,
383-
kw_name=kw_name,
384-
requires_kw_only=requires_kw_only,
382+
callable=callable,
383+
kw_name=kw_name,
384+
requires_kw_only=requires_kw_only,
385385
allow_var_kwargs=allow_var_kwargs
386386
) == is_supported
387387

@@ -948,6 +948,36 @@ def test_join_host_port():
948948
assert join_host_port("::1", 5555) == "[::1]:5555"
949949

950950

951+
def test_json_count_leaves():
952+
"""Test json_count_leaves function from jsontree utility."""
953+
from vllm.utils.jsontree import json_count_leaves
954+
955+
# Single leaf values
956+
assert json_count_leaves(42) == 1
957+
assert json_count_leaves("hello") == 1
958+
assert json_count_leaves(None) == 1
959+
960+
# Empty containers
961+
assert json_count_leaves([]) == 0
962+
assert json_count_leaves({}) == 0
963+
assert json_count_leaves(()) == 0
964+
965+
# Flat structures
966+
assert json_count_leaves([1, 2, 3]) == 3
967+
assert json_count_leaves({"a": 1, "b": 2}) == 2
968+
assert json_count_leaves((1, 2, 3)) == 3
969+
970+
# Nested structures
971+
nested_dict = {"a": 1, "b": {"c": 2, "d": 3}}
972+
assert json_count_leaves(nested_dict) == 3
973+
974+
nested_list = [1, [2, 3], 4]
975+
assert json_count_leaves(nested_list) == 4
976+
977+
mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4}
978+
assert json_count_leaves(mixed_nested) == 4
979+
980+
951981
def test_convert_ids_list_to_tokens():
952982
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
953983
token_ids = tokenizer.encode("Hello, world!")

vllm/multimodal/cache.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010

1111
from vllm.logger import init_logger
1212
from vllm.utils import GiB_bytes, LRUCache
13-
from vllm.utils.jsontree import json_map_leaves, json_reduce_leaves
13+
from vllm.utils.jsontree import (json_count_leaves, json_map_leaves,
14+
json_reduce_leaves)
1415

1516
from .inputs import (MultiModalFeatureSpec, MultiModalFieldElem,
1617
MultiModalKwargs, MultiModalKwargsItem,
@@ -127,11 +128,32 @@ def get_item_size(
127128
)
128129

129130
if debug:
130-
logger.debug("Calculated size of %s to be %.2f GiB", type(value),
131-
size / GiB_bytes)
131+
leaf_count = json_count_leaves(value)
132+
logger.debug(
133+
"Calculated size of %s to be %.2f GiB (%d leaves)",
134+
type(value),
135+
size / GiB_bytes,
136+
leaf_count,
137+
)
132138

133139
return size
134140

141+
@classmethod
142+
def get_item_complexity(cls, value: MultiModalCacheValue) -> int:
143+
"""
144+
Get the number of leaf elements in a multi-modal cache value.
145+
146+
This provides a measure of structural complexity that can be useful
147+
for debugging cache performance and understanding data patterns.
148+
149+
Args:
150+
value: The multi-modal cache value to analyze.
151+
152+
Returns:
153+
The number of leaf elements in the nested structure.
154+
"""
155+
return json_count_leaves(value)
156+
135157
@classmethod
136158
def get_lru_cache(
137159
cls,
@@ -184,7 +206,7 @@ def get_and_update_item(
184206
"""
185207
Possibly update a multi-modal item based on whether it is
186208
in the underlying cache.
187-
209+
188210
This update is done out-of-place and updates the cache eviction order.
189211
190212
Args:
@@ -262,7 +284,7 @@ def is_cached(self, mm_hashes: list[str]) -> list[bool]:
262284
in the underlying cache.
263285
264286
This **DOES NOT** update the cache eviction order.
265-
287+
266288
Args:
267289
mm_hashes: The hash of each item to check.
268290

vllm/utils/jsontree.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""Helper functions to work with nested JSON structures."""
4+
45
from collections.abc import Iterable
56
from functools import reduce
67
from typing import Callable, TypeVar, Union, overload
78

89
_T = TypeVar("_T")
910
_U = TypeVar("_U")
1011

11-
JSONTree = Union[dict[str, "JSONTree[_T]"], list["JSONTree[_T]"],
12-
tuple["JSONTree[_T]", ...], _T]
12+
JSONTree = Union[
13+
dict[str, "JSONTree[_T]"],
14+
list["JSONTree[_T]"],
15+
tuple["JSONTree[_T]", ...],
16+
_T,
17+
]
1318
"""A nested JSON structure where the leaves need not be JSON-serializable."""
1419

1520

@@ -78,3 +83,8 @@ def json_reduce_leaves(
7883
json_iter_leaves(value),
7984
initial,
8085
)
86+
87+
88+
def json_count_leaves(value: JSONTree[_T]) -> int:
89+
"""Count the number of leaves in a nested JSON structure."""
90+
return sum(1 for _ in json_iter_leaves(value))

0 commit comments

Comments
 (0)