Skip to content

Commit 762917a

Browse files
DarkLight1337diegocastanibm
authored andcommitted
[Core] Use individual MM items in P0/P1 cache and model runner (vllm-project#22570)
Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: Diego-Castan <[email protected]>
1 parent dcbf193 commit 762917a

24 files changed

+548
-485
lines changed

tests/multimodal/test_utils.py

Lines changed: 79 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import mimetypes
66
import os
77
from tempfile import NamedTemporaryFile, TemporaryDirectory
8-
from typing import TYPE_CHECKING, NamedTuple, Optional
8+
from typing import TYPE_CHECKING, NamedTuple
99

1010
import numpy as np
1111
import pytest
@@ -19,14 +19,12 @@
1919
initialize_model_parallel)
2020
from vllm.multimodal.image import convert_image_mode
2121
from vllm.multimodal.inputs import PlaceholderRange
22-
from vllm.multimodal.utils import (MediaConnector,
23-
merge_and_sort_multimodal_metadata,
22+
from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions,
2423
run_dp_sharded_vision_model)
2524
from vllm.platforms import current_platform
2625
from vllm.utils import get_open_port, update_environment_variables
2726

2827
if TYPE_CHECKING:
29-
from vllm.multimodal.hasher import MultiModalHashDict
3028
from vllm.multimodal.inputs import MultiModalPlaceholderDict
3129

3230
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
@@ -178,54 +176,45 @@ async def test_fetch_video_http(video_url: str, num_frames: int):
178176
assert metadata_sync == metadata_async
179177

180178

181-
# Used for the next two tests related to `merge_and_sort_multimodal_metadata`.
179+
# Used for `test_argsort_mm_positions`.
182180
class TestCase(NamedTuple):
183181
mm_positions: "MultiModalPlaceholderDict"
184-
mm_hashes: Optional["MultiModalHashDict"]
185-
expected_modalities: list[str]
186-
expected_ranges: list[PlaceholderRange]
187-
expected_hashes: Optional[list[str]]
182+
expected_modality_idxs: list[tuple[str, int]]
188183

189184

190-
def test_merge_and_sort_multimodal_metadata():
185+
def test_argsort_mm_positions():
191186

192187
test_cases = [
193-
# Single modality should return result as is but flattened
188+
# Single modality
189+
## Internally sorted
194190
TestCase(
195191
mm_positions={
196192
"image": [
197193
PlaceholderRange(offset=0, length=2),
198194
PlaceholderRange(offset=3, length=2),
199195
]
200196
},
201-
mm_hashes={"image": ["hash1", "hash2"]},
202-
expected_modalities=["image", "image"],
203-
expected_ranges=[
204-
PlaceholderRange(offset=0, length=2),
205-
PlaceholderRange(offset=3, length=2),
197+
expected_modality_idxs=[
198+
("image", 0),
199+
("image", 1),
206200
],
207-
expected_hashes=["hash1", "hash2"],
208201
),
209-
210-
# Single modality without hashes return None for mm hash.
202+
## Internally unsorted
211203
TestCase(
212204
mm_positions={
213205
"image": [
206+
PlaceholderRange(offset=3, length=2),
214207
PlaceholderRange(offset=0, length=2),
215-
PlaceholderRange(offset=2, length=2),
216208
]
217209
},
218-
mm_hashes=None,
219-
expected_modalities=["image", "image"],
220-
expected_ranges=[
221-
PlaceholderRange(offset=0, length=2),
222-
PlaceholderRange(offset=2, length=2),
210+
expected_modality_idxs=[
211+
("image", 1),
212+
("image", 0),
223213
],
224-
expected_hashes=None,
225214
),
226215

227-
# Multiple modalities with hashes should return sorted modalities
228-
# and flattened ranges and hashes.
216+
# Two modalities
217+
## Internally sorted
229218
TestCase(
230219
mm_positions={
231220
"image": [
@@ -237,47 +226,54 @@ def test_merge_and_sort_multimodal_metadata():
237226
PlaceholderRange(offset=2, length=3),
238227
]
239228
},
240-
mm_hashes={
241-
"image": ["image_hash1", "image_hash2"],
242-
"audio": ["audio_hash1", "audio_hash2"],
243-
},
244-
expected_modalities=["audio", "audio", "image", "image"],
245-
expected_ranges=[
246-
PlaceholderRange(offset=0, length=2),
247-
PlaceholderRange(offset=2, length=3),
248-
PlaceholderRange(offset=7, length=4),
249-
PlaceholderRange(offset=11, length=5),
229+
expected_modality_idxs=[
230+
("audio", 0),
231+
("audio", 1),
232+
("image", 0),
233+
("image", 1),
250234
],
251-
expected_hashes=[
252-
"audio_hash1", "audio_hash2", "image_hash1", "image_hash2"
235+
),
236+
## Interleaved, internally sorted
237+
TestCase(
238+
mm_positions={
239+
"image": [
240+
PlaceholderRange(offset=0, length=4),
241+
PlaceholderRange(offset=8, length=2),
242+
],
243+
"audio": [
244+
PlaceholderRange(offset=5, length=2),
245+
PlaceholderRange(offset=11, length=4),
246+
]
247+
},
248+
expected_modality_idxs=[
249+
("image", 0),
250+
("audio", 0),
251+
("image", 1),
252+
("audio", 1),
253253
],
254254
),
255-
256-
# Multiple modalities without hashes should return sorted modalities
257-
# and flattened ranges and None.
255+
## Interleaved, internally unsorted
258256
TestCase(
259257
mm_positions={
260258
"image": [
261-
PlaceholderRange(offset=7, length=4),
262-
PlaceholderRange(offset=11, length=5),
259+
PlaceholderRange(offset=8, length=2),
260+
PlaceholderRange(offset=0, length=4),
263261
],
264262
"audio": [
265-
PlaceholderRange(offset=0, length=2),
266-
PlaceholderRange(offset=2, length=3),
263+
PlaceholderRange(offset=11, length=4),
264+
PlaceholderRange(offset=5, length=2),
267265
]
268266
},
269-
mm_hashes=None,
270-
expected_modalities=["audio", "audio", "image", "image"],
271-
expected_ranges=[
272-
PlaceholderRange(offset=0, length=2),
273-
PlaceholderRange(offset=2, length=3),
274-
PlaceholderRange(offset=7, length=4),
275-
PlaceholderRange(offset=11, length=5),
267+
expected_modality_idxs=[
268+
("image", 1),
269+
("audio", 1),
270+
("image", 0),
271+
("audio", 0),
276272
],
277-
expected_hashes=None,
278273
),
279274

280275
# Three modalities
276+
## Internally sorted
281277
TestCase(
282278
mm_positions={
283279
"image": [
@@ -293,72 +289,16 @@ def test_merge_and_sort_multimodal_metadata():
293289
PlaceholderRange(offset=12, length=6),
294290
]
295291
},
296-
mm_hashes={
297-
"image": ["image_hash1", "image_hash2"],
298-
"audio": ["audio_hash1"],
299-
"video": ["video_hash1", "video_hash2", "video_hash3"]
300-
},
301-
expected_modalities=[
302-
"audio", "video", "video", "video", "image", "image"
303-
],
304-
expected_ranges=[
305-
PlaceholderRange(offset=0, length=2),
306-
PlaceholderRange(offset=3, length=4),
307-
PlaceholderRange(offset=7, length=5),
308-
PlaceholderRange(offset=12, length=6),
309-
PlaceholderRange(offset=15, length=7),
310-
PlaceholderRange(offset=22, length=8),
311-
],
312-
expected_hashes=[
313-
"audio_hash1", "video_hash1", "video_hash2", "video_hash3",
314-
"image_hash1", "image_hash2"
315-
],
316-
),
317-
]
318-
319-
for (mm_positions, mm_hashes, expected_modalities, expected_ranges,
320-
expected_hashes) in test_cases:
321-
modalities, ranges, hashes = merge_and_sort_multimodal_metadata(
322-
mm_positions, mm_hashes)
323-
324-
assert modalities == expected_modalities
325-
assert ranges == expected_ranges
326-
assert hashes == expected_hashes
327-
328-
329-
def test_merge_and_sort_multimodal_metadata_with_interleaving():
330-
331-
test_cases = [
332-
333-
# <image> <audio> <image> <audio>
334-
TestCase(
335-
mm_positions={
336-
"image": [
337-
PlaceholderRange(offset=0, length=4),
338-
PlaceholderRange(offset=8, length=2),
339-
],
340-
"audio": [
341-
PlaceholderRange(offset=5, length=2),
342-
PlaceholderRange(offset=11, length=4),
343-
]
344-
},
345-
mm_hashes={
346-
"image": ["image_hash1", "image_hash2"],
347-
"audio": ["audio_hash1", "audio_hash2"],
348-
},
349-
expected_modalities=["image", "audio", "image", "audio"],
350-
expected_ranges=[
351-
PlaceholderRange(offset=0, length=4),
352-
PlaceholderRange(offset=5, length=2),
353-
PlaceholderRange(offset=8, length=2),
354-
PlaceholderRange(offset=11, length=4),
355-
],
356-
expected_hashes=[
357-
"image_hash1", "audio_hash1", "image_hash2", "audio_hash2"
292+
expected_modality_idxs=[
293+
("audio", 0),
294+
("video", 0),
295+
("video", 1),
296+
("video", 2),
297+
("image", 0),
298+
("image", 1),
358299
],
359300
),
360-
361-
# <image> <image> <audio> <video> <image>
301+
## Interleaved, internally sorted
362302
TestCase(
363303
mm_positions={
364304
"image": [
@@ -373,58 +313,43 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving():
373313
PlaceholderRange(offset=8, length=5),
374314
]
375315
},
376-
mm_hashes=None,
377-
expected_modalities=["image", "image", "audio", "video", "image"],
378-
expected_ranges=[
379-
PlaceholderRange(offset=0, length=2),
380-
PlaceholderRange(offset=2, length=3),
381-
PlaceholderRange(offset=5, length=2),
382-
PlaceholderRange(offset=8, length=5),
383-
PlaceholderRange(offset=20, length=4),
316+
expected_modality_idxs=[
317+
("image", 0),
318+
("image", 1),
319+
("audio", 0),
320+
("video", 0),
321+
("image", 2),
384322
],
385-
expected_hashes=None,
386323
),
387-
388-
# <image> <audio> <video> <image> with hashes
324+
## Interleaved, internally sunorted
389325
TestCase(
390326
mm_positions={
391327
"image": [
392328
PlaceholderRange(offset=0, length=2),
393-
PlaceholderRange(offset=18, length=4),
329+
PlaceholderRange(offset=20, length=4),
330+
PlaceholderRange(offset=2, length=3),
394331
],
395332
"audio": [
396-
PlaceholderRange(offset=6, length=2),
333+
PlaceholderRange(offset=5, length=2),
397334
],
398335
"video": [
399-
PlaceholderRange(offset=10, length=5),
336+
PlaceholderRange(offset=8, length=5),
400337
]
401338
},
402-
mm_hashes={
403-
"image": ["image_hash1", "image_hash2"],
404-
"audio": ["audio_hash1"],
405-
"video": ["video_hash1"],
406-
},
407-
expected_modalities=["image", "audio", "video", "image"],
408-
expected_ranges=[
409-
PlaceholderRange(offset=0, length=2),
410-
PlaceholderRange(offset=6, length=2),
411-
PlaceholderRange(offset=10, length=5),
412-
PlaceholderRange(offset=18, length=4),
413-
],
414-
expected_hashes=[
415-
"image_hash1", "audio_hash1", "video_hash1", "image_hash2"
339+
expected_modality_idxs=[
340+
("image", 0),
341+
("image", 2),
342+
("audio", 0),
343+
("video", 0),
344+
("image", 1),
416345
],
417346
),
418347
]
419348

420-
for (mm_positions, mm_hashes, expected_modalities, expected_ranges,
421-
expected_hashes) in test_cases:
422-
modalities, ranges, hashes = merge_and_sort_multimodal_metadata(
423-
mm_positions, mm_hashes)
349+
for mm_positions, expected_modality_idxs in test_cases:
350+
modality_idxs = argsort_mm_positions(mm_positions)
424351

425-
assert modalities == expected_modalities
426-
assert ranges == expected_ranges
427-
assert hashes == expected_hashes
352+
assert modality_idxs == expected_modality_idxs
428353

429354

430355
class SimpleLinearModel(torch.nn.Module):

0 commit comments

Comments
 (0)