-
Notifications
You must be signed in to change notification settings - Fork 76
Description
I find that matmul(X, Y) is ~4X slower when either X or Y needs to be transposed.
So I have a matmul kernel that is similar to the one in triton tutorial here.
That kernel is launched from this code
def fused_mul_add(X, Y, b, transpose_x, transpose_y):
if transpose_x:
K, M = X.shape
Xstride0, Xstride1 = X.stride(1), X.stride(0)
else:
M, K = X.shape
Xstride0, Xstride1 = X.stride(0), X.stride(1)
if transpose_y:
N, _ = Y.shape
Wstride0, Wstride1 = Y.stride(1), Y.stride(0)
else:
_, N = Y.shape
Wstride0, Wstride1 = Y.stride(0), Y.stride(1)
# Allocates output.
Z = torch.empty((M, N), device=X.device, dtype=X.dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel_with_block_pointers[grid](
X, Y, b, Z,
M, N, K,
Xstride0, Xstride1,
Wstride0, Wstride1,
Z.stride(0), Z.stride(1),
BIAS_REQD=b is not None,
)
return Z
Note that the strides of X or Y are switched (e.g. Xstride0, Xstride1 = X.stride(1), X.stride(0)) if it needs to be transposed.
I notice ff neither needs to be transposed, performance is similar to PyTorch's matmul perf but when either needs to be transposed (so that strides are switched for that input), performance is 4X slower.
This does not happen on CUDA devices. So can you please look into making it efficient for XPU devices as well?