Skip to content

Commit b73b3db

Browse files
authored
[GEMM] Add the tensor of pointer benchmark (#3705)
Closes #3633
1 parent b424431 commit b73b3db

File tree

3 files changed

+378
-0
lines changed

3 files changed

+378
-0
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,16 @@ jobs:
160160
TAG="${TAG}-adv"
161161
python ../../scripts/build_report.py $REPORTS/matmul-performance-adv-path.csv $REPORTS/gemm-triton-advanced-report.csv --benchmark gemm --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
162162
163+
- name: Run Triton GEMM kernel benchmark - with tensor of pointer
164+
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_tensor_of_ptr_benchmark.py') }}
165+
run: |
166+
cd benchmarks/triton_kernels_benchmark
167+
python gemm_tensor_of_ptr_benchmark.py --reports $REPORTS
168+
source ../../scripts/capture-hw-details.sh
169+
python ../../scripts/build_report.py $REPORTS/matmul-tensor-of-ptr-performance.csv $REPORTS/gemm-tensor-of-ptr-triton-report.csv --benchmark gemm-tensor-of-ptr --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
170+
python ../../scripts/build_report.py $REPORTS/matmul-tensor-of-ptr-performance.csv $REPORTS/gemm-tensor-of-ptr-xetla-report.csv --benchmark gemm-tensor-of-ptr --compiler xetla --param_cols "B,M,K,N" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
171+
python ../../scripts/build_report.py $REPORTS/matmul-tensor-of-ptr-performance.csv $REPORTS/gemm-tensor-of-ptr-onednn-report.csv --benchmark gemm-tensor-of-ptr --compiler onednn --param_cols "B,M,K,N" --tflops_col OneDNN-TFlops --hbm_col "OneDNN-GB/s" --tag $TAG
172+
163173
- name: Run Triton GEMM (A@B^t) kernel benchmark
164174
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_benchmark.py_abt') }}
165175
run: |
Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
"""
2+
Gemm benchmark (tensor of pointer)
3+
============================
4+
5+
This benchmark is come from the Triton tutorial 03-matrix-multiplication.py (commit: 3f4fdd1)
6+
To compare the performance to XeTLA kernel.
7+
8+
"""
9+
import os
10+
11+
import torch
12+
import triton
13+
import triton.language as tl
14+
15+
import triton_kernels_benchmark as benchmark_suit
16+
from triton_kernels_benchmark import xetla_kernel
17+
18+
TRANSPOSE_A = os.getenv('TRANSPOSE_A', '0') == '1'
19+
TRANSPOSE_B = os.getenv('TRANSPOSE_B', '0') == '1'
20+
use_xetla = not (TRANSPOSE_A or TRANSPOSE_B)
21+
SMALL_GRF = os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0'
22+
23+
24+
@triton.autotune(
25+
configs=[
26+
triton.Config(
27+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
28+
num_stages=s, num_warps=32) for s in [1, 2, 3]
29+
] + [
30+
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m},
31+
num_stages=s, num_warps=w)
32+
for s in [2, 3, 4]
33+
for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)])
34+
] + [
35+
triton.Config(
36+
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
37+
num_stages=s, num_warps=32) for s in [2]
38+
] + [
39+
triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': m},
40+
num_stages=s, num_warps=w)
41+
for s in [2, 3]
42+
for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)])
43+
],
44+
key=['M', 'N', 'K'],
45+
)
46+
@triton.jit
47+
def matmul_kernel(
48+
# Pointers to matrices
49+
a_ptr, b_ptr, c_ptr,
50+
# Matrix dimensions
51+
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
52+
# Stride variables
53+
stride_am: tl.constexpr, stride_ak: tl.constexpr, #
54+
stride_bk: tl.constexpr, stride_bn: tl.constexpr, #
55+
stride_cm: tl.constexpr, stride_cn: tl.constexpr,
56+
# Meta-parameters
57+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
58+
pid = tl.program_id(axis=0)
59+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
60+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
61+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
62+
group_id = pid // num_pid_in_group
63+
first_pid_m = group_id * GROUP_SIZE_M
64+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
65+
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
66+
pid_n = (pid % num_pid_in_group) // group_size_m
67+
68+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
69+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
70+
offs_k = tl.arange(0, BLOCK_SIZE_K)
71+
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
72+
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
73+
74+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
75+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
76+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
77+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
78+
accumulator = tl.dot(a, b, accumulator)
79+
a_ptrs += BLOCK_SIZE_K * stride_ak
80+
b_ptrs += BLOCK_SIZE_K * stride_bk
81+
82+
c = accumulator.to(tl.float32)
83+
84+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
85+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
86+
c_ptrs = c_ptr + stride_cm * \
87+
offs_cm[:, None] + stride_cn * offs_cn[None, :]
88+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
89+
tl.store(c_ptrs, c, mask=c_mask)
90+
91+
92+
# pylint: disable=unused-argument
93+
@triton.autotune(
94+
configs=[
95+
triton.Config(
96+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
97+
num_stages=s, num_warps=32) for s in [2, 3]
98+
] + [
99+
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m},
100+
num_stages=s, num_warps=w)
101+
for s in [2]
102+
for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)])
103+
] + [
104+
triton.Config(
105+
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 1024, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
106+
num_stages=s, num_warps=32) for s in [2, 3]
107+
] + [
108+
triton.Config(
109+
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
110+
num_stages=s, num_warps=32) for s in [2]
111+
] + [
112+
triton.Config(
113+
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'},
114+
num_stages=s, num_warps=32) for s in [2]
115+
] + [
116+
triton.Config(
117+
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'},
118+
num_stages=s, num_warps=4) for s in [2]
119+
],
120+
key=['M', 'N', 'K'],
121+
)
122+
@triton.jit
123+
def matmul_kernel_batched(
124+
# Pointers to matrices
125+
a_ptr, b_ptr, c_ptr,
126+
# Matrix dimensions
127+
B: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
128+
# Stride variables
129+
stride_az: tl.constexpr, stride_am: tl.constexpr, stride_ak: tl.constexpr, #
130+
stride_bz: tl.constexpr, stride_bk: tl.constexpr, stride_bn: tl.constexpr, #
131+
stride_cz: tl.constexpr, stride_cm: tl.constexpr, stride_cn: tl.constexpr,
132+
# Meta-parameters
133+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
134+
bid = tl.program_id(axis=1)
135+
pid = tl.program_id(axis=0)
136+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
137+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
138+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
139+
group_id = pid // num_pid_in_group
140+
first_pid_m = group_id * GROUP_SIZE_M
141+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
142+
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
143+
pid_n = (pid % num_pid_in_group) // group_size_m
144+
145+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
146+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
147+
offs_k = tl.arange(0, BLOCK_SIZE_K)
148+
offset_a = bid.to(tl.int64) * stride_az
149+
offset_b = bid.to(tl.int64) * stride_bz
150+
a_ptrs = a_ptr + offset_a + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
151+
b_ptrs = b_ptr + offset_b + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
152+
153+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
154+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
155+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
156+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
157+
accumulator = tl.dot(a, b, accumulator)
158+
a_ptrs += BLOCK_SIZE_K * stride_ak
159+
b_ptrs += BLOCK_SIZE_K * stride_bk
160+
c = accumulator.to(tl.float32)
161+
162+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
163+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
164+
offset_c = bid.to(tl.int64) * stride_cz
165+
c_ptrs = c_ptr + offset_c + stride_cm * \
166+
offs_cm[:, None] + stride_cn * offs_cn[None, :]
167+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
168+
tl.store(c_ptrs, c, mask=c_mask)
169+
170+
171+
# We can now create a convenience wrapper function that only takes two input tensors,
172+
# and (1) checks any shape constraint; (2) launches the above kernel.
173+
def matmul(a, b, c, transpose_a=False, transpose_b=False):
174+
a_major, a_minor = -2, -1
175+
if transpose_a:
176+
a_major, a_minor = a_minor, a_major
177+
b_minor, b_major = -2, -1
178+
if transpose_b:
179+
b_major, b_minor = b_minor, b_major
180+
181+
assert a.shape[a_minor] == b.shape[b_minor], 'Incompatible dimensions'
182+
assert a.is_contiguous(), 'Matrix A must be contiguous'
183+
assert b.is_contiguous(), 'Matrix B must be contiguous'
184+
M, N, K = a.shape[a_major], b.shape[b_major], a.shape[a_minor]
185+
# Check constraints.
186+
if len(a.shape) == 3 and len(b.shape) == 3:
187+
assert a.shape[0] == b.shape[0], 'Incompatible Batch dimension'
188+
B = a.shape[0]
189+
# 1D launch kernel where each block gets its own program.
190+
191+
grid = lambda META: (
192+
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
193+
B,
194+
)
195+
matmul_kernel_batched[grid](
196+
a, b, c, #
197+
B, M, N, K, #
198+
a.stride(0), a.stride(a_major), a.stride(a_minor), #
199+
b.stride(0), b.stride(b_minor), b.stride(b_major), #
200+
c.stride(0), c.stride(1), c.stride(2))
201+
elif len(a.shape) == 2 and len(b.shape) == 2:
202+
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
203+
matmul_kernel[grid](
204+
a, b, c, #
205+
M, N, K, #
206+
a.stride(a_major), a.stride(a_minor), #
207+
b.stride(b_minor), b.stride(b_major), #
208+
c.stride(0), c.stride(1))
209+
else:
210+
assert False, 'Input matrixs dimensions mismatch'
211+
return c
212+
213+
214+
def get_shapes(B, M, N, K, transpose_a, transpose_b):
215+
a_shape = (M, K)
216+
if transpose_a:
217+
a_shape = (K, M)
218+
219+
b_shape = (K, N)
220+
if transpose_b:
221+
b_shape = (N, K)
222+
223+
if B != 1:
224+
a_shape = (B, *a_shape)
225+
b_shape = (B, *b_shape)
226+
return a_shape, b_shape
227+
228+
229+
X_VALS = [[1, 1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]] + [
230+
[1, 1, 13824, 5120],
231+
[1, 4, 12288, 4096],
232+
[1, 512, 8192, 8192],
233+
[1, 512, 8192, 32768],
234+
[1, 512, 32768, 8192],
235+
[1, 1024, 8192, 16384],
236+
[1, 1024, 8192, 28672],
237+
[1, 3072, 3072, 4096], # FIXME: Remove this case when gemm_streamk_benchmark can get better performance
238+
[1, 4096, 8192, 16384],
239+
[1, 8192, 1024, 16384],
240+
[1, 8192, 4096, 16384],
241+
[1, 16384, 1024, 8192],
242+
[1, 16384, 4096, 8192],
243+
[1, 16384, 8192, 1024],
244+
[1, 16384, 8192, 4096],
245+
[4, 32768, 128, 4096],
246+
[4, 32768, 4096, 128],
247+
[32, 4096, 128, 4096],
248+
[4096, 8, 128, 16384],
249+
[4096, 8, 16384, 128],
250+
]
251+
252+
DEVICE_NAME = torch.xpu.get_device_name()
253+
DEVICE_TOTAL_MEMORY = torch.xpu.get_device_properties().total_memory
254+
255+
256+
def is_enough_memory(x_val):
257+
# x_val: (B, M, N, K)
258+
B, M, N, K = x_val
259+
# a: (B, M, K) bfloat16
260+
# b: (B, N, K) bfloat16
261+
# c: (B, M, N) float32
262+
# pytorch reference: (B, M, N) float32
263+
required_memory = B * M * K * 2 + B * N * K * 2 + 2 * B * M * N * 4
264+
enough_memory = required_memory < DEVICE_TOTAL_MEMORY
265+
if not enough_memory:
266+
print(f"'{x_val}' combination skipped for '{DEVICE_NAME}'; {required_memory=} but {DEVICE_TOTAL_MEMORY=}")
267+
return enough_memory
268+
269+
270+
X_VALS = [x_val for x_val in X_VALS if is_enough_memory(x_val)]
271+
272+
273+
# Benchmark Performance
274+
@benchmark_suit.perf_report(
275+
benchmark_suit.Benchmark(
276+
# argument names to use as an x-axis for the plot
277+
x_names=['B', 'M', 'N', 'K'],
278+
# different possible values for `x_name`
279+
x_vals=X_VALS,
280+
line_arg='provider',
281+
# argument name whose value corresponds to a different line in the plot
282+
# possible values for `line_arg``
283+
line_vals=['triton', 'onednn'] + (['xetla'] if use_xetla else []),
284+
# label name for the lines
285+
line_names=['Triton', 'OneDNN'] + (['XeTLA'] if use_xetla else []),
286+
# line styles
287+
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
288+
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
289+
plot_name='matmul-tensor-of-ptr-performance',
290+
# name for the plot. Used also as a file name for saving the plot.
291+
args={},
292+
))
293+
def benchmark(B, M, N, K, provider):
294+
a_shape, b_shape = get_shapes(B, M, N, K, transpose_a=TRANSPOSE_A, transpose_b=TRANSPOSE_B)
295+
296+
torch.manual_seed(0)
297+
a = torch.rand(a_shape, device='xpu', dtype=torch.bfloat16)
298+
b = torch.rand(b_shape, device='xpu', dtype=torch.bfloat16)
299+
300+
quantiles = [0.5, 0.0, 1.0]
301+
302+
torch_a = a
303+
if TRANSPOSE_A:
304+
torch_a = torch.transpose(torch_a, -2, -1)
305+
306+
torch_b = b
307+
if TRANSPOSE_B:
308+
torch_b = torch.transpose(torch_b, -2, -1)
309+
310+
if provider == 'onednn':
311+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(torch_a, torch_b), n_warmup=10,
312+
n_repeat=10, quantiles=quantiles)
313+
elif provider == 'triton':
314+
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
315+
if len(a.shape) == 3:
316+
c = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
317+
else:
318+
assert len(a.shape) == 2, 'Expecting shape of length 2'
319+
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
320+
triton_fn = lambda: matmul(a, b, c, transpose_a=TRANSPOSE_A, transpose_b=TRANSPOSE_B)
321+
torch_fn = lambda: torch.matmul(torch_a, torch_b).to(torch.float32)
322+
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
323+
benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch')
324+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
325+
quantiles=quantiles)
326+
elif provider == 'xetla':
327+
if B == 1:
328+
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
329+
cnt = torch.zeros((M, N), device='xpu', dtype=torch.int32)
330+
else:
331+
c = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
332+
cnt = torch.zeros((B, M, N), device='xpu', dtype=torch.int32)
333+
name = f'gemm_shape_{B}_{M}_{K}_{N}'
334+
# FIXME: Use gemm_streamk_benchmark.py when Triton streamk can get
335+
# better performance.
336+
if (B, M, N, K) == (1, 3072, 3072, 4096):
337+
name = 'gemm_streamk_shape_3072_4096_3072'
338+
func = getattr(xetla_kernel, name)
339+
340+
def xetla_func_with_acc_allocation():
341+
# allocating `acc` matrix on every function call, to be as similar as
342+
# possible to the triton kernel, which also does this on every call.
343+
if B == 1:
344+
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
345+
else:
346+
acc = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
347+
return func(a, b, c, acc, cnt)
348+
349+
xetla_fn = xetla_func_with_acc_allocation
350+
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
351+
352+
# benchmark_suit.assert_close(xetla_fn, torch_fn, atol=1e-4, rtol=1.0, err_msg='xetla to torch')
353+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10,
354+
quantiles=quantiles)
355+
else:
356+
raise NotImplementedError(f'Unsupported provider {provider}')
357+
358+
tflops = lambda ms: 2 * B * M * N * K * (1e-12) / (ms * 1e-3)
359+
gbps = lambda ms: B * (2 * (M * K + K * N) + 4.0 * (M * N)) * (1e-9) / (ms * 1e-3)
360+
361+
return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv
362+
363+
364+
if __name__ == '__main__':
365+
benchmark.run(show_plots=False, print_data=True)

scripts/test-triton.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,9 @@ run_benchmark_gemm() {
291291
IGC_VISAOptions=" -enableBCR -nolocalra" \
292292
IGC_DisableLoopUnroll=1 \
293293
python $TRITON_PROJ/benchmarks/triton_kernels_benchmark/gemm_benchmark.py
294+
295+
echo "GEMM with tensor of pointer:"
296+
python $TRITON_PROJ/benchmarks/triton_kernels_benchmark/gemm_tensor_of_ptr_benchmark.py
294297
}
295298

296299
run_benchmark_attention() {

0 commit comments

Comments
 (0)