Skip to content

Commit 561a0ba

Browse files
authored
[CI] Fix flaky test v1/worker/test_gpu_model_runner.py::test_kv_cache_stride_order (vllm-project#24640)
Signed-off-by: Chen Zhang <[email protected]>
1 parent f592b31 commit 561a0ba

File tree

1 file changed

+23
-24
lines changed

1 file changed

+23
-24
lines changed

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import random
5-
64
import numpy as np
75
import pytest
86
import torch
@@ -409,29 +407,30 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
409407
model_runner.model_config.get_head_size()
410408
]
411409
# TODO mla test
412-
default_stride = list(range(5))
410+
default_stride = tuple(range(5))
413411
# Permutation that gets you back to expected kv shape
414-
rnd_stride = tuple(random.sample(default_stride, len(default_stride)))
415-
416-
def rnd_stride_order():
417-
return rnd_stride
418-
419-
# Patch the attention backend class and re-trigger the KV cache creation.
420-
for attn_group in model_runner._attn_group_iterator():
421-
attn_backend = attn_group.backend
422-
monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order",
423-
rnd_stride_order)
424-
425-
model_runner.attn_groups = []
426-
model_runner.initialize_kv_cache(model_runner.kv_cache_config)
427-
428-
# Shape is unchanged, but layout may differ
429-
kv_cache_shape = model_runner.kv_caches[0].shape
430-
assert list(kv_cache_shape) == expected_kv_cache_shape
431-
if default_stride == rnd_stride:
432-
assert all(kv.is_contiguous() for kv in model_runner.kv_caches)
433-
else:
434-
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)
412+
for test_stride in ((1, 4, 0, 2, 3), (0, 1, 2, 3, 4)):
413+
414+
def rnd_stride_order(test_stride=test_stride):
415+
return test_stride
416+
417+
# Patch the attention backend class and re-trigger the KV cache creation
418+
for attn_group in model_runner._attn_group_iterator():
419+
attn_backend = attn_group.backend
420+
monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order",
421+
rnd_stride_order)
422+
423+
model_runner.attn_groups = []
424+
model_runner.kv_caches = []
425+
model_runner.initialize_kv_cache(model_runner.kv_cache_config)
426+
427+
# Shape is unchanged, but layout may differ
428+
kv_cache_shape = model_runner.kv_caches[0].shape
429+
assert list(kv_cache_shape) == expected_kv_cache_shape
430+
if default_stride == test_stride:
431+
assert all(kv.is_contiguous() for kv in model_runner.kv_caches)
432+
else:
433+
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)
435434

436435

437436
def test_update_config(model_runner):

0 commit comments

Comments
 (0)