|
| 1 | +diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py |
| 2 | +index 884e715f..14e5df33 100644 |
| 3 | +--- a/python/sglang/srt/utils.py |
| 4 | ++++ b/python/sglang/srt/utils.py |
| 5 | +@@ -77,12 +77,20 @@ from torch.func import functional_call |
| 6 | + from torch.library import Library |
| 7 | + from torch.profiler import ProfilerActivity, profile, record_function |
| 8 | + from torch.utils._contextlib import _DecoratorContextManager |
| 9 | +-from triton.runtime.cache import ( |
| 10 | +- FileCacheManager, |
| 11 | +- default_cache_dir, |
| 12 | +- default_dump_dir, |
| 13 | +- default_override_dir, |
| 14 | +-) |
| 15 | ++try: |
| 16 | ++ from triton.runtime.cache import ( |
| 17 | ++ FileCacheManager, |
| 18 | ++ default_cache_dir, |
| 19 | ++ default_dump_dir, |
| 20 | ++ default_override_dir, |
| 21 | ++ ) |
| 22 | ++except ImportError: |
| 23 | ++ from triton.runtime.cache import FileCacheManager |
| 24 | ++ from triton.knobs import cache as tt_cache |
| 25 | ++ |
| 26 | ++ default_cache_dir = lambda: tt_cache.dir |
| 27 | ++ default_dump_dir = lambda: tt_cache.dump_dir |
| 28 | ++ default_override_dir = lambda: tt_cache.override_dir |
| 29 | + |
| 30 | + logger = logging.getLogger(__name__) |
| 31 | + |
| 32 | +@@ -156,6 +164,18 @@ def is_xpu() -> bool: |
| 33 | + def is_npu() -> bool: |
| 34 | + return hasattr(torch, "npu") and torch.npu.is_available() |
| 35 | + |
| 36 | ++def infer_device(): |
| 37 | ++ """ |
| 38 | ++ Infer the device type based on the current environment. |
| 39 | ++ """ |
| 40 | ++ if is_cuda_alike(): |
| 41 | ++ return "cuda" |
| 42 | ++ elif is_xpu(): |
| 43 | ++ return "xpu" |
| 44 | ++ elif is_hpu(): |
| 45 | ++ return "hpu" |
| 46 | ++ else: |
| 47 | ++ return "cpu" |
| 48 | + |
| 49 | + def is_flashinfer_available(): |
| 50 | + """ |
| 51 | +diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py |
| 52 | +index 47eb16a9..9d6a0af0 100644 |
| 53 | +--- a/test/srt/test_triton_attention_kernels.py |
| 54 | ++++ b/test/srt/test_triton_attention_kernels.py |
| 55 | +@@ -16,8 +16,11 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import ( |
| 56 | + context_attention_fwd, |
| 57 | + ) |
| 58 | + from sglang.test.test_utils import CustomTestCase |
| 59 | ++from sglang.srt.utils import infer_device |
| 60 | + |
| 61 | + |
| 62 | ++device = infer_device() |
| 63 | ++ |
| 64 | + class TestTritonAttention(CustomTestCase): |
| 65 | + |
| 66 | + def _set_all_seeds(self, seed): |
| 67 | +@@ -37,24 +40,24 @@ class TestTritonAttention(CustomTestCase): |
| 68 | + dtype = torch.bfloat16 |
| 69 | + |
| 70 | + b_seq_len_prefix = torch.randint( |
| 71 | +- 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda" |
| 72 | ++ 1, N_CTX // 2, (B,), dtype=torch.int32, device=device |
| 73 | + ) |
| 74 | + b_seq_len_extend = torch.randint( |
| 75 | +- 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda" |
| 76 | ++ 1, N_CTX // 2, (B,), dtype=torch.int32, device=device |
| 77 | + ) |
| 78 | + b_seq_len = b_seq_len_prefix + b_seq_len_extend |
| 79 | + max_len_in_batch = torch.max(b_seq_len, 0)[0].item() |
| 80 | + |
| 81 | +- b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda") |
| 82 | +- b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda") |
| 83 | ++ b_req_idx = torch.arange(B, dtype=torch.int32, device=device) |
| 84 | ++ b_start_loc = torch.zeros((B,), dtype=torch.int32, device=device) |
| 85 | + b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) |
| 86 | +- b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda") |
| 87 | ++ b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device=device) |
| 88 | + b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) |
| 89 | + |
| 90 | +- kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") |
| 91 | ++ kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device) |
| 92 | + kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0) |
| 93 | + kv_indices = torch.zeros( |
| 94 | +- (b_seq_len_prefix.sum().item(),), dtype=torch.int32, device="cuda" |
| 95 | ++ (b_seq_len_prefix.sum().item(),), dtype=torch.int32, device=device |
| 96 | + ) |
| 97 | + |
| 98 | + for i in range(B): |
| 99 | +@@ -65,15 +68,15 @@ class TestTritonAttention(CustomTestCase): |
| 100 | + total_token_num = torch.sum(b_seq_len).item() |
| 101 | + extend_token_num = torch.sum(b_seq_len_extend).item() |
| 102 | + k_buffer = torch.empty( |
| 103 | +- (total_token_num, H_KV, D), dtype=dtype, device="cuda" |
| 104 | ++ (total_token_num, H_KV, D), dtype=dtype, device=device |
| 105 | + ).normal_(mean=0.1, std=0.2) |
| 106 | + v_buffer = torch.empty( |
| 107 | +- (total_token_num, H_KV, D), dtype=dtype, device="cuda" |
| 108 | ++ (total_token_num, H_KV, D), dtype=dtype, device=device |
| 109 | + ).normal_(mean=0.1, std=0.2) |
| 110 | + |
| 111 | +- k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") |
| 112 | +- v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") |
| 113 | +- q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") |
| 114 | ++ k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device) |
| 115 | ++ v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device) |
| 116 | ++ q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device) |
| 117 | + for i in range(B): |
| 118 | + extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] |
| 119 | + extend_end_in_buffer = b_start_loc[i] + b_seq_len[i] |
| 120 | +@@ -86,20 +89,20 @@ class TestTritonAttention(CustomTestCase): |
| 121 | + extend_start_in_buffer:extend_end_in_buffer |
| 122 | + ] |
| 123 | + q_extend[extend_start:extend_end] = torch.empty( |
| 124 | +- (b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda" |
| 125 | ++ (b_seq_len_extend[i], H_Q, D), dtype=dtype, device=device |
| 126 | + ).normal_(mean=0.1, std=0.2) |
| 127 | + |
| 128 | +- o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") |
| 129 | ++ o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device) |
| 130 | + o_extend_mask = torch.empty( |
| 131 | +- (extend_token_num, H_Q, D), dtype=dtype, device="cuda" |
| 132 | ++ (extend_token_num, H_Q, D), dtype=dtype, device=device |
| 133 | + ) |
| 134 | + o_redundant = torch.empty( |
| 135 | +- (extend_token_num, H_Q, D), dtype=dtype, device="cuda" |
| 136 | ++ (extend_token_num, H_Q, D), dtype=dtype, device=device |
| 137 | + ) |
| 138 | + |
| 139 | + b_seq_len_extend = b_seq_len - b_seq_len_prefix |
| 140 | + max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() |
| 141 | +- qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") |
| 142 | ++ qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device) |
| 143 | + qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0) |
| 144 | + |
| 145 | + custom_mask = None |
| 146 | +@@ -123,9 +126,9 @@ class TestTritonAttention(CustomTestCase): |
| 147 | + |
| 148 | + b_seq_mask_len = b_seq_len_extend * b_seq_len |
| 149 | + custom_mask = torch.ones( |
| 150 | +- (b_seq_mask_len.sum().item(),), dtype=torch.bool, device="cuda" |
| 151 | ++ (b_seq_mask_len.sum().item(),), dtype=torch.bool, device=device |
| 152 | + ) |
| 153 | +- mask_indptr = torch.zeros((B + 1,), dtype=torch.int64, device="cuda") |
| 154 | ++ mask_indptr = torch.zeros((B + 1,), dtype=torch.int64, device=device) |
| 155 | + mask_indptr[1 : B + 1] = torch.cumsum(b_seq_mask_len[:B], dim=0) |
| 156 | + for i in range(B): |
| 157 | + causal_mask = ( |
| 158 | +@@ -187,14 +190,14 @@ class TestTritonAttention(CustomTestCase): |
| 159 | + max_seq_len = max(seq_lens) |
| 160 | + |
| 161 | + # Create random input tensors |
| 162 | +- q = torch.randn(sum(seq_lens), num_heads, head_dim, device="cuda") |
| 163 | +- k = torch.randn(sum(seq_lens), num_heads, head_dim, device="cuda") |
| 164 | +- v = torch.randn(sum(seq_lens), num_heads, head_dim, device="cuda") |
| 165 | +- o = torch.zeros(sum(seq_lens), num_heads, head_dim, device="cuda") |
| 166 | ++ q = torch.randn(sum(seq_lens), num_heads, head_dim, device=device) |
| 167 | ++ k = torch.randn(sum(seq_lens), num_heads, head_dim, device=device) |
| 168 | ++ v = torch.randn(sum(seq_lens), num_heads, head_dim, device=device) |
| 169 | ++ o = torch.zeros(sum(seq_lens), num_heads, head_dim, device=device) |
| 170 | + |
| 171 | + # Create b_start_loc and b_seq_len tensors |
| 172 | +- b_start_loc = torch.tensor([0, seq_lens[0]], device="cuda") |
| 173 | +- b_seq_len = torch.tensor(seq_lens, device="cuda") |
| 174 | ++ b_start_loc = torch.tensor([0, seq_lens[0]], device=device) |
| 175 | ++ b_seq_len = torch.tensor(seq_lens, device=device) |
| 176 | + |
| 177 | + context_attention_fwd( |
| 178 | + q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=is_causal |
| 179 | +@@ -232,33 +235,33 @@ class TestTritonAttention(CustomTestCase): |
| 180 | + total_tokens = B * seq_len |
| 181 | + sm_scale = 1.0 / (D**0.5) |
| 182 | + max_kv_splits = 8 |
| 183 | +- num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda") |
| 184 | ++ num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device=device) |
| 185 | + |
| 186 | + # q represents the new token being generated, one per batch |
| 187 | +- q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda") |
| 188 | ++ q = torch.randn(B, H_Q, D, dtype=dtype, device=device) |
| 189 | + |
| 190 | + # k_buffer and v_buffer represent all previous tokens |
| 191 | +- k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda") |
| 192 | +- v_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda") |
| 193 | ++ k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=device) |
| 194 | ++ v_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=device) |
| 195 | + |
| 196 | + # o will have the same shape as q |
| 197 | +- o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda") |
| 198 | ++ o = torch.zeros(B, H_Q, D, dtype=dtype, device=device) |
| 199 | + |
| 200 | +- b_seq_len = torch.full((B,), seq_len, device="cuda") |
| 201 | ++ b_seq_len = torch.full((B,), seq_len, device=device) |
| 202 | + |
| 203 | +- kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") |
| 204 | ++ kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device) |
| 205 | + kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len[:B], dim=0) |
| 206 | +- kv_indices = torch.arange(total_tokens, device="cuda") |
| 207 | ++ kv_indices = torch.arange(total_tokens, device=device) |
| 208 | + |
| 209 | + attn_logits = torch.empty( |
| 210 | + (B, H_Q, max_kv_splits, D), |
| 211 | + dtype=torch.float32, |
| 212 | +- device="cuda", |
| 213 | ++ device=device, |
| 214 | + ) |
| 215 | + attn_lse = torch.empty( |
| 216 | + (B, H_Q, max_kv_splits), |
| 217 | + dtype=torch.float32, |
| 218 | +- device="cuda", |
| 219 | ++ device=device, |
| 220 | + ) |
| 221 | + |
| 222 | + decode_attention_fwd( |
| 223 | +@@ -296,34 +299,34 @@ class TestTritonAttention(CustomTestCase): |
| 224 | + total_tokens = B * seq_len |
| 225 | + sm_scale = 1.0 / (D**0.5) |
| 226 | + max_kv_splits = 8 |
| 227 | +- num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda") |
| 228 | ++ num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device=device) |
| 229 | + |
| 230 | + # q represents the new token being generated, one per batch |
| 231 | +- q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda") |
| 232 | ++ q = torch.randn(B, H_Q, D, dtype=dtype, device=device) |
| 233 | + |
| 234 | + # k_buffer and v_buffer represent all previous tokens |
| 235 | +- k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda") |
| 236 | +- v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device="cuda") |
| 237 | ++ k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=device) |
| 238 | ++ v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device=device) |
| 239 | + |
| 240 | + # o will have the same shape as q |
| 241 | +- o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") |
| 242 | +- o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") |
| 243 | ++ o = torch.zeros(B, H_Q, D_V, dtype=dtype, device=device) |
| 244 | ++ o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device=device) |
| 245 | + |
| 246 | +- b_seq_len = torch.full((B,), seq_len, device="cuda") |
| 247 | ++ b_seq_len = torch.full((B,), seq_len, device=device) |
| 248 | + |
| 249 | +- kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") |
| 250 | ++ kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device) |
| 251 | + kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len[:B], dim=0) |
| 252 | +- kv_indices = torch.arange(total_tokens, device="cuda") |
| 253 | ++ kv_indices = torch.arange(total_tokens, device=device) |
| 254 | + |
| 255 | + attn_logits = torch.empty( |
| 256 | + (B, H_Q, max_kv_splits, D_V), |
| 257 | + dtype=torch.float32, |
| 258 | +- device="cuda", |
| 259 | ++ device=device, |
| 260 | + ) |
| 261 | + attn_lse = torch.empty( |
| 262 | + (B, H_Q, max_kv_splits), |
| 263 | + dtype=torch.float32, |
| 264 | +- device="cuda", |
| 265 | ++ device=device, |
| 266 | + ) |
| 267 | + |
| 268 | + decode_attention_fwd_normal( |
| 269 | +@@ -343,12 +346,12 @@ class TestTritonAttention(CustomTestCase): |
| 270 | + attn_logits1 = torch.empty( |
| 271 | + (B, H_Q, max_kv_splits, D_V), |
| 272 | + dtype=torch.float32, |
| 273 | +- device="cuda", |
| 274 | ++ device=device, |
| 275 | + ) |
| 276 | + attn_lse1 = torch.empty( |
| 277 | + (B, H_Q, max_kv_splits, D_V), |
| 278 | + dtype=torch.float32, |
| 279 | +- device="cuda", |
| 280 | ++ device=device, |
| 281 | + ) |
| 282 | + |
| 283 | + decode_attention_fwd_grouped( |
0 commit comments