Skip to content

[GEMM-perf] matmul is slower when one input needs to be transposed #1795

@mgrabban

Description

@mgrabban

I find that matmul(X, Y) is ~4X slower when either X or Y needs to be transposed.

So I have a matmul kernel that is similar to the one in triton tutorial here.

That kernel is launched from this code

def fused_mul_add(X, Y, b, transpose_x, transpose_y):
    if transpose_x:
        K, M = X.shape
        Xstride0, Xstride1 = X.stride(1), X.stride(0)
    else:
        M, K = X.shape
        Xstride0, Xstride1 = X.stride(0), X.stride(1)
    if transpose_y:
        N, _ = Y.shape
        Wstride0, Wstride1 = Y.stride(1), Y.stride(0)
    else:
        _, N = Y.shape
        Wstride0, Wstride1 = Y.stride(0), Y.stride(1)
    # Allocates output.
    Z = torch.empty((M, N), device=X.device, dtype=X.dtype)
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    matmul_kernel_with_block_pointers[grid](
        X, Y, b, Z,
        M, N, K,
        Xstride0, Xstride1,
        Wstride0, Wstride1,
        Z.stride(0), Z.stride(1),
        BIAS_REQD=b is not None,
    )

    return Z

Note that the strides of X or Y are switched (e.g. Xstride0, Xstride1 = X.stride(1), X.stride(0)) if it needs to be transposed.

I notice ff neither needs to be transposed, performance is similar to PyTorch's matmul perf but when either needs to be transposed (so that strides are switched for that input), performance is 4X slower.

This does not happen on CUDA devices. So can you please look into making it efficient for XPU devices as well?

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions