Skip to content

Commit 835497a

Browse files
committed
Merge branch 'main' into gregory/windows-support
2 parents 97d2441 + c05fe4f commit 835497a

File tree

74 files changed

+4502
-914
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+4502
-914
lines changed

.github/pins/pytorch-upstream.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
487873f7cafeb0fd390eaefe40496b804bceabbd
1+
0efa590d435d2b4aefcbad9014dd5fa75dcf8405

.github/workflows/auto-update-translator-cid.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ jobs:
8686
- name: Search the latest valid Translator cid
8787
if: ${{ env.TARGET_PRID == null }}
8888
run: |
89-
env
9089
./scripts/check-update-translator-cid.sh $CID_LATEST $CID_CURRENT
9190
if git status --porcelain ./lib/Target/SPIRV/spirv-llvm-translator.conf | grep '^ M'; then
9291
echo "MODIFIED=true" >> $GITHUB_ENV

.github/workflows/integration-tests.yml

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -239,14 +239,14 @@ jobs:
239239
cd python
240240
LIT_TEST_DIR="build/$(ls build | grep -i cmake)/test"
241241
if [ ! -d "${LIT_TEST_DIR}" ]; then
242-
echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1
242+
echo "Could not find '${LIT_TEST_DIR}'" ; exit -1
243243
fi
244244
lit -v "${LIT_TEST_DIR}"
245245
- name: Run python tests on CUDA
246246
run: |
247247
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/instrumentation"
248248
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
249-
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
249+
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
250250
fi
251251
cd python/test/unit
252252
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
@@ -268,14 +268,16 @@ jobs:
268268
language/test_random.py language/test_block_pointer.py language/test_subprocess.py language/test_line_info.py \
269269
runtime/test_autotuner.py::test_kwargs[False]\
270270
../../tutorials/06-fused-attention.py::test_op --device cpu
271+
- name: Run regression tests
272+
run: |
273+
cd python/test/regression
274+
python3 -m pytest -s -n 8 .
271275
- name: Run C++ unittests
272276
run: |
273277
cd python
274278
cd "build/$(ls build | grep -i cmake)"
275279
ctest -j32
276280
- name: Run Proton tests
277-
env:
278-
LD_LIBRARY_PATH: "/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH"
279281
run: |
280282
cd third_party/proton
281283
python3 -m pytest -s test
@@ -395,14 +397,14 @@ jobs:
395397
cd python
396398
LIT_TEST_DIR="build/$(ls build | grep -i cmake)/test"
397399
if [ ! -d "${LIT_TEST_DIR}" ]; then
398-
echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1
400+
echo "Could not find '${LIT_TEST_DIR}'" ; exit -1
399401
fi
400402
lit -v "${LIT_TEST_DIR}"
401403
- name: Run python tests on HIP
402404
run: |
403405
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/instrumentation"
404406
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
405-
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
407+
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
406408
fi
407409
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
408410
cd python/test/unit
@@ -416,10 +418,15 @@ jobs:
416418
417419
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
418420
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s -n 8 language/test_line_info.py
421+
- name: Run regression tests
422+
run: |
423+
# Reenable test_functional_regression.py once it's fixed
424+
cd python/test/regression
425+
python3 -m pytest -s -n 8 ./test_cast_matmul.py
419426
- name: Run Proton tests
420427
run: |
421428
cd third_party/proton
422-
python3 -m pytest test
429+
python3 -m pytest -s test
423430
- name: Run C++ unittests
424431
run: |
425432
cd python

.github/workflows/integration-tests.yml.in

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -272,15 +272,15 @@ jobs:
272272
cd python
273273
LIT_TEST_DIR="build/$(ls build | grep -i cmake)/test"
274274
if [ ! -d "${LIT_TEST_DIR}" ]; then
275-
echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1
275+
echo "Could not find '${LIT_TEST_DIR}'" ; exit -1
276276
fi
277277
lit -v "${LIT_TEST_DIR}"
278278

279279
- name: Run python tests on CUDA
280280
run: |
281281
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/instrumentation"
282282
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
283-
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
283+
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
284284
fi
285285
cd python/test/unit
286286
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
@@ -304,16 +304,20 @@ jobs:
304304
runtime/test_autotuner.py::test_kwargs[False]\
305305
../../tutorials/06-fused-attention.py::test_op --device cpu
306306

307+
- name: Run regression tests
308+
run: |
309+
cd python/test/regression
310+
python3 -m pytest -s -n 8 .
311+
307312
- &run-cpp-unittests-step
308313
name: Run C++ unittests
309314
run: |
310315
cd python
311316
cd "build/$(ls build | grep -i cmake)"
312317
ctest -j32
313318

314-
- name: Run Proton tests
315-
env:
316-
LD_LIBRARY_PATH: "/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH"
319+
- &run-proton-tests-step
320+
name: Run Proton tests
317321
run: |
318322
cd third_party/proton
319323
python3 -m pytest -s test
@@ -398,7 +402,7 @@ jobs:
398402
run: |
399403
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/instrumentation"
400404
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
401-
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
405+
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
402406
fi
403407
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
404408
cd python/test/unit
@@ -413,11 +417,13 @@ jobs:
413417
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
414418
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s -n 8 language/test_line_info.py
415419

416-
- name: Run Proton tests
420+
- name: Run regression tests
417421
run: |
418-
cd third_party/proton
419-
python3 -m pytest test
422+
# Reenable test_functional_regression.py once it's fixed
423+
cd python/test/regression
424+
python3 -m pytest -s -n 8 ./test_cast_matmul.py
420425

426+
- *run-proton-tests-step
421427
- *run-cpp-unittests-step
422428
- *save-build-artifacts-step
423429
- *inspect-cache-directories-step

benchmarks/setup.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,37 @@ def run(self):
125125
super().run()
126126

127127

128-
setup(name="triton-kernels-benchmark", packages=[
129-
"triton_kernels_benchmark",
130-
], package_dir={
131-
"triton_kernels_benchmark": "triton_kernels_benchmark",
132-
}, package_data={"triton_kernels_benchmark": ["xetla_kernel.cpython-*.so"]}, cmdclass={
133-
"build_ext": build_ext,
134-
"clean": clean,
135-
}, ext_modules=[CMakeExtension("triton_kernels_benchmark")])
128+
def get_git_commit_hash(length=8):
129+
try:
130+
cmd = ["git", "rev-parse", f"--short={length}", "HEAD"]
131+
return f"+git{subprocess.check_output(cmd).strip().decode('utf-8')}"
132+
except (
133+
FileNotFoundError,
134+
subprocess.CalledProcessError,
135+
subprocess.TimeoutExpired,
136+
):
137+
return ""
138+
139+
140+
setup(
141+
name="triton-kernels-benchmark",
142+
version="3.1.0" + get_git_commit_hash(),
143+
packages=["triton_kernels_benchmark"],
144+
install_requires=[
145+
"torch",
146+
"pandas",
147+
"tabulate",
148+
"matplotlib",
149+
],
150+
package_dir={"triton_kernels_benchmark": "triton_kernels_benchmark"},
151+
package_data={"triton_kernels_benchmark": ["xetla_kernel.cpython-*.so"]},
152+
cmdclass={
153+
"build_ext": build_ext,
154+
"clean": clean,
155+
},
156+
ext_modules=[CMakeExtension("triton_kernels_benchmark")],
157+
extra_require={
158+
"ipex": ["numpy<=2.0", "intel-extension-for-pytorch=2.1.10"],
159+
"pytorch": ["torch>=2.6"],
160+
},
161+
)

benchmarks/triton_kernels_benchmark/benchmark_driver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ def serialize_kernel_metadata(arg, args_dict):
405405
args_dict["shared_memory"] = arg.shared
406406
args_dict["kernel_name"] = arg.name
407407
args_dict["spv_name"] = f"{arg.name}.spv"
408+
args_dict["build_flags"] = arg.build_flags
408409

409410

410411
def serialize_args(args, constants, signature):

benchmarks/triton_kernels_benchmark/benchmark_testing.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,7 @@ def do_bench_elapsed_time(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quan
153153
warmup_time = n_warmup * estimate_ms
154154
rep_time = n_repeat * estimate_ms
155155

156-
times = triton_do_bench(fn, warmup=warmup_time, rep=rep_time, grad_to_none=grad_to_none, return_mode="all",
157-
device_type=device)
156+
times = triton_do_bench(fn, warmup=warmup_time, rep=rep_time, grad_to_none=grad_to_none, return_mode="all")
158157
times = torch.tensor(times, dtype=torch.float)
159158
return _summarize_statistics(times, quantiles, return_mode)
160159

benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import triton.language as tl
44

55
import triton_kernels_benchmark as benchmark_suit
6+
import xetla_kernel
67

78
if benchmark_suit.USE_IPEX_OPTION:
89
import intel_extension_for_pytorch # type: ignore # noqa: F401
@@ -131,9 +132,9 @@ def forward(ctx, a, b, c, acc_dtype=None):
131132
line_arg='provider',
132133
# argument name whose value corresponds to a different line in the plot
133134
# possible values for `line_arg``
134-
line_vals=['triton'],
135+
line_vals=['triton', 'xetla'],
135136
# label name for the lines
136-
line_names=['Triton'],
137+
line_names=['Triton', 'XeTLA'],
137138
# line styles
138139
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
139140
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
@@ -148,23 +149,36 @@ def benchmark(M, N, K, provider):
148149
quantiles = [0.5, 0.0, 1.0]
149150

150151
if provider == 'onednn':
151-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10,
152-
quantiles=quantiles)
152+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10,
153+
quantiles=quantiles)
153154
elif provider == 'triton':
154155
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
155156
triton_fn = lambda: matmul(a, b, c)
156157
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
157158
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
158159
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
159-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
160-
kernel_name='_kernel')
160+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
161+
quantiles=quantiles, kernel_name='_kernel')
162+
elif provider == 'xetla':
163+
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
164+
acc = torch.empty((M, N), device='xpu', dtype=torch.float32)
165+
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)
166+
167+
name = f'gemm_splitk_shape_{M}_{K}_{N}'
168+
func = getattr(xetla_kernel, name)
169+
xetla_fn = lambda: func(a, b, c, acc, cnt)
170+
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
171+
172+
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
173+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10,
174+
quantiles=quantiles, kernel_name='split_k_gemm_run')
161175
else:
162176
raise NotImplementedError(f'Unsupported provider {provider}')
163177

164178
tflops = lambda mean: 2 * M * N * K * (1e-12) / (mean * 1e-3)
165179
gbps = lambda mean: 2 * (M * K + K * N) + 4.0 * (M * N) * (1e-9) / (mean * 1e-3)
166180

167-
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
181+
return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv
168182

169183

170184
if __name__ == '__main__':

benchmarks/xetla_kernel/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ endif()
4545
add_subdirectory(softmax)
4646
add_subdirectory(gemm)
4747
add_subdirectory(stream_k_gemm)
48+
add_subdirectory(split_k_gemm)
4849
add_subdirectory(flash_attention)
4950

5051
install(TARGETS xetla_kernel LIBRARY DESTINATION .)

benchmarks/xetla_kernel/python_main.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "flash_attention/fmha_forward_v5.h"
33
#include "gemm/gemm.h"
44
#include "softmax/softmax.h"
5+
#include "split_k_gemm/split_k_gemm.h"
56
#include "stream_k_gemm/stream_k_gemm.h"
67
#include <CL/sycl.hpp>
78
#include <c10/core/ScalarType.h>
@@ -95,6 +96,29 @@ at::Tensor bf16_stream_k_gemm(const at::Tensor &a, const at::Tensor &b,
9596
return acc;
9697
}
9798

99+
template <int m, int k, int n,
100+
kslicing_impl_t kslicing_type = kslicing_impl_t::none>
101+
at::Tensor bf16_split_k_gemm(const at::Tensor &a, const at::Tensor &b,
102+
const at::Tensor &c, const at::Tensor &acc,
103+
const at::Tensor &cnt) {
104+
CHECK_INPUT(a);
105+
CHECK_INPUT(b);
106+
CHECK_INPUT(c);
107+
CHECK_INPUT(acc);
108+
#ifdef USE_IPEX
109+
RECORD_FUNCTION("xetla split_k_gemm", {});
110+
#endif
111+
112+
auto queue = get_current_sycl_queue();
113+
auto evt = split_k_gemm_run<m, k, n, kslicing_type>(
114+
a.data_ptr(), b.data_ptr(), c.data_ptr(), acc.data_ptr(), cnt.data_ptr(),
115+
queue);
116+
#ifdef USE_IPEX
117+
xpu::profiler_record("xetla kernel", evt);
118+
#endif
119+
return acc;
120+
}
121+
98122
#define CALL_IMPL_ATTENTION_FWD_FUNC(P) \
99123
fmha::fmha_forward_impl<P, T, use_mask, IsCausal, use_dropout>( \
100124
queue, q.data_ptr(), k.data_ptr(), v.data_ptr(), out.data_ptr(), \
@@ -283,6 +307,16 @@ PYBIND11_MODULE(xetla_kernel, m) {
283307
// gemm stream k
284308
m.def("gemm_streamk_shape_3072_4096_3072", &bf16_stream_k_gemm,
285309
"bf16_gemm_streamk (XeTLA)");
310+
// gemm split k
311+
m.def("gemm_splitk_shape_512_32768_8192",
312+
&bf16_split_k_gemm<512, 32768, 8192, kslicing_impl_t::global>,
313+
"bf16_gemm_splitk (XeTLA)");
314+
m.def("gemm_splitk_shape_1024_28672_8192",
315+
&bf16_split_k_gemm<1024, 28672, 8192, kslicing_impl_t::global>,
316+
"bf16_gemm_splitk (XeTLA)");
317+
m.def("gemm_splitk_shape_3072_4096_3072",
318+
&bf16_split_k_gemm<3072, 4096, 3072, kslicing_impl_t::global>,
319+
"bf16_gemm_splitk (XeTLA)");
286320
// flash_attn
287321
m.def("flash_attn_causal_false", &flash_attn<false, false, false>,
288322
"flash attn fwd (XeTLA)");

0 commit comments

Comments
 (0)