Skip to content

Commit 7c9476e

Browse files
authored
Fix SGLang third party test (#4652)
Resolve the conflicts when applying sglang patch.
1 parent ed71f09 commit 7c9476e

File tree

3 files changed

+59
-45
lines changed

3 files changed

+59
-45
lines changed

.github/workflows/sglang-tests.yml

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,23 +74,27 @@ jobs:
7474
echo "REPORTS=$PWD/reports" >> $GITHUB_ENV
7575
7676
- name: Install SGLang
77-
id: install-sglang
77+
id: install
7878
run: |
7979
git clone https://github.com/sgl-project/sglang.git
8080
cd sglang
81-
git apply ../benchmarks/third_party/sglang/sglang.patch
82-
pip install ./python[dev_xpu]
81+
git apply ../benchmarks/third_party/sglang/sglang-fix.patch
82+
pip install "./python[dev_xpu]"
8383
84-
# Install Pytorch and Triton after SGLANG to ensure that the correct versions are used
8584
- name: Setup PyTorch
8685
uses: ./.github/actions/setup-pytorch
8786

8887
- name: Setup Triton
8988
uses: ./.github/actions/setup-triton
9089

9190
- name: Run SGLANG tests
92-
if: ${{ steps.install.outcome == 'success' && steps.install-sglang.outcome == 'success' && !cancelled() }}
91+
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
9392
run: |
94-
pip install pytest pytest-xdist
95-
cd sglang
96-
pytest -vvv -n 4 test/srt/test_triton_attention_kernels.py
93+
./scripts/test-triton.sh --sglang --skip-pip-install --skip-pytorch-install
94+
95+
- name: Upload test report
96+
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
97+
uses: actions/upload-artifact@v4
98+
with:
99+
name: test-reports
100+
path: reports

benchmarks/third_party/sglang/sglang.patch renamed to benchmarks/third_party/sglang/sglang-fix.patch

Lines changed: 14 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,12 @@
11
diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py
2-
index 884e715f..14e5df33 100644
2+
index bc2affa1..8ef91e66 100644
33
--- a/python/sglang/srt/utils.py
44
+++ b/python/sglang/srt/utils.py
5-
@@ -77,12 +77,20 @@ from torch.func import functional_call
6-
from torch.library import Library
7-
from torch.profiler import ProfilerActivity, profile, record_function
8-
from torch.utils._contextlib import _DecoratorContextManager
9-
-from triton.runtime.cache import (
10-
- FileCacheManager,
11-
- default_cache_dir,
12-
- default_dump_dir,
13-
- default_override_dir,
14-
-)
15-
+try:
16-
+ from triton.runtime.cache import (
17-
+ FileCacheManager,
18-
+ default_cache_dir,
19-
+ default_dump_dir,
20-
+ default_override_dir,
21-
+ )
22-
+except ImportError:
23-
+ from triton.runtime.cache import FileCacheManager
24-
+ from triton.knobs import cache as tt_cache
25-
+
26-
+ default_cache_dir = lambda: tt_cache.dir
27-
+ default_dump_dir = lambda: tt_cache.dump_dir
28-
+ default_override_dir = lambda: tt_cache.override_dir
29-
30-
logger = logging.getLogger(__name__)
5+
@@ -228,6 +228,22 @@ def is_flashinfer_available():
6+
return importlib.util.find_spec("flashinfer") is not None and is_cuda()
317

32-
@@ -156,6 +164,18 @@ def is_xpu() -> bool:
33-
def is_npu() -> bool:
34-
return hasattr(torch, "npu") and torch.npu.is_available()
358

36-
+def infer_device():
9+
+def auto_detect_device():
3710
+ """
3811
+ Infer the device type based on the current environment.
3912
+ """
@@ -43,23 +16,27 @@ index 884e715f..14e5df33 100644
4316
+ return "xpu"
4417
+ elif is_hpu():
4518
+ return "hpu"
19+
+ elif is_npu():
20+
+ return "npu"
4621
+ else:
4722
+ return "cpu"
48-
49-
def is_flashinfer_available():
50-
"""
23+
+
24+
+
25+
_ENABLE_TORCH_INFERENCE_MODE = get_bool_env_var(
26+
"SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false"
27+
)
5128
diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py
52-
index 47eb16a9..9d6a0af0 100644
29+
index 47eb16a9..cce70fb9 100644
5330
--- a/test/srt/test_triton_attention_kernels.py
5431
+++ b/test/srt/test_triton_attention_kernels.py
5532
@@ -16,8 +16,11 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
5633
context_attention_fwd,
5734
)
5835
from sglang.test.test_utils import CustomTestCase
59-
+from sglang.srt.utils import infer_device
36+
+from sglang.srt.utils import auto_detect_device
6037

6138

62-
+device = infer_device()
39+
+device = auto_detect_device()
6340
+
6441
class TestTritonAttention(CustomTestCase):
6542

scripts/test-triton.sh

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ TEST:
2020
--attention
2121
--instrumentation
2222
--inductor
23+
--sglang
2324
2425
OPTION:
2526
--unskip
@@ -57,6 +58,7 @@ TEST_BENCHMARK_GEMM=false
5758
TEST_BENCHMARK_ATTENTION=false
5859
TEST_INSTRUMENTATION=false
5960
TEST_INDUCTOR=false
61+
TEST_SGLANG=false
6062
VENV=false
6163
TRITON_TEST_REPORTS=false
6264
TRITON_TEST_WARNING_REPORTS=false
@@ -141,6 +143,11 @@ while (( $# != 0 )); do
141143
TEST_DEFAULT=false
142144
shift
143145
;;
146+
--sglang)
147+
TEST_SGLANG=true
148+
TEST_DEFAULT=false
149+
shift
150+
;;
144151
--venv)
145152
VENV=true
146153
shift
@@ -470,6 +477,29 @@ run_inductor_tests() {
470477
grep AlbertForMaskedLM inductor_log.csv | grep -q ,pass,
471478
}
472479

480+
run_sglang_tests() {
481+
echo "***************************************************"
482+
echo "****** Running SGLang Triton tests ******"
483+
echo "***************************************************"
484+
485+
if ! [ -d "./sglang" ]; then
486+
git clone https://github.com/sgl-project/sglang.git
487+
fi
488+
cd sglang
489+
490+
if ! pip list | grep "sglang" ; then
491+
git apply $TRITON_PROJ/benchmarks/third_party/sglang/sglang-fix.patch
492+
pip install "./python[dev_xpu]"
493+
494+
# SGLang installation breaks the default PyTorch and Triton versions, so we need to reinstall them.
495+
$SCRIPTS_DIR/install-pytorch.sh --force-reinstall
496+
$SCRIPTS_DIR/compile-triton.sh --triton
497+
fi
498+
499+
pip install pytest pytest-xdist
500+
run_pytest_command -vvv -n ${PYTEST_MAX_PROCESSES:-4} test/srt/test_triton_attention_kernels.py
501+
}
502+
473503
test_triton() {
474504
if [ "$TEST_UNIT" = true ]; then
475505
run_unit_tests
@@ -517,6 +547,9 @@ test_triton() {
517547
if [ "$TEST_INDUCTOR" == true ]; then
518548
run_inductor_tests
519549
fi
550+
if [ "$TEST_SGLANG" == true ]; then
551+
run_sglang_tests
552+
fi
520553
}
521554

522555
install_deps

0 commit comments

Comments
 (0)