Skip to content

Commit 6e1dafa

Browse files
[release/3.4] "[BACKEND] Workaround for ptxas bug in matrix descriptor arithmetic (triton-lang#7197)" (triton-lang#7389)
The previous code sequence was hitting a bug in ptxas that caused Emiting this new code sequence should be cheaper and saves us from hitting the ptxas bug. Co-authored-by: Thomas Raoux <[email protected]>
1 parent ae84826 commit 6e1dafa

File tree

1 file changed

+4
-8
lines changed
  • third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM

1 file changed

+4
-8
lines changed

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,10 @@ Value mlir::triton::NVIDIA::DotOpMmaV3SmemLoader::smemLoad(
151151
} else {
152152
off1 = tb.mul(tb.i32_val(elemBits / 8), offset);
153153
}
154-
Value off_ = tb.zext(i64_ty, tb.udiv(off1, tb.i32_val(16)));
155-
156-
Value loadDesc = tb.add(descriptor, off_);
157-
// Add the base at the end to make it easier to do loop invariant code
158-
// motion.
159-
loadDesc = tb.add(
160-
loadDesc, tb.lshr(tb.shl(tb.ptrtoint(i64_ty, base), tb.int_val(64, 46)),
161-
tb.int_val(64, 50)));
154+
Value smemBase = tb.ptrtoint(i32_ty, base);
155+
smemBase = tb.add(smemBase, off1);
156+
smemBase = tb.lshr(tb.and_(smemBase, tb.i32_val(0x3FFFF)), tb.i32_val(4));
157+
Value loadDesc = tb.add(descriptor, tb.zext(i64_ty, smemBase));
162158
return loadDesc;
163159
}
164160

0 commit comments

Comments
 (0)