|
| 1 | +import os |
| 2 | +import pytest |
| 3 | +import torch |
| 4 | + |
| 5 | + |
| 6 | +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for Triton kernels") |
| 7 | + |
| 8 | + |
| 9 | +def _randn_like(shape, dtype=torch.float32, device="cuda"): |
| 10 | + torch.manual_seed(1234) |
| 11 | + return torch.randn(shape, dtype=dtype, device=device).contiguous() |
| 12 | + |
| 13 | + |
| 14 | +def _gather_along_size(t, idx): |
| 15 | + return t.index_select(dim=1, index=idx) |
| 16 | + |
| 17 | + |
| 18 | +@pytest.mark.parametrize("head_dim", [64, 96]) |
| 19 | +@pytest.mark.parametrize( |
| 20 | + "kv_head_num,tp_world_size", |
| 21 | + [ |
| 22 | + (4, 1), |
| 23 | + (4, 2), |
| 24 | + (8, 1), |
| 25 | + (8, 2), |
| 26 | + (8, 4), |
| 27 | + (16, 2), |
| 28 | + ], |
| 29 | +) |
| 30 | +def test_page_io_roundtrip_with_tp(head_dim, kv_head_num, tp_world_size): |
| 31 | + from lightllm.common.kv_trans_kernel.nixl_kv_trans import page_io |
| 32 | + |
| 33 | + device = "cuda" |
| 34 | + dtype = torch.bfloat16 |
| 35 | + |
| 36 | + # Shapes |
| 37 | + layer_num = 3 |
| 38 | + size = 32 |
| 39 | + assert kv_head_num % 2 == 0 |
| 40 | + page_size = 32 |
| 41 | + page_head_num = kv_head_num * tp_world_size |
| 42 | + |
| 43 | + tp_indices = list(range(tp_world_size)) |
| 44 | + |
| 45 | + kv_buffer = _randn_like((layer_num, size, kv_head_num, head_dim), dtype=dtype, device=device) |
| 46 | + page_tensor = torch.zeros((page_size, layer_num, page_head_num, head_dim), dtype=dtype, device=device).contiguous() |
| 47 | + |
| 48 | + # Select a handful of token positions to move (tid count <= page_size) |
| 49 | + mem_indexes = torch.tensor([2, 5, 7, 9, 11], dtype=torch.int64, device=device).contiguous() |
| 50 | + |
| 51 | + # Write: kv_buffer -> page_tensor, done by all tp ranks to fill their partition |
| 52 | + for tp_index in tp_indices: |
| 53 | + page_io( |
| 54 | + mem_indexes=mem_indexes, |
| 55 | + page_tensor=page_tensor, |
| 56 | + kv_buffer=kv_buffer, |
| 57 | + tp_index=tp_index, |
| 58 | + tp_world_size=tp_world_size, |
| 59 | + mode="write", |
| 60 | + ) |
| 61 | + |
| 62 | + # After-write expectation check |
| 63 | + token_num = mem_indexes.numel() |
| 64 | + k_head_num = kv_head_num // 2 |
| 65 | + page_k_head_num = page_head_num // 2 |
| 66 | + page_write_head_num = page_k_head_num // tp_world_size |
| 67 | + |
| 68 | + expected_page = torch.zeros_like(page_tensor) |
| 69 | + for tid in range(token_num): |
| 70 | + mem_index = int(mem_indexes[tid].item()) |
| 71 | + for layer_index in range(layer_num): |
| 72 | + for tp_index in tp_indices: |
| 73 | + page_head_start = tp_index * page_write_head_num |
| 74 | + for kv_head_id in range(page_write_head_num): |
| 75 | + # K half |
| 76 | + expected_page[tid, layer_index, page_head_start + kv_head_id, :] = kv_buffer[ |
| 77 | + layer_index, mem_index, kv_head_id, : |
| 78 | + ] |
| 79 | + # V half |
| 80 | + expected_page[tid, layer_index, page_k_head_num + page_head_start + kv_head_id, :] = kv_buffer[ |
| 81 | + layer_index, mem_index, k_head_num + kv_head_id, : |
| 82 | + ] |
| 83 | + |
| 84 | + assert torch.allclose(page_tensor[:token_num], expected_page[:token_num], atol=1e-3, rtol=1e-3) |
| 85 | + |
| 86 | + # Read back to a fresh buffer |
| 87 | + kv_buffer_rt = torch.zeros_like(kv_buffer) |
| 88 | + for tp_index in tp_indices: |
| 89 | + page_io( |
| 90 | + mem_indexes=mem_indexes, |
| 91 | + page_tensor=page_tensor, |
| 92 | + kv_buffer=kv_buffer_rt, |
| 93 | + tp_index=tp_index, |
| 94 | + tp_world_size=tp_world_size, |
| 95 | + mode="read", |
| 96 | + ) |
| 97 | + |
| 98 | + # Check equality only at selected positions along size-dim |
| 99 | + ref = _gather_along_size(kv_buffer, mem_indexes) |
| 100 | + out = _gather_along_size(kv_buffer_rt, mem_indexes) |
| 101 | + assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3) |
| 102 | + |
| 103 | + |
| 104 | +@pytest.mark.parametrize("head_dim", [32, 80]) # include non-power-of-two |
| 105 | +def test_mla_page_io_roundtrip(head_dim): |
| 106 | + from lightllm.common.kv_trans_kernel.nixl_kv_trans import mla_page_io |
| 107 | + |
| 108 | + device = "cuda" |
| 109 | + dtype = torch.bfloat16 |
| 110 | + |
| 111 | + # Shapes (single-head) |
| 112 | + layer_num = 2 |
| 113 | + size = 20 |
| 114 | + kv_head_num = 1 |
| 115 | + page_size = 10 |
| 116 | + page_head_num = 1 |
| 117 | + |
| 118 | + kv_buffer = _randn_like((layer_num, size, kv_head_num, head_dim), dtype=dtype, device=device) |
| 119 | + page_tensor = torch.zeros((page_size, layer_num, page_head_num, head_dim), dtype=dtype, device=device).contiguous() |
| 120 | + |
| 121 | + mem_indexes = torch.tensor([0, 3, 6, 7], dtype=torch.int64, device=device).contiguous() |
| 122 | + |
| 123 | + # Write kv -> page |
| 124 | + mla_page_io(mem_indexes=mem_indexes, page_tensor=page_tensor, kv_buffer=kv_buffer, mode="write") |
| 125 | + |
| 126 | + # After-write expectation check |
| 127 | + token_num = mem_indexes.numel() |
| 128 | + expected_page = torch.zeros_like(page_tensor) |
| 129 | + for tid in range(token_num): |
| 130 | + mem_index = int(mem_indexes[tid].item()) |
| 131 | + for layer_index in range(layer_num): |
| 132 | + expected_page[tid, layer_index, 0, :] = kv_buffer[layer_index, mem_index, 0, :] |
| 133 | + assert torch.allclose(page_tensor[:token_num], expected_page[:token_num], atol=1e-3, rtol=1e-3) |
| 134 | + |
| 135 | + # Read back page -> kv |
| 136 | + kv_buffer_rt = torch.zeros_like(kv_buffer) |
| 137 | + mla_page_io(mem_indexes=mem_indexes, page_tensor=page_tensor, kv_buffer=kv_buffer_rt, mode="read") |
| 138 | + |
| 139 | + ref = kv_buffer.index_select(dim=1, index=mem_indexes) |
| 140 | + out = kv_buffer_rt.index_select(dim=1, index=mem_indexes) |
| 141 | + assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3) |
0 commit comments