Skip to content

type_infer_annot #88

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 10 additions & 20 deletions examples/cuda_matmul_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,7 @@ def sgemm_naive[
B_t: T.memref(K, N, dtype),
C_t: T.memref(M, N, dtype),
](A: A_t, B: B_t, C: C_t):
one = arith.constant(1.0, type=dtype)
tmp = arith.constant(0, type=dtype)

# this is from the example and it's basically a mistake
# it increments the row for each adjacent thread id
# uncomment the print to see
Expand All @@ -154,7 +152,7 @@ def sgemm_naive[
for k, tmp in range_(K, iter_args=[tmp]):
tmp += A[r, k] * B[k, c]
tmp = yield tmp
C[r, c] = tmp + one
C[r, c] = tmp + 1


@gpu.func
Expand All @@ -168,7 +166,7 @@ def sgemm_naive_row_order[
B_t: T.memref(K, N, dtype),
C_t: T.memref(M, N, dtype),
](A: A_t, B: B_t, C: C_t):
one = arith.constant(1.0, type=dtype)

tmp = arith.constant(0, type=dtype)

# increment along the cols (ie preserve row-order access)
Expand All @@ -180,7 +178,7 @@ def sgemm_naive_row_order[
for k, tmp in range_(K, iter_args=[tmp]):
tmp += A[r, k] * B[k, c]
tmp = yield tmp
C[r, c] = tmp + one
C[r, c] = tmp + 1


@gpu.func
Expand All @@ -202,7 +200,6 @@ def sgemm_coalesce[
c = block_idx.y * BLOCK_SIZE + (tid % BLOCK_SIZE)
# gpu.printf("tid: %ld: (%ld, %ld)\n", tid, r, c)

one = arith.constant(1.0, type=dtype)
tmp = arith.constant(0, type=dtype)

for k, tmp in range_(K, iter_args=[tmp]):
Expand All @@ -211,7 +208,7 @@ def sgemm_coalesce[
# because there's enough scratch per SM to prefetch all the data each thread needs?
tmp += A[r, k] * B[k, c]
tmp = yield tmp
C[r, c] = tmp + one
C[r, c] = tmp + 1


# So if you try to load something like:
Expand Down Expand Up @@ -266,7 +263,6 @@ def sgemm_coalesce_transpose_B[
r = block_idx.x * BLOCK_SIZE + (tid / BLOCK_SIZE)
c = block_idx.y * BLOCK_SIZE + (tid % BLOCK_SIZE)

one = arith.constant(1.0, type=dtype)
tmp = arith.constant(0, type=dtype)

for k, tmp in range_(K, iter_args=[tmp]):
Expand All @@ -275,7 +271,7 @@ def sgemm_coalesce_transpose_B[
# but k now being on the row order dim doesn't help?
tmp += A[r, k] * B[c, k]
tmp = yield tmp
C[r, c] = tmp + one
C[r, c] = tmp + 1


@gpu.func
Expand Down Expand Up @@ -333,9 +329,8 @@ def sgemm_shared_mem_block[

tmp = yield tmp

one = arith.constant(1.0, type=dtype)
C_ = C[c_row : c_row + BLOCK_SIZE, c_col : c_col + BLOCK_SIZE]
C_[thread_row, thread_col] = tmp + one
C_[thread_row, thread_col] = tmp + 1


def prepare_non_tiled_kernel(ctx: MLIRContext, kernel, M, K, N, BLOCK_SIZE=32):
Expand Down Expand Up @@ -429,10 +424,9 @@ def sgemm_shared_mem_1d_block_tiling[

gpu.barrier()

one = arith.constant(1.0, type=dtype)
C_ = C[c_row : c_row + BM, c_col : c_col + BN]
for res_idx in range_(TM):
C_[thread_row * TM + res_idx, thread_col] = thread_results[res_idx] + one
C_[thread_row * TM + res_idx, thread_col] = thread_results[res_idx] + 1


@gpu.func
Expand Down Expand Up @@ -512,13 +506,12 @@ def sgemm_shared_mem_2d_block_tiling[

gpu.barrier()

one = arith.constant(1.0, type=dtype)
C_ = C[c_row : c_row + BM, c_col : c_col + BN]

for res_idx_m in range_(TM):
for res_idx_n in range_(TN):
C_[thread_row * TM + res_idx_m, thread_col * TN + res_idx_n] = (
thread_results[res_idx_m, res_idx_n] + one
thread_results[res_idx_m, res_idx_n] + 1
)


Expand Down Expand Up @@ -612,7 +605,6 @@ def sgemm_shared_mem_2d_block_tiling_vectorize[

gpu.barrier()

one = arith.constant(1.0, type=dtype)
C_ = C[c_row : c_row + BM, c_col : c_col + BN]

for res_idx_m in range_(TM):
Expand All @@ -623,7 +615,7 @@ def sgemm_shared_mem_2d_block_tiling_vectorize[
[thread_row * TM + res_idx_m, thread_col * TN + res_idx_n],
)
for j in range(VECTOR_WIDTH):
tmp[j] = thread_results[res_idx_m, res_idx_n + j] + one
tmp[j] = thread_results[res_idx_m, res_idx_n + j] + 1
vector.store(
tmp, C_, [thread_row * TM + res_idx_m, thread_col * TN + res_idx_n]
)
Expand Down Expand Up @@ -771,8 +763,6 @@ def sgemm_warp_tiling[

gpu.barrier()

one = arith.constant(1.0, type=dtype)

for w_sub_row_idx in range_(WMITER):
for w_sub_col_idx in range_(WNITER):
r = c_row + warp_row * WM + w_sub_row_idx * WSUBM
Expand All @@ -794,7 +784,7 @@ def sgemm_warp_tiling[
w_sub_row_idx * TM + res_idx_m,
w_sub_col_idx * TN + res_idx_n + j,
]
+ one
+ 1
)
vector.store(
tmp,
Expand Down