Skip to content

Commit 9357902

Browse files
authored
[TEST] Reenable mixed precision dot tests (#4965)
And remove the outdated performance tests. We can also add various float8 types and move `scaled_dot` tests here.
1 parent 6a4be78 commit 9357902

File tree

5 files changed

+106
-336
lines changed

5 files changed

+106
-336
lines changed

.github/workflows/integration-tests.yml

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -239,14 +239,14 @@ jobs:
239239
cd python
240240
LIT_TEST_DIR="build/$(ls build | grep -i cmake)/test"
241241
if [ ! -d "${LIT_TEST_DIR}" ]; then
242-
echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1
242+
echo "Could not find '${LIT_TEST_DIR}'" ; exit -1
243243
fi
244244
lit -v "${LIT_TEST_DIR}"
245245
- name: Run python tests on CUDA
246246
run: |
247247
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/instrumentation"
248248
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
249-
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
249+
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
250250
fi
251251
cd python/test/unit
252252
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
@@ -268,14 +268,16 @@ jobs:
268268
language/test_random.py language/test_block_pointer.py language/test_subprocess.py language/test_line_info.py \
269269
runtime/test_autotuner.py::test_kwargs[False]\
270270
../../tutorials/06-fused-attention.py::test_op --device cpu
271+
- name: Run regression tests
272+
run: |
273+
cd python/test/regression
274+
python3 -m pytest -s -n 8 .
271275
- name: Run C++ unittests
272276
run: |
273277
cd python
274278
cd "build/$(ls build | grep -i cmake)"
275279
ctest -j32
276280
- name: Run Proton tests
277-
env:
278-
LD_LIBRARY_PATH: "/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH"
279281
run: |
280282
cd third_party/proton
281283
python3 -m pytest -s test
@@ -395,14 +397,14 @@ jobs:
395397
cd python
396398
LIT_TEST_DIR="build/$(ls build | grep -i cmake)/test"
397399
if [ ! -d "${LIT_TEST_DIR}" ]; then
398-
echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1
400+
echo "Could not find '${LIT_TEST_DIR}'" ; exit -1
399401
fi
400402
lit -v "${LIT_TEST_DIR}"
401403
- name: Run python tests on HIP
402404
run: |
403405
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/instrumentation"
404406
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
405-
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
407+
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
406408
fi
407409
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
408410
cd python/test/unit
@@ -416,10 +418,15 @@ jobs:
416418
417419
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
418420
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s -n 8 language/test_line_info.py
421+
- name: Run regression tests
422+
run: |
423+
# Reenable test_functional_regression.py once it's fixed
424+
cd python/test/regression
425+
python3 -m pytest -s -n 8 ./test_cast_matmul.py
419426
- name: Run Proton tests
420427
run: |
421428
cd third_party/proton
422-
python3 -m pytest test
429+
python3 -m pytest -s test
423430
- name: Run C++ unittests
424431
run: |
425432
cd python

.github/workflows/integration-tests.yml.in

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -272,15 +272,15 @@ jobs:
272272
cd python
273273
LIT_TEST_DIR="build/$(ls build | grep -i cmake)/test"
274274
if [ ! -d "${LIT_TEST_DIR}" ]; then
275-
echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1
275+
echo "Could not find '${LIT_TEST_DIR}'" ; exit -1
276276
fi
277277
lit -v "${LIT_TEST_DIR}"
278278

279279
- name: Run python tests on CUDA
280280
run: |
281281
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/instrumentation"
282282
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
283-
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
283+
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
284284
fi
285285
cd python/test/unit
286286
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
@@ -304,16 +304,20 @@ jobs:
304304
runtime/test_autotuner.py::test_kwargs[False]\
305305
../../tutorials/06-fused-attention.py::test_op --device cpu
306306

307+
- name: Run regression tests
308+
run: |
309+
cd python/test/regression
310+
python3 -m pytest -s -n 8 .
311+
307312
- &run-cpp-unittests-step
308313
name: Run C++ unittests
309314
run: |
310315
cd python
311316
cd "build/$(ls build | grep -i cmake)"
312317
ctest -j32
313318

314-
- name: Run Proton tests
315-
env:
316-
LD_LIBRARY_PATH: "/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH"
319+
- &run-proton-tests-step
320+
name: Run Proton tests
317321
run: |
318322
cd third_party/proton
319323
python3 -m pytest -s test
@@ -398,7 +402,7 @@ jobs:
398402
run: |
399403
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/instrumentation"
400404
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
401-
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
405+
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
402406
fi
403407
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
404408
cd python/test/unit
@@ -413,11 +417,13 @@ jobs:
413417
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
414418
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s -n 8 language/test_line_info.py
415419

416-
- name: Run Proton tests
420+
- name: Run regression tests
417421
run: |
418-
cd third_party/proton
419-
python3 -m pytest test
422+
# Reenable test_functional_regression.py once it's fixed
423+
cd python/test/regression
424+
python3 -m pytest -s -n 8 ./test_cast_matmul.py
420425

426+
- *run-proton-tests-step
421427
- *run-cpp-unittests-step
422428
- *save-build-artifacts-step
423429
- *inspect-cache-directories-step

python/test/regression/conftest.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import os
2+
import pytest
3+
import tempfile
4+
5+
6+
def pytest_addoption(parser):
7+
parser.addoption("--device", action="store", default="cuda")
8+
9+
10+
@pytest.fixture
11+
def device(request):
12+
return request.config.getoption("--device")
13+
14+
15+
@pytest.fixture
16+
def fresh_triton_cache():
17+
with tempfile.TemporaryDirectory() as tmpdir:
18+
try:
19+
os.environ["TRITON_CACHE_DIR"] = tmpdir
20+
yield tmpdir
21+
finally:
22+
os.environ.pop("TRITON_CACHE_DIR", None)

python/test/regression/test_cast_matmul.py

Lines changed: 55 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,68 @@
11
"""
2+
Mixed precision tests for matmul (tl.dot) with cast (tl.to)
3+
24
issue: https://github.com/triton-lang/triton/issues/2523
3-
fused type convert and matmul, base on triton matmul, the different with matmul:
4-
1. force C's dtype=dot_out_dtype to ["float16", "float32"]
5-
2. accept A and B with dtype=["float32", "float64"]
65
6+
TODO: float8 types
77
"""
8+
89
import pytest
910
import torch
1011

12+
import triton
1113
import triton.language as tl
12-
from triton import cdiv, jit
1314

14-
input_dtypes = ["float32", "float64"]
15+
input_dtypes = ["float16", "float32", "float64"]
1516
out_dtypes = ["float16", "float32"]
1617

1718

19+
@triton.jit
20+
def matmul_kernel(A, B, C, M, N, K, #
21+
stride_am, stride_ak, #
22+
stride_bk, stride_bn, #
23+
stride_cm, stride_cn, #
24+
dot_out_dtype: tl.constexpr, #
25+
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, #
26+
BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr):
27+
# matrix multiplication
28+
pid = tl.program_id(0)
29+
grid_m = tl.cdiv(M, BLOCK_M)
30+
grid_n = tl.cdiv(N, BLOCK_N)
31+
# re-order program ID for better L2 performance
32+
width = GROUP_M * grid_n
33+
group_id = pid // width
34+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
35+
pid_m = group_id * GROUP_M + (pid % group_size)
36+
pid_n = (pid % width) // (group_size)
37+
# do matrix multiplication
38+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
39+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
40+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
41+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
42+
rk = tl.arange(0, BLOCK_K)
43+
# pointers
44+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
45+
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
46+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
47+
for k in range(0, tl.cdiv(K, BLOCK_K)):
48+
k_remaining = K - k * BLOCK_K
49+
_0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
50+
a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
51+
b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
52+
a = a.to(C.dtype.element_ty)
53+
b = b.to(C.dtype.element_ty)
54+
acc += tl.dot(a, b, out_dtype=dot_out_dtype)
55+
A += BLOCK_K * stride_ak
56+
B += BLOCK_K * stride_bk
57+
acc = acc.to(C.dtype.element_ty)
58+
# rematerialize rm and rn to save registers
59+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
60+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
61+
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
62+
mask = (rm < M)[:, None] & (rn < N)[None, :]
63+
tl.store(C, acc, mask=mask)
64+
65+
1866
@pytest.mark.parametrize("M, K, N, w_dtype, x_dtype, out_dtype",
1967
[(M, K, N, w, x, o) #
2068
for (M, K, N) in [(128, 128, 128), (1280, 768, 1024)] #
@@ -23,7 +71,7 @@
2371
for o in out_dtypes])
2472
def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype):
2573
if x_dtype == w_dtype:
26-
pytest.skip("skip same dtype")
74+
pytest.skip("skip the same input dtype")
2775
device = torch.cuda.current_device()
2876
x_dtype = getattr(torch, x_dtype)
2977
w_dtype = getattr(torch, w_dtype)
@@ -36,53 +84,7 @@ def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype):
3684

3785
# launch kernel
3886
BLOCK_M, BLOCK_N, BLOCK_K = 16, 16, 32
39-
grid = ((cdiv(M, BLOCK_M) * cdiv(N, BLOCK_N)), 1)
40-
41-
@jit
42-
def matmul_kernel(A, B, C, M, N, K, #
43-
stride_am, stride_ak, #
44-
stride_bk, stride_bn, #
45-
stride_cm, stride_cn, #
46-
dot_out_dtype: tl.constexpr, #
47-
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, #
48-
BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr):
49-
# matrix multiplication
50-
pid = tl.program_id(0)
51-
grid_m = tl.cdiv(M, BLOCK_M)
52-
grid_n = tl.cdiv(N, BLOCK_N)
53-
# re-order program ID for better L2 performance
54-
width = GROUP_M * grid_n
55-
group_id = pid // width
56-
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
57-
pid_m = group_id * GROUP_M + (pid % group_size)
58-
pid_n = (pid % width) // (group_size)
59-
# do matrix multiplication
60-
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
61-
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
62-
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
63-
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
64-
rk = tl.arange(0, BLOCK_K)
65-
# pointers
66-
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
67-
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
68-
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
69-
for k in range(0, tl.cdiv(K, BLOCK_K)):
70-
k_remaining = K - k * BLOCK_K
71-
_0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
72-
a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
73-
b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
74-
a = a.to(C.dtype.element_ty)
75-
b = b.to(C.dtype.element_ty)
76-
acc += tl.dot(a, b, out_dtype=dot_out_dtype)
77-
A += BLOCK_K * stride_ak
78-
B += BLOCK_K * stride_bk
79-
acc = acc.to(C.dtype.element_ty)
80-
# rematerialize rm and rn to save registers
81-
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
82-
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
83-
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
84-
mask = (rm < M)[:, None] & (rn < N)[None, :]
85-
tl.store(C, acc, mask=mask)
87+
grid = ((triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)), 1)
8688

8789
matmul_kernel[grid](
8890
a, b, out_triton, M, N, K, #

0 commit comments

Comments
 (0)