Skip to content

Commit 752d83c

Browse files
authored
Added compiler hints to enable buffer loads (#729)
1 parent 6acc10c commit 752d83c

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

python/perf-kernels/gemm.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ def matmul_kernel(
7171
"""Kernel for computing the matmul C = A x B.
7272
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
7373
"""
74+
75+
tl.assume(stride_am > 0)
76+
tl.assume(stride_ak > 0)
77+
tl.assume(stride_bk > 0)
78+
tl.assume(stride_bn > 0)
79+
tl.assume(stride_cm > 0)
80+
tl.assume(stride_cn > 0)
81+
7482
# -----------------------------------------------------------
7583
# Map program ids `pid` to the block of C it should compute.
7684
# This is done in a grouped ordering to promote L2 data reuse.
@@ -89,6 +97,9 @@ def matmul_kernel(
8997
pid_m = first_pid_m + (pid % group_size_m)
9098
pid_n = (pid % num_pid_in_group) // group_size_m
9199

100+
tl.assume(pid_m > 0)
101+
tl.assume(pid_n > 0)
102+
92103
# Create pointers for first block of A and B input matrices
93104
offs_k = tl.arange(0, BLOCK_SIZE_K)
94105
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M

0 commit comments

Comments
 (0)