Skip to content

Commit b904705

Browse files
committed
Remove the other=0.0 argument from the tl.load calls in the triton_matmul_kernel function
1 parent e2e0112 commit b904705

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

matmul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ def triton_matmul_kernel(
9393

9494
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
9595
for i in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
96-
lhs = tl.load(lhs_ptrs, mask=offs_k[None, :] < k - i * BLOCK_SIZE_K, other=0.0)
97-
rhs = tl.load(rhs_ptrs, mask=offs_k[:, None] < k - i * BLOCK_SIZE_K, other=0.0)
96+
lhs = tl.load(lhs_ptrs, mask=offs_k[None, :] < k - i * BLOCK_SIZE_K)
97+
rhs = tl.load(rhs_ptrs, mask=offs_k[:, None] < k - i * BLOCK_SIZE_K)
9898
accumulator = tl.dot(lhs, rhs, accumulator)
9999
lhs_ptrs += BLOCK_SIZE_K * lhs_stride_k
100100
rhs_ptrs += BLOCK_SIZE_K * rhs_stride_k

0 commit comments

Comments
 (0)