Skip to content

Commit 501c933

Browse files
committed
Add gemm test program
1 parent 744383a commit 501c933

File tree

1 file changed

+298
-0
lines changed

1 file changed

+298
-0
lines changed

test.py

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
from functools import partial
5+
6+
device = 'xpu'
7+
backend = getattr(torch, device)
8+
9+
10+
def compute_time(
11+
fn,
12+
warmup=1,
13+
rep=5,
14+
grad_to_none=None,
15+
quantiles=None,
16+
fast_flush=True,
17+
return_mode="mean",
18+
):
19+
assert return_mode in ["min", "max", "mean", "median"]
20+
21+
"""
22+
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
23+
the 20-th and 80-th performance percentile.
24+
25+
:param fn: Function to benchmark
26+
:type fn: Callable
27+
:param warmup: Warmup time (in ms)
28+
:type warmup: int
29+
:param rep: Repetition time (in ms)
30+
:type rep: int
31+
:param grad_to_none: Reset the gradient of the provided tensor to None
32+
:type grad_to_none: torch.tensor, optional
33+
:param quantiles: Performance percentile to return in addition to the median.
34+
:type quantiles: list[float]
35+
:param fast_flush: Use faster kernel to flush L2 between measurements
36+
:type fast_flush: bool
37+
"""
38+
backend.synchronize()
39+
40+
# We maintain a buffer of 256 MB that we clear
41+
# before each kernel call to make sure that the L2
42+
# doesn't contain any input data before the run
43+
if fast_flush:
44+
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=device)
45+
else:
46+
cache = torch.empty(int(256e6), dtype=torch.int8, device=device)
47+
48+
# compute number of warmup and repeat
49+
50+
start_event = [backend.Event(enable_timing=True) for i in range(rep)]
51+
end_event = [backend.Event(enable_timing=True) for i in range(rep)]
52+
# Warm-up
53+
for _ in range(warmup):
54+
fn()
55+
# Benchmark
56+
for i in range(rep):
57+
# we don't want `fn` to accumulate gradient values
58+
# if it contains a backward pass. So we clear the
59+
# provided gradients
60+
if grad_to_none is not None:
61+
for x in grad_to_none:
62+
if hasattr(x, 'grad'):
63+
x.grad = None
64+
# we clear the L2 cache before each run
65+
cache.zero_()
66+
# record time of `fn`
67+
start_event[i].record()
68+
fn()
69+
end_event[i].record()
70+
# Record clocks
71+
backend.synchronize()
72+
times = torch.tensor(
73+
[s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float
74+
)
75+
if quantiles is not None:
76+
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
77+
if len(ret) == 1:
78+
ret = ret[0]
79+
return ret
80+
return getattr(torch, return_mode)(times).item()
81+
82+
83+
@triton.autotune(
84+
configs=[
85+
triton.Config(kwargs={'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
86+
# triton.Config(kwargs={'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=3, num_warps=32),
87+
# triton.Config(kwargs={'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=2, num_warps=32),
88+
# triton.Config(kwargs={'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=2, num_warps=32),
89+
# triton.Config(kwargs={'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}, num_stages=2, num_warps=32),
90+
],
91+
key=['M', 'N', 'K'],)
92+
@triton.jit
93+
def matmul_kernel_with_block_pointers(
94+
# Pointers to matrices
95+
a_ptr, b_ptr, bias_ptr, c_ptr,
96+
# Matrix dimensions
97+
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
98+
# The stride variables represent how much to increase the ptr by when moving by 1
99+
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
100+
# by to get the element one row down (A has M rows).
101+
stride_am: tl.constexpr, stride_ak: tl.constexpr, #
102+
stride_bk: tl.constexpr, stride_bn: tl.constexpr, #
103+
stride_cm: tl.constexpr, stride_cn: tl.constexpr,
104+
BIAS_REQD: tl.constexpr,
105+
# Meta-parameters
106+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
107+
"""Kernel for computing the matmul C = A x B.
108+
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
109+
"""
110+
# -----------------------------------------------------------
111+
# Map program ids `pid` to the block of C it should compute.
112+
# This is done in a grouped ordering to promote L2 data reuse.
113+
# See the matrix multiplication tutorial for details.
114+
pid = tl.program_id(axis=0)
115+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
116+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
117+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
118+
group_id = pid // num_pid_in_group
119+
first_pid_m = group_id * GROUP_SIZE_M
120+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
121+
pid_m = first_pid_m + (pid % group_size_m)
122+
pid_n = (pid % num_pid_in_group) // group_size_m
123+
#tl.device_print("pid", pid_m)
124+
125+
# ----------------------------------------------------------
126+
# Create block pointers for the first blocks of A and B.
127+
# We will advance this pointer as we move in the K direction and accumulate.
128+
# See above `Make a Block Pointer` section for details.
129+
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
130+
offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
131+
order=(1, 0))
132+
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
133+
offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
134+
order=(1, 0))
135+
136+
# -----------------------------------------------------------
137+
# Iterate to compute a block of the C matrix.
138+
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block.
139+
# of fp32 values for higher accuracy.
140+
# `accumulator` will be converted back to fp16 after the loop.
141+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
142+
for k in range(0, K, BLOCK_SIZE_K):
143+
# Load with boundary checks, no need to calculate the mask manually.
144+
# For better performance, you may remove some axis from the boundary
145+
# check, if you can guarantee that the access is always in-bound in
146+
# that axis.
147+
# See above `Load/Store a Block Pointer` section for details.
148+
a = tl.load(a_block_ptr, boundary_check=(0, 1))
149+
b = tl.load(b_block_ptr, boundary_check=(0, 1))
150+
# We accumulate along the K dimension.
151+
accumulator += tl.dot(a, b)
152+
# Advance the block pointer to the next K block.
153+
# See above `Advance a Block Pointer` section for details.
154+
a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K))
155+
b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0))
156+
c = accumulator.to(tl.float32)
157+
# add bias to accumulator
158+
if BIAS_REQD:
159+
offs_yn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
160+
bias = tl.load(bias_ptr + offs_yn, mask=offs_yn < N, other=0.0).to(tl.float32)
161+
c += bias[None, :]
162+
# ----------------------------------------------------------------
163+
# Write back the block of the output matrix C with boundary checks.
164+
# See above `Load/Store a Block Pointer` section for details.
165+
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
166+
offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
167+
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))
168+
tl.store(c_block_ptr, c.to(tl.float16), boundary_check=(0, 1))
169+
170+
171+
def triton_mm(X, Y, b=None, transpose_x=False, transpose_y=False):
172+
if transpose_x:
173+
K, M = X.shape
174+
Xstride0, Xstride1 = X.stride(1), X.stride(0)
175+
else:
176+
M, K = X.shape
177+
Xstride0, Xstride1 = X.stride(0), X.stride(1)
178+
if transpose_y:
179+
N, _ = Y.shape
180+
Wstride0, Wstride1 = Y.stride(1), Y.stride(0)
181+
else:
182+
_, N = Y.shape
183+
Wstride0, Wstride1 = Y.stride(0), Y.stride(1)
184+
# Allocates output.
185+
Z = torch.empty((M, N), device=X.device, dtype=X.dtype)
186+
# 1D launch kernel where each block gets its own program.
187+
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
188+
189+
matmul_kernel_with_block_pointers[grid](
190+
X, Y, b, Z,
191+
M, N, K,
192+
Xstride0, Xstride1,
193+
Wstride0, Wstride1,
194+
Z.stride(0), Z.stride(1),
195+
BIAS_REQD=b is not None,
196+
)
197+
198+
return Z
199+
200+
201+
M = 1024
202+
K = 5120
203+
N = 4096
204+
dtype = torch.float16
205+
torch.manual_seed(0)
206+
207+
AxB = True
208+
AxBT = True
209+
ATxB = True
210+
ATxBT = True
211+
212+
if AxB:
213+
print('Compute A x B')
214+
X = torch.randn((M, K), device=device, dtype=dtype, requires_grad=False)
215+
Y = torch.randn((K, N), device=device, dtype=dtype, requires_grad=False)
216+
217+
fn_tor = partial(torch.mm, X, Y)
218+
fn_tri = partial(triton_mm, X, Y)
219+
220+
rtol = 1e-3
221+
result_tor = fn_tor()
222+
result_tri = fn_tri()
223+
if torch.allclose(result_tri, result_tor, atol=1e-2, rtol=rtol):
224+
print("✅ Triton and Torch match")
225+
else:
226+
exit("❌ Triton and Torch differ")
227+
228+
t_tor = compute_time(fn_tor, warmup=5, rep=100)
229+
t_tri = compute_time(fn_tri, warmup=5, rep=100)
230+
print(f"Time for torch: {t_tor} ms")
231+
print(f"Time for triton: {t_tri} ms")
232+
233+
234+
if AxBT:
235+
torch.manual_seed(0)
236+
print('Compute A x B.T')
237+
X = torch.randn((M, K), device=device, dtype=dtype, requires_grad=False)
238+
Y = torch.randn((N, K), device=device, dtype=dtype, requires_grad=False)
239+
240+
fn_tor = partial(torch.mm, X, Y.T)
241+
fn_tri = partial(triton_mm, X, Y, transpose_y=True)
242+
243+
rtol = 1e-3
244+
result_tor = fn_tor()
245+
result_tri = fn_tri()
246+
if torch.allclose(result_tri, result_tor, atol=1e-2, rtol=rtol):
247+
print("✅ Triton and Torch match")
248+
else:
249+
exit("❌ Triton and Torch differ")
250+
251+
t_tor = compute_time(fn_tor, warmup=5, rep=100)
252+
t_tri = compute_time(fn_tri, warmup=5, rep=100)
253+
print(f"Time for torch: {t_tor} ms")
254+
print(f"Time for triton: {t_tri} ms")
255+
256+
if ATxB:
257+
torch.manual_seed(0)
258+
print('Compute A.T x B')
259+
X = torch.randn((K, M), device=device, dtype=dtype, requires_grad=False)
260+
Y = torch.randn((K, N), device=device, dtype=dtype, requires_grad=False)
261+
262+
fn_tor = partial(torch.mm, X.T, Y)
263+
fn_tri = partial(triton_mm, X, Y, transpose_x=True)
264+
265+
rtol = 1e-3
266+
result_tor = fn_tor()
267+
result_tri = fn_tri()
268+
if torch.allclose(result_tri, result_tor, atol=1e-2, rtol=rtol):
269+
print("✅ Triton and Torch match")
270+
else:
271+
exit("❌ Triton and Torch differ")
272+
273+
t_tor = compute_time(fn_tor, warmup=5, rep=100)
274+
t_tri = compute_time(fn_tri, warmup=5, rep=100)
275+
print(f"Time for torch: {t_tor} ms")
276+
print(f"Time for triton: {t_tri} ms")
277+
278+
if ATxBT:
279+
torch.manual_seed(0)
280+
print('Compute A.T x B.T')
281+
X = torch.randn((K, M), device=device, dtype=dtype, requires_grad=False)
282+
Y = torch.randn((N, K), device=device, dtype=dtype, requires_grad=False)
283+
284+
fn_tor = partial(torch.mm, X.T, Y.T)
285+
fn_tri = partial(triton_mm, X, Y, transpose_x=True, transpose_y=True)
286+
287+
rtol = 1e-3
288+
result_tor = fn_tor()
289+
result_tri = fn_tri()
290+
if torch.allclose(result_tri, result_tor, atol=1e-2, rtol=rtol):
291+
print("✅ Triton and Torch match")
292+
else:
293+
exit("❌ Triton and Torch differ")
294+
295+
t_tor = compute_time(fn_tor, warmup=5, rep=100)
296+
t_tri = compute_time(fn_tri, warmup=5, rep=100)
297+
print(f"Time for torch: {t_tor} ms")
298+
print(f"Time for triton: {t_tri} ms")

0 commit comments

Comments
 (0)