Skip to content

Commit 92e34af

Browse files
authored
[CI] Initial integrate sglang kernels to triton-test (#3901)
Initial integrate SGLang attention kernels test to `triton-test`. The sglang patch will be removed after upstream PR sgl-project/sglang#5278 land.
1 parent 5edbad5 commit 92e34af

File tree

3 files changed

+380
-1
lines changed

3 files changed

+380
-1
lines changed

.github/workflows/third-party-tests.yml renamed to .github/workflows/ligerkernels-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Third party tests
1+
name: Third party Liger Kernels tests
22

33
on:
44
workflow_dispatch:

.github/workflows/sglang-tests.yml

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
name: Third party SGLang tests
2+
3+
on:
4+
workflow_dispatch:
5+
inputs:
6+
runner_label:
7+
description: Runner label, keep empty for default
8+
type: string
9+
default: ""
10+
use_pyenv_python:
11+
description: Use Python built with pyenv
12+
type: boolean
13+
default: false
14+
schedule:
15+
# About midnight PST Sunday (UTC-8)
16+
- cron: "5 10 * * SUN"
17+
18+
19+
# Cancels in-progress PR runs when the PR is updated. Manual runs are never cancelled.
20+
concurrency:
21+
group: ${{ github.workflow }}-${{ github.event_name == 'workflow_dispatch' && github.run_id || github.event.pull_request.number || github.ref }}
22+
cancel-in-progress: true
23+
24+
permissions: read-all
25+
26+
env:
27+
PYTHON_VERSION: "3.10"
28+
TAG: ${{ inputs.tag || (github.event_name == 'pull_request' && format('pr-{0}', github.event.number)) || (github.event_name == 'schedule' && 'ci') || 'test' }}
29+
30+
jobs:
31+
build:
32+
name: Triton benchmarks
33+
runs-on:
34+
- linux
35+
- ${{ inputs.runner_label || 'rolling' }}
36+
timeout-minutes: 720
37+
defaults:
38+
run:
39+
shell: bash -noprofile --norc -eo pipefail -c "source /opt/intel/oneapi/setvars.sh > /dev/null; source {0}"
40+
steps:
41+
- name: Print inputs
42+
run: |
43+
cat <<EOF
44+
${{ toJSON(inputs) }}
45+
EOF
46+
47+
- name: Checkout repository
48+
uses: actions/checkout@v4
49+
50+
- name: Install Python
51+
if: ${{ !(inputs.use_pyenv_python || false) }}
52+
uses: actions/setup-python@v5
53+
with:
54+
python-version: ${{ env.PYTHON_VERSION }}
55+
56+
- name: Install Python (from pyenv) ${{ inputs.python_version }}
57+
if: ${{ inputs.use_pyenv_python }}
58+
uses: ./.github/actions/setup-pyenv-python
59+
with:
60+
python-version: ${{ env.PYTHON_VERSION }}
61+
62+
- name: Identify Python version
63+
run: |
64+
PYTHON_VERSION="$(python -c 'import sys; print(f"{sys.version_info[0]}.{ sys.version_info[1]}")')"
65+
echo "PYTHON_VERSION=$PYTHON_VERSION" | tee -a $GITHUB_ENV
66+
67+
- name: Install Python build dependencies
68+
run: |
69+
pip install wheel cmake
70+
71+
- name: Create reports dir
72+
run: |
73+
mkdir reports
74+
echo "REPORTS=$PWD/reports" >> $GITHUB_ENV
75+
76+
- name: Install SGLang
77+
id: install-sglang
78+
run: |
79+
git clone https://github.com/sgl-project/sglang.git
80+
cd sglang
81+
git apply ../benchmarks/third_party/sglang/sglang.patch
82+
pip install ./python[dev_xpu]
83+
84+
# Install Pytorch and Triton after SGLANG to ensure that the correct versions are used
85+
- name: Setup PyTorch
86+
uses: ./.github/actions/setup-pytorch
87+
88+
- name: Setup Triton
89+
uses: ./.github/actions/setup-triton
90+
91+
- name: Run SGLANG tests
92+
if: ${{ steps.install.outcome == 'success' && steps.install-sglang.outcome == 'success' && !cancelled() }}
93+
run: |
94+
pip install pytest pytest-xdist
95+
cd sglang
96+
pytest -vvv -n 4 test/srt/test_triton_attention_kernels.py
Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py
2+
index 884e715f..14e5df33 100644
3+
--- a/python/sglang/srt/utils.py
4+
+++ b/python/sglang/srt/utils.py
5+
@@ -77,12 +77,20 @@ from torch.func import functional_call
6+
from torch.library import Library
7+
from torch.profiler import ProfilerActivity, profile, record_function
8+
from torch.utils._contextlib import _DecoratorContextManager
9+
-from triton.runtime.cache import (
10+
- FileCacheManager,
11+
- default_cache_dir,
12+
- default_dump_dir,
13+
- default_override_dir,
14+
-)
15+
+try:
16+
+ from triton.runtime.cache import (
17+
+ FileCacheManager,
18+
+ default_cache_dir,
19+
+ default_dump_dir,
20+
+ default_override_dir,
21+
+ )
22+
+except ImportError:
23+
+ from triton.runtime.cache import FileCacheManager
24+
+ from triton.knobs import cache as tt_cache
25+
+
26+
+ default_cache_dir = lambda: tt_cache.dir
27+
+ default_dump_dir = lambda: tt_cache.dump_dir
28+
+ default_override_dir = lambda: tt_cache.override_dir
29+
30+
logger = logging.getLogger(__name__)
31+
32+
@@ -156,6 +164,18 @@ def is_xpu() -> bool:
33+
def is_npu() -> bool:
34+
return hasattr(torch, "npu") and torch.npu.is_available()
35+
36+
+def infer_device():
37+
+ """
38+
+ Infer the device type based on the current environment.
39+
+ """
40+
+ if is_cuda_alike():
41+
+ return "cuda"
42+
+ elif is_xpu():
43+
+ return "xpu"
44+
+ elif is_hpu():
45+
+ return "hpu"
46+
+ else:
47+
+ return "cpu"
48+
49+
def is_flashinfer_available():
50+
"""
51+
diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py
52+
index 47eb16a9..9d6a0af0 100644
53+
--- a/test/srt/test_triton_attention_kernels.py
54+
+++ b/test/srt/test_triton_attention_kernels.py
55+
@@ -16,8 +16,11 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
56+
context_attention_fwd,
57+
)
58+
from sglang.test.test_utils import CustomTestCase
59+
+from sglang.srt.utils import infer_device
60+
61+
62+
+device = infer_device()
63+
+
64+
class TestTritonAttention(CustomTestCase):
65+
66+
def _set_all_seeds(self, seed):
67+
@@ -37,24 +40,24 @@ class TestTritonAttention(CustomTestCase):
68+
dtype = torch.bfloat16
69+
70+
b_seq_len_prefix = torch.randint(
71+
- 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
72+
+ 1, N_CTX // 2, (B,), dtype=torch.int32, device=device
73+
)
74+
b_seq_len_extend = torch.randint(
75+
- 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
76+
+ 1, N_CTX // 2, (B,), dtype=torch.int32, device=device
77+
)
78+
b_seq_len = b_seq_len_prefix + b_seq_len_extend
79+
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
80+
81+
- b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda")
82+
- b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
83+
+ b_req_idx = torch.arange(B, dtype=torch.int32, device=device)
84+
+ b_start_loc = torch.zeros((B,), dtype=torch.int32, device=device)
85+
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
86+
- b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
87+
+ b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device=device)
88+
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
89+
90+
- kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
91+
+ kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
92+
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
93+
kv_indices = torch.zeros(
94+
- (b_seq_len_prefix.sum().item(),), dtype=torch.int32, device="cuda"
95+
+ (b_seq_len_prefix.sum().item(),), dtype=torch.int32, device=device
96+
)
97+
98+
for i in range(B):
99+
@@ -65,15 +68,15 @@ class TestTritonAttention(CustomTestCase):
100+
total_token_num = torch.sum(b_seq_len).item()
101+
extend_token_num = torch.sum(b_seq_len_extend).item()
102+
k_buffer = torch.empty(
103+
- (total_token_num, H_KV, D), dtype=dtype, device="cuda"
104+
+ (total_token_num, H_KV, D), dtype=dtype, device=device
105+
).normal_(mean=0.1, std=0.2)
106+
v_buffer = torch.empty(
107+
- (total_token_num, H_KV, D), dtype=dtype, device="cuda"
108+
+ (total_token_num, H_KV, D), dtype=dtype, device=device
109+
).normal_(mean=0.1, std=0.2)
110+
111+
- k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
112+
- v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
113+
- q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
114+
+ k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device)
115+
+ v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device)
116+
+ q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
117+
for i in range(B):
118+
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
119+
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
120+
@@ -86,20 +89,20 @@ class TestTritonAttention(CustomTestCase):
121+
extend_start_in_buffer:extend_end_in_buffer
122+
]
123+
q_extend[extend_start:extend_end] = torch.empty(
124+
- (b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda"
125+
+ (b_seq_len_extend[i], H_Q, D), dtype=dtype, device=device
126+
).normal_(mean=0.1, std=0.2)
127+
128+
- o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
129+
+ o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
130+
o_extend_mask = torch.empty(
131+
- (extend_token_num, H_Q, D), dtype=dtype, device="cuda"
132+
+ (extend_token_num, H_Q, D), dtype=dtype, device=device
133+
)
134+
o_redundant = torch.empty(
135+
- (extend_token_num, H_Q, D), dtype=dtype, device="cuda"
136+
+ (extend_token_num, H_Q, D), dtype=dtype, device=device
137+
)
138+
139+
b_seq_len_extend = b_seq_len - b_seq_len_prefix
140+
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
141+
- qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
142+
+ qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
143+
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
144+
145+
custom_mask = None
146+
@@ -123,9 +126,9 @@ class TestTritonAttention(CustomTestCase):
147+
148+
b_seq_mask_len = b_seq_len_extend * b_seq_len
149+
custom_mask = torch.ones(
150+
- (b_seq_mask_len.sum().item(),), dtype=torch.bool, device="cuda"
151+
+ (b_seq_mask_len.sum().item(),), dtype=torch.bool, device=device
152+
)
153+
- mask_indptr = torch.zeros((B + 1,), dtype=torch.int64, device="cuda")
154+
+ mask_indptr = torch.zeros((B + 1,), dtype=torch.int64, device=device)
155+
mask_indptr[1 : B + 1] = torch.cumsum(b_seq_mask_len[:B], dim=0)
156+
for i in range(B):
157+
causal_mask = (
158+
@@ -187,14 +190,14 @@ class TestTritonAttention(CustomTestCase):
159+
max_seq_len = max(seq_lens)
160+
161+
# Create random input tensors
162+
- q = torch.randn(sum(seq_lens), num_heads, head_dim, device="cuda")
163+
- k = torch.randn(sum(seq_lens), num_heads, head_dim, device="cuda")
164+
- v = torch.randn(sum(seq_lens), num_heads, head_dim, device="cuda")
165+
- o = torch.zeros(sum(seq_lens), num_heads, head_dim, device="cuda")
166+
+ q = torch.randn(sum(seq_lens), num_heads, head_dim, device=device)
167+
+ k = torch.randn(sum(seq_lens), num_heads, head_dim, device=device)
168+
+ v = torch.randn(sum(seq_lens), num_heads, head_dim, device=device)
169+
+ o = torch.zeros(sum(seq_lens), num_heads, head_dim, device=device)
170+
171+
# Create b_start_loc and b_seq_len tensors
172+
- b_start_loc = torch.tensor([0, seq_lens[0]], device="cuda")
173+
- b_seq_len = torch.tensor(seq_lens, device="cuda")
174+
+ b_start_loc = torch.tensor([0, seq_lens[0]], device=device)
175+
+ b_seq_len = torch.tensor(seq_lens, device=device)
176+
177+
context_attention_fwd(
178+
q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=is_causal
179+
@@ -232,33 +235,33 @@ class TestTritonAttention(CustomTestCase):
180+
total_tokens = B * seq_len
181+
sm_scale = 1.0 / (D**0.5)
182+
max_kv_splits = 8
183+
- num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda")
184+
+ num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device=device)
185+
186+
# q represents the new token being generated, one per batch
187+
- q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
188+
+ q = torch.randn(B, H_Q, D, dtype=dtype, device=device)
189+
190+
# k_buffer and v_buffer represent all previous tokens
191+
- k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
192+
- v_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
193+
+ k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=device)
194+
+ v_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=device)
195+
196+
# o will have the same shape as q
197+
- o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda")
198+
+ o = torch.zeros(B, H_Q, D, dtype=dtype, device=device)
199+
200+
- b_seq_len = torch.full((B,), seq_len, device="cuda")
201+
+ b_seq_len = torch.full((B,), seq_len, device=device)
202+
203+
- kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
204+
+ kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
205+
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len[:B], dim=0)
206+
- kv_indices = torch.arange(total_tokens, device="cuda")
207+
+ kv_indices = torch.arange(total_tokens, device=device)
208+
209+
attn_logits = torch.empty(
210+
(B, H_Q, max_kv_splits, D),
211+
dtype=torch.float32,
212+
- device="cuda",
213+
+ device=device,
214+
)
215+
attn_lse = torch.empty(
216+
(B, H_Q, max_kv_splits),
217+
dtype=torch.float32,
218+
- device="cuda",
219+
+ device=device,
220+
)
221+
222+
decode_attention_fwd(
223+
@@ -296,34 +299,34 @@ class TestTritonAttention(CustomTestCase):
224+
total_tokens = B * seq_len
225+
sm_scale = 1.0 / (D**0.5)
226+
max_kv_splits = 8
227+
- num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda")
228+
+ num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device=device)
229+
230+
# q represents the new token being generated, one per batch
231+
- q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
232+
+ q = torch.randn(B, H_Q, D, dtype=dtype, device=device)
233+
234+
# k_buffer and v_buffer represent all previous tokens
235+
- k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
236+
- v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device="cuda")
237+
+ k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=device)
238+
+ v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device=device)
239+
240+
# o will have the same shape as q
241+
- o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
242+
- o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
243+
+ o = torch.zeros(B, H_Q, D_V, dtype=dtype, device=device)
244+
+ o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device=device)
245+
246+
- b_seq_len = torch.full((B,), seq_len, device="cuda")
247+
+ b_seq_len = torch.full((B,), seq_len, device=device)
248+
249+
- kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
250+
+ kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
251+
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len[:B], dim=0)
252+
- kv_indices = torch.arange(total_tokens, device="cuda")
253+
+ kv_indices = torch.arange(total_tokens, device=device)
254+
255+
attn_logits = torch.empty(
256+
(B, H_Q, max_kv_splits, D_V),
257+
dtype=torch.float32,
258+
- device="cuda",
259+
+ device=device,
260+
)
261+
attn_lse = torch.empty(
262+
(B, H_Q, max_kv_splits),
263+
dtype=torch.float32,
264+
- device="cuda",
265+
+ device=device,
266+
)
267+
268+
decode_attention_fwd_normal(
269+
@@ -343,12 +346,12 @@ class TestTritonAttention(CustomTestCase):
270+
attn_logits1 = torch.empty(
271+
(B, H_Q, max_kv_splits, D_V),
272+
dtype=torch.float32,
273+
- device="cuda",
274+
+ device=device,
275+
)
276+
attn_lse1 = torch.empty(
277+
(B, H_Q, max_kv_splits, D_V),
278+
dtype=torch.float32,
279+
- device="cuda",
280+
+ device=device,
281+
)
282+
283+
decode_attention_fwd_grouped(

0 commit comments

Comments
 (0)