Skip to content

Commit dca70ac

Browse files
authored
[BACKEND] Fix indexing into TMEM for lhs mma operands (#6888)
The indexing math was incorrect when accesing tmem as lhs operand. This fixes some random failures on attention tutorial when BLOCK_M=256 is picked
1 parent 6b06242 commit dca70ac

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

python/test/unit/language/test_matmul.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,8 @@ def flatten_scale(scale):
537537
print(f"SWP failed for M = {M}, N = {N}")
538538

539539

540-
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 64), (128, 64, 128), (64, 128, 32), (128, 256, 32)])
540+
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 64), (128, 64, 128), (64, 128, 32), (128, 256, 32),
541+
(256, 64, 32)])
541542
@pytest.mark.parametrize("a_trans", [False, True])
542543
@pytest.mark.parametrize("dtype_src_str", ["float32", "float16", "float8e5"])
543544
@pytest.mark.skipif(is_hip() or torch.cuda.get_device_capability()[0] != 10, reason="Requires compute capability == 10")

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class DotOpMmaV5TmemLoader : public DotOpMmaMemLoader {
9595
SmallVector<unsigned int> instrShape;
9696
int numElementsPer32b;
9797
int numRepM;
98+
int numSlicePerBlockN;
9899
};
99100

100101
} // namespace NVIDIA

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ mlir::triton::NVIDIA::DotOpMmaV5TmemLoader::DotOpMmaV5TmemLoader(
2626
auto ty = cast<MemDescType>(tensor.getType());
2727
auto tmemEncoding = cast<ttng::TensorMemoryEncodingAttr>(ty.getEncoding());
2828
unpacked = tmemEncoding.getUnpacked();
29+
// When using TMEM to store operands mma operands the TMEM block size may be
30+
// smaller than mma k block. Therefore we need to adjust the offset
31+
// calculation.
32+
numSlicePerBlockN = tmemEncoding.getBlockN() / instrShape[1];
2933
int elTyWidth = ty.getElementTypeBitWidth();
3034
numElementsPer32b = unpacked ? 1 : 32 / elTyWidth;
3135
auto shapePerCTA = triton::gpu::getShapePerCTA(ty);
@@ -38,8 +42,9 @@ MemDescOperand mlir::triton::NVIDIA::DotOpMmaV5TmemLoader::tmemLoad(
3842
if (interleaved || instrShape[0] >= 128)
3943
numRows = 128;
4044
int numColPerBlock =
41-
((instrShape[0] * instrShape[1]) / numRows) / numElementsPer32b;
42-
int blockId = a + b * numRepM;
45+
((instrShape[0] * numSlicePerBlockN * instrShape[1]) / numRows) /
46+
numElementsPer32b;
47+
int blockId = a + (b / numSlicePerBlockN) * numRepM;
4348
int offset;
4449
if (!interleaved) {
4550
offset = numColPerBlock * blockId;
@@ -48,7 +53,7 @@ MemDescOperand mlir::triton::NVIDIA::DotOpMmaV5TmemLoader::tmemLoad(
4853
int blockIdPrevEven = blockId - blockIdIsOdd;
4954
offset = numColPerBlock * blockIdPrevEven + ((16 * blockIdIsOdd) << 16);
5055
}
51-
56+
offset += (b % numSlicePerBlockN) * (instrShape[1] / numElementsPer32b);
5257
auto tb = TritonLLVMOpBuilder(loc, rewriter);
5358
Value address = tb.ptrtoint(i32_ty, base);
5459
return {address, offset};

0 commit comments

Comments
 (0)