Skip to content

Commit 3efde92

Browse files
ThomasRaouxloislo
authored andcommitted
[BACKEND] Fix crash in mmav5 lhs comes from tmem (triton-lang#6011)
1 parent 44a64f0 commit 3efde92

File tree

2 files changed

+30
-5
lines changed

2 files changed

+30
-5
lines changed

test/Conversion/tritongpu_to_llvm_blackwell.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,3 +387,23 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar
387387
tt.return
388388
}
389389
}
390+
391+
// -----
392+
393+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 16}>
394+
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
395+
#smem = #ttg.shared_memory
396+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 32, unpacked = false>
397+
#tmem1 = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
398+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
399+
tt.func @tc_gen5_mma_lhs_tmem(%arg0: !ttg.memdesc<128x32xf16, #tmem, #ttng.tensor_memory>, %arg1: !ttg.memdesc<32x128xf16, #shared, #smem>, %arg2: !ttg.memdesc<128x128xf32, #tmem1, #ttng.tensor_memory, mutable>, %arg3: i1, %arg4: i1, %arg5: !ttg.memdesc<1xi64, #shared1, #smem>) {
400+
// CHECK-LABEL: tc_gen5_mma_lhs_tmem
401+
// CHECK: tcgen05.mma.cta_group::1.kind::f16
402+
ttng.tc_gen5_mma %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : (
403+
!ttg.memdesc<128x32xf16, #tmem, #ttng.tensor_memory>,
404+
!ttg.memdesc<32x128xf16, #shared, #smem>,
405+
!ttg.memdesc<128x128xf32, #tmem1, #ttng.tensor_memory, mutable>,
406+
i1, i1, !ttg.memdesc<1xi64, #shared1, #smem>) -> ()
407+
tt.return
408+
}
409+
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -347,11 +347,16 @@ void convertDot(const LLVMTypeConverter *typeConverter,
347347
}
348348
auto bSharedLayout = cast<NVMMASharedEncodingAttr>(bTensorTy.getEncoding());
349349
bool transB = !bSharedLayout.getTransposed();
350-
Value baseA =
351-
getSharedMemoryObjectFromStruct(
352-
loc, loadedA, typeConverter->convertType(aTensorTy.getElementType()),
353-
rewriter)
354-
.getBase();
350+
Value baseA;
351+
if (aInTmem) {
352+
baseA = loadedA;
353+
} else {
354+
baseA =
355+
getSharedMemoryObjectFromStruct(
356+
loc, loadedA,
357+
typeConverter->convertType(aTensorTy.getElementType()), rewriter)
358+
.getBase();
359+
}
355360
Value baseB =
356361
getSharedMemoryObjectFromStruct(
357362
loc, loadedB, typeConverter->convertType(bTensorTy.getElementType()),

0 commit comments

Comments
 (0)