Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions tritonbench/operators/gemm/partition_k.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _matmul_partition_k(
# See above `Pointer Arithmetic` section for details
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = (pid_pk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) % K
offs_k = (pid_pk * PK_SIZE + tl.arange(0, BLOCK_SIZE_K)) % K
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

Expand All @@ -157,13 +157,12 @@ def _matmul_partition_k(
for k in range(0, tl.cdiv(PK_SIZE, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
# a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
# b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
k_mask = (pid_pk * PK_SIZE + k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) < K
a = tl.load(a_ptrs, mask=k_mask[None, :], other=0.0)
b = tl.load(b_ptrs, mask=k_mask[:, None], other=0.0)
accumulator += tl.dot(a, b)
a_ptrs += PK_SIZE * stride_ak
b_ptrs += PK_SIZE * stride_bk
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk

offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
Expand Down Expand Up @@ -195,7 +194,7 @@ def _reduce(
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // num_pid_m
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n

offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
Expand All @@ -220,17 +219,26 @@ def torch_reduction(c_buf, a):
compiled_reduction = torch.compile(torch_reduction)


def _matmul_partition_k_impl(a, b, triton_reduce=False):
def _matmul_partition_k_impl(a, b, triton_reduce=False, partition_k=None):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
assert b.is_contiguous(), "Matrix B must be contiguous"

# TODO: Tune on this parameter, currently 32 is best performing
partitionK = 32

M, K = a.shape
K, N = b.shape

# Choose partition size
if partition_k is not None:
partitionK = partition_k
else:
# Use 32 partitions by default, only reduce for small K to maintain accuracy
partitionK = 32 if K >= 1024 else 8

# Ensure K is divisible by partitionK
while K % partitionK != 0 and partitionK > 1:
partitionK -= 1

# Allocates output.
partitionK_SIZE = K // partitionK

Expand Down Expand Up @@ -312,5 +320,14 @@ def backward(ctx, grad_output):
return grad_a, grad_b, None


def matmul_partition_k(a, b, triton_reduce=False):
return _PartitionKMatmul.apply(a, b, triton_reduce)
def matmul_partition_k(a, b, triton_reduce=False, partition_k=None):
"""Matrix multiplication with partition-K parallelization.

Args:
a: Left input tensor (M, K)
b: Right input tensor (K, N)
triton_reduce: If True, use Triton kernel for reduction, else use PyTorch
partition_k: Number of partitions to split K dimension into.
If None, automatically choose based on K dimension.
"""
return _matmul_partition_k_impl(a, b, triton_reduce, partition_k)
Loading