We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
other=0.0
tl.load
triton_matmul_kernel
1 parent e2e0112 commit b904705Copy full SHA for b904705
matmul.py
@@ -93,8 +93,8 @@ def triton_matmul_kernel(
93
94
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
95
for i in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
96
- lhs = tl.load(lhs_ptrs, mask=offs_k[None, :] < k - i * BLOCK_SIZE_K, other=0.0)
97
- rhs = tl.load(rhs_ptrs, mask=offs_k[:, None] < k - i * BLOCK_SIZE_K, other=0.0)
+ lhs = tl.load(lhs_ptrs, mask=offs_k[None, :] < k - i * BLOCK_SIZE_K)
+ rhs = tl.load(rhs_ptrs, mask=offs_k[:, None] < k - i * BLOCK_SIZE_K)
98
accumulator = tl.dot(lhs, rhs, accumulator)
99
lhs_ptrs += BLOCK_SIZE_K * lhs_stride_k
100
rhs_ptrs += BLOCK_SIZE_K * rhs_stride_k
0 commit comments