Skip to content

Commit ebf5409

Browse files
committed
Add block ptr test for dot product with transpose
1 parent e19e02f commit ebf5409

File tree

1 file changed

+142
-1
lines changed

1 file changed

+142
-1
lines changed

python/test/unit/intel/test_block_load.py

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import pytest
22
import torch
33
import pathlib
4+
from functools import partial
45

56
import triton
7+
import triton.language as tl
68
from triton._internal_testing import is_xpu
79

810

@@ -74,5 +76,144 @@ def test_block_load_dpas_layout(M, N, dtype_str, transpose, device, tmp_path: pa
7476
kernel = triton.compile(str(temp_file))
7577

7678
kernel[(1, 1, 1)](a, x, b, y)
77-
#import pdb; pdb.set_trace()
7879
assert torch.equal(a, x) and torch.equal(b.T if transpose else b, y)
80+
81+
82+
@pytest.mark.parametrize("BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K",
83+
[[256, 256, 32], [256, 64, 32], [64, 256, 32], [64, 128, 32], [64, 64, 32], [32, 32, 32],
84+
[32, 32, 16], [16, 16, 16], [8, 32, 16], [8, 512, 64]])
85+
@pytest.mark.parametrize("GROUP_SIZE_M", [4, 1])
86+
@pytest.mark.parametrize("TRANSPOSE_A", [True, False])
87+
@pytest.mark.parametrize("TRANSPOSE_B", [True, False])
88+
@pytest.mark.skipif(not is_xpu(), reason="Block load tests are specific to the XPU backend")
89+
@pytest.mark.xfail(
90+
not (torch.xpu.get_device_capability()['has_subgroup_2d_block_io']
91+
and torch.xpu.get_device_capability()['has_subgroup_matrix_multiply_accumulate']),
92+
reason="Block loads not supported on this architecture")
93+
def test_block_load_dot_product(BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M, TRANSPOSE_A, TRANSPOSE_B,
94+
device):
95+
if GROUP_SIZE_M == 1 and (BLOCK_SIZE_M > 64 or BLOCK_SIZE_N > 64):
96+
# skip large block sizes as they will be too slow
97+
pytest.skip("Skipping slow combinations")
98+
99+
@triton.jit
100+
def matmul_kernel_with_block_pointers(
101+
# Pointers to matrices
102+
a_ptr, b_ptr, #bias_ptr,
103+
c_ptr,
104+
# Matrix dimensions
105+
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
106+
# The stride variables represent how much to increase the ptr by when moving by 1
107+
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
108+
# by to get the element one row down (A has M rows).
109+
stride_am: tl.constexpr, stride_ak: tl.constexpr, #
110+
stride_bk: tl.constexpr, stride_bn: tl.constexpr, #
111+
stride_cm: tl.constexpr, stride_cn: tl.constexpr, BIAS_REQD: tl.constexpr,
112+
# Meta-parameters
113+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
114+
GROUP_SIZE_M: tl.constexpr):
115+
"""Kernel for computing the matmul C = A x B.
116+
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
117+
"""
118+
# -----------------------------------------------------------
119+
# Map program ids `pid` to the block of C it should compute.
120+
# This is done in a grouped ordering to promote L2 data reuse.
121+
# See the matrix multiplication tutorial for details.
122+
pid = tl.program_id(axis=0)
123+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
124+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
125+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
126+
group_id = pid // num_pid_in_group
127+
first_pid_m = group_id * GROUP_SIZE_M
128+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
129+
pid_m = first_pid_m + (pid % group_size_m)
130+
pid_n = (pid % num_pid_in_group) // group_size_m
131+
#tl.device_print("pid", pid_m)
132+
133+
# ----------------------------------------------------------
134+
# Create block pointers for the first blocks of A and B.
135+
# We will advance this pointer as we move in the K direction and accumulate.
136+
# See above `Make a Block Pointer` section for details.
137+
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
138+
offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
139+
order=(1, 0))
140+
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
141+
offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
142+
order=(1, 0))
143+
144+
# -----------------------------------------------------------
145+
# Iterate to compute a block of the C matrix.
146+
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block.
147+
# of fp32 values for higher accuracy.
148+
# `accumulator` will be converted back to fp16 after the loop.
149+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
150+
for k in range(0, K, BLOCK_SIZE_K):
151+
# Load with boundary checks, no need to calculate the mask manually.
152+
# For better performance, you may remove some axis from the boundary
153+
# check, if you can guarantee that the access is always in-bound in
154+
# that axis.
155+
# See above `Load/Store a Block Pointer` section for details.
156+
a = tl.load(a_block_ptr, boundary_check=(0, 1))
157+
b = tl.load(b_block_ptr, boundary_check=(0, 1))
158+
# We accumulate along the K dimension.
159+
accumulator += tl.dot(a, b)
160+
# Advance the block pointer to the next K block.
161+
# See above `Advance a Block Pointer` section for details.
162+
a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K))
163+
b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0))
164+
c = accumulator.to(tl.float32)
165+
# add bias to accumulator
166+
167+
#if BIAS_REQD:
168+
# offs_yn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
169+
# bias = tl.load(bias_ptr + offs_yn, mask=offs_yn < N, other=0.0).to(tl.float32)
170+
# c += bias[None, :]
171+
# ----------------------------------------------------------------
172+
# Write back the block of the output matrix C with boundary checks.
173+
# See above `Load/Store a Block Pointer` section for details.
174+
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
175+
offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
176+
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))
177+
tl.store(c_block_ptr, c.to(tl.float16), boundary_check=(0, 1))
178+
179+
def triton_mm(X, Y, b=None, transpose_x=False, transpose_y=False):
180+
if transpose_x:
181+
K, M = X.shape
182+
Xstride0, Xstride1 = X.stride(1), X.stride(0)
183+
else:
184+
M, K = X.shape
185+
Xstride0, Xstride1 = X.stride(0), X.stride(1)
186+
if transpose_y:
187+
N, _ = Y.shape
188+
Wstride0, Wstride1 = Y.stride(1), Y.stride(0)
189+
else:
190+
_, N = Y.shape
191+
Wstride0, Wstride1 = Y.stride(0), Y.stride(1)
192+
# Allocates output.
193+
Z = torch.empty((M, N), device=X.device, dtype=X.dtype)
194+
# 1D launch kernel where each block gets its own program.
195+
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
196+
197+
matmul_kernel_with_block_pointers[grid](X, Y, Z, M, N, K, Xstride0, Xstride1, Wstride0, Wstride1, Z.stride(0),
198+
Z.stride(1), BIAS_REQD=b is not None, BLOCK_SIZE_M=BLOCK_SIZE_M,
199+
BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
200+
GROUP_SIZE_M=GROUP_SIZE_M)
201+
202+
return Z
203+
204+
M = 512
205+
K = 64
206+
N = 512
207+
dtype = torch.float16
208+
torch.manual_seed(0)
209+
210+
X = torch.randn((M, K) if not TRANSPOSE_A else (K, M), device=device, dtype=dtype, requires_grad=False)
211+
Y = torch.randn((K, N) if not TRANSPOSE_B else (N, K), device=device, dtype=dtype, requires_grad=False)
212+
213+
fn_tor = partial(torch.mm, X if not TRANSPOSE_A else X.T, Y if not TRANSPOSE_B else Y.T)
214+
fn_tri = partial(triton_mm, X, Y, transpose_x=TRANSPOSE_A, transpose_y=TRANSPOSE_B)
215+
216+
rtol = 1e-3
217+
result_tor = fn_tor()
218+
result_tri = fn_tri()
219+
torch.testing.assert_close(result_tri, result_tor, atol=1e-2, rtol=rtol)

0 commit comments

Comments
 (0)