|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3 | 3 |
|
4 |
| -import random |
5 |
| - |
6 | 4 | import numpy as np
|
7 | 5 | import pytest
|
8 | 6 | import torch
|
@@ -409,29 +407,30 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
|
409 | 407 | model_runner.model_config.get_head_size()
|
410 | 408 | ]
|
411 | 409 | # TODO mla test
|
412 |
| - default_stride = list(range(5)) |
| 410 | + default_stride = tuple(range(5)) |
413 | 411 | # Permutation that gets you back to expected kv shape
|
414 |
| - rnd_stride = tuple(random.sample(default_stride, len(default_stride))) |
415 |
| - |
416 |
| - def rnd_stride_order(): |
417 |
| - return rnd_stride |
418 |
| - |
419 |
| - # Patch the attention backend class and re-trigger the KV cache creation. |
420 |
| - for attn_group in model_runner._attn_group_iterator(): |
421 |
| - attn_backend = attn_group.backend |
422 |
| - monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order", |
423 |
| - rnd_stride_order) |
424 |
| - |
425 |
| - model_runner.attn_groups = [] |
426 |
| - model_runner.initialize_kv_cache(model_runner.kv_cache_config) |
427 |
| - |
428 |
| - # Shape is unchanged, but layout may differ |
429 |
| - kv_cache_shape = model_runner.kv_caches[0].shape |
430 |
| - assert list(kv_cache_shape) == expected_kv_cache_shape |
431 |
| - if default_stride == rnd_stride: |
432 |
| - assert all(kv.is_contiguous() for kv in model_runner.kv_caches) |
433 |
| - else: |
434 |
| - assert all(not kv.is_contiguous() for kv in model_runner.kv_caches) |
| 412 | + for test_stride in ((1, 4, 0, 2, 3), (0, 1, 2, 3, 4)): |
| 413 | + |
| 414 | + def rnd_stride_order(test_stride=test_stride): |
| 415 | + return test_stride |
| 416 | + |
| 417 | + # Patch the attention backend class and re-trigger the KV cache creation |
| 418 | + for attn_group in model_runner._attn_group_iterator(): |
| 419 | + attn_backend = attn_group.backend |
| 420 | + monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order", |
| 421 | + rnd_stride_order) |
| 422 | + |
| 423 | + model_runner.attn_groups = [] |
| 424 | + model_runner.kv_caches = [] |
| 425 | + model_runner.initialize_kv_cache(model_runner.kv_cache_config) |
| 426 | + |
| 427 | + # Shape is unchanged, but layout may differ |
| 428 | + kv_cache_shape = model_runner.kv_caches[0].shape |
| 429 | + assert list(kv_cache_shape) == expected_kv_cache_shape |
| 430 | + if default_stride == test_stride: |
| 431 | + assert all(kv.is_contiguous() for kv in model_runner.kv_caches) |
| 432 | + else: |
| 433 | + assert all(not kv.is_contiguous() for kv in model_runner.kv_caches) |
435 | 434 |
|
436 | 435 |
|
437 | 436 | def test_update_config(model_runner):
|
|
0 commit comments