Skip to content

Commit 15f2e14

Browse files
authored
Merge branch 'main' into lesh/conda-oct
2 parents 70549ea + 4355afd commit 15f2e14

File tree

4 files changed

+29
-14
lines changed

4 files changed

+29
-14
lines changed

.pre-commit-config.yaml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ repos:
2222
- id: ruff
2323
files: '^python/.*'
2424
args: ["--fix", "--line-length", "120"]
25-
stages: [commit, push, manual]
25+
stages: [pre-commit, pre-push, manual]
2626
exclude: |
2727
(?x)(
2828
^python/triton/runtime/.*|
@@ -35,14 +35,14 @@ repos:
3535
hooks:
3636
- id: yapf
3737
args: ["-p", "-i"]
38-
stages: [commit, push, manual]
38+
stages: [pre-commit, pre-push, manual]
3939
exclude: "python/test/unit/language/test_line_info.py"
4040

4141
- repo: https://github.com/pre-commit/mirrors-clang-format
4242
rev: v16.0.6
4343
hooks:
4444
- id: clang-format
45-
stages: [commit, push, manual]
45+
stages: [pre-commit, pre-push, manual]
4646

4747
# Expand YAML anchors in files used by github workflows, because github can't
4848
# do this itself. This lets us use anchors, which avoids code duplication.
@@ -69,15 +69,15 @@ repos:
6969
- id: bandit
7070
files: '^(benchmarks|scripts|third_party/intel)/.*\.py$'
7171
args: ["-c", "bandit.yaml", "-s", "B404,B603,B607"]
72-
stages: [commit, push, manual]
72+
stages: [pre-commit, pre-push, manual]
7373

7474
- repo: https://github.com/astral-sh/ruff-pre-commit
7575
rev: v0.1.3
7676
hooks:
7777
- id: ruff
7878
files: '^(benchmarks|third_party/intel|scripts)/.*'
7979
args: ["--fix", "--line-length", "120"]
80-
stages: [commit, push, manual]
80+
stages: [pre-commit, pre-push, manual]
8181

8282
- repo: https://github.com/pycqa/pylint
8383
rev: v3.2.6
@@ -105,7 +105,7 @@ repos:
105105
- --disable=too-many-locals
106106
- --disable=too-many-statements
107107
- --disable=too-many-arguments
108-
stages: [commit, push, manual]
108+
stages: [pre-commit, pre-push, manual]
109109

110110
- id: pylint
111111
name: pylint for benchmarks
@@ -136,7 +136,7 @@ repos:
136136
- --disable=too-many-statements
137137
- --disable=too-many-arguments
138138
- --disable=fixme
139-
stages: [commit, push, manual]
139+
stages: [pre-commit, pre-push, manual]
140140

141141

142142
exclude: |

benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import triton.language as tl
1111

1212
import triton_kernels_benchmark as benchmark_suit
13+
import xetla_kernel
1314

1415
if benchmark_suit.USE_IPEX_OPTION:
1516
import intel_extension_for_pytorch # type: ignore # noqa: F401
@@ -253,9 +254,9 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
253254
line_arg='provider',
254255
# argument name whose value corresponds to a different line in the plot
255256
# possible values for `line_arg``
256-
line_vals=['triton'],
257+
line_vals=['triton', 'xetla'],
257258
# label name for the lines
258-
line_names=['Triton'],
259+
line_names=['Triton', 'XeTLA'],
259260
# line styles
260261
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
261262
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
@@ -281,6 +282,20 @@ def benchmark(M, N, K, provider):
281282
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
282283
quantiles=quantiles,
283284
kernel_name=['first_wave', 'full_tiles'])
285+
elif provider == 'xetla':
286+
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
287+
acc = torch.empty((M, N), device='xpu', dtype=torch.float32)
288+
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)
289+
290+
name = f'gemm_streamk_shape_{M}_{K}_{N}'
291+
func = getattr(xetla_kernel, name)
292+
xetla_fn = lambda: func(a, b, c, acc, cnt)
293+
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
294+
295+
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
296+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(
297+
xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
298+
kernel_name='gpu::xetla::kernel::gemm_universal_t<dispatch_stream_k')
284299
else:
285300
raise NotImplementedError(f'Unsupported provider {provider}')
286301

benchmarks/xetla_kernel/python_main.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,10 @@ PYBIND11_MODULE(xetla_kernel, m) {
280280
&bf16_gemm<Test_4096x8x128x16384_row_row>, "bf16_gemm (XeTLA)");
281281
m.def("gemm_shape_4096_8_16384_128",
282282
&bf16_gemm<Test_4096x8x16384x128_row_row>, "bf16_gemm (XeTLA)");
283-
// flash_attn_fwd
283+
// gemm stream k
284+
m.def("gemm_streamk_shape_3072_4096_3072", &bf16_stream_k_gemm,
285+
"bf16_gemm_streamk (XeTLA)");
286+
// flash_attn
284287
m.def("flash_attn_causal_false", &flash_attn<false, false, false>,
285288
"flash attn fwd (XeTLA)");
286289
m.def("flash_attn_causal_true", &flash_attn<false, true, false>,

benchmarks/xetla_kernel/stream_k_gemm/stream_k_gemm.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,6 @@ sycl::event stream_k_gemm_run(void *_A, void *_B, void *_C, void *_Acc,
3636
using data_type_c = float;
3737
using data_type_acc = float;
3838

39-
auto context = queue.get_info<sycl::info::queue::context>();
40-
auto device = queue.get_info<sycl::info::queue::device>();
41-
4239
data_type_a *A = static_cast<data_type_a *>(_A);
4340
data_type_b *B = static_cast<data_type_b *>(_B);
4441
data_type_c *C = static_cast<data_type_c *>(_C);
@@ -52,7 +49,7 @@ sycl::event stream_k_gemm_run(void *_A, void *_B, void *_C, void *_Acc,
5249
constexpr uint32_t sg_tile_k = 32;
5350

5451
// StreamK parameters - xecores available for stream_k dispatch
55-
uint32_t avail_xecores = 32;
52+
uint32_t avail_xecores = 64;
5653

5754
// Org the compute shape for sub-matrix
5855
using tile_shape =

0 commit comments

Comments
 (0)