Skip to content

Commit b3b1b13

Browse files
authored
[Blackwell] Fix codegen for tmem_load of Nx1xf32 (#7234)
LLVM's `inline_asm` is only allowed to return a struct when the asm has more than 1 result. This also makes unpacked `Nx2xf16` work but `Nx1xf16` still crashes. It can be supported if it is needed.
1 parent e7cef2e commit b3b1b13

File tree

2 files changed

+53
-6
lines changed

2 files changed

+53
-6
lines changed

test/Conversion/tritongpu_to_llvm_blackwell.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,7 @@ tt.func public @tmem_message_maxnreg_80(%desc: !ttg.memdesc<128x64xf32, #tmem, #
495495
tt.return
496496
}
497497

498+
// CHECK-LABEL: @module_constraint_supercedes_local
498499
tt.func public @module_constraint_supercedes_local(%desc: !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory>) {
499500
ttg.warp_specialize(%desc) attributes {actualRegisters = array<i32: 256, 256>}
500501
default {
@@ -611,6 +612,10 @@ tt.func private @reinterpret(%arg0: !ttg.memdesc<128xf32, #tmem, #ttng.tensor_me
611612

612613
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = false>
613614
#tmem_unpacked = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
615+
#tmem_x1 = #ttng.tensor_memory_encoding<blockM = 128, blockN = 1, unpacked = false>
616+
#tmem_x1_unpacked = #ttng.tensor_memory_encoding<blockM = 128, blockN = 2, unpacked = true>
617+
618+
#blocked_x1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
614619

615620
module attributes {"ttg.num-warps" = 4 : i32} {
616621

@@ -633,4 +638,29 @@ tt.func private @subslice_packed(%arg0: !ttg.memdesc<128x128xf16, #tmem, #ttng.t
633638
tt.return %0 : !ttg.memdesc<128x64xf16, #tmem, #ttng.tensor_memory>
634639
}
635640

641+
// CHECK-LABEL: @load_store_x1
642+
tt.func @load_store_x1(%arg0: !ttg.memdesc<128x1xf32, #tmem_x1, #ttng.tensor_memory, mutable>) {
643+
%true = arith.constant true
644+
// CHECK: [[V:%.*]] = llvm.inline_asm {{.*}}tcgen05.ld.sync{{.*}} (i32) -> i32
645+
// CHECK: [[F:%.*]] = llvm.bitcast [[V]] : i32 to f32
646+
// CHECK: insertvalue [[F]], {{.*}} : !llvm.struct<(f32)>
647+
%0 = ttng.tmem_load %arg0 : !ttg.memdesc<128x1xf32, #tmem_x1, #ttng.tensor_memory, mutable> -> tensor<128x1xf32, #blocked_x1>
648+
ttng.tmem_store %0, %arg0, %true : tensor<128x1xf32, #blocked_x1> -> !ttg.memdesc<128x1xf32, #tmem_x1, #ttng.tensor_memory, mutable>
649+
tt.return
650+
}
651+
652+
// CHECK-LABEL: @load_store_x1_unpacked
653+
tt.func @load_store_x1_unpacked(%arg0: !ttg.memdesc<128x2xf16, #tmem_x1_unpacked, #ttng.tensor_memory, mutable>) {
654+
%true = arith.constant true
655+
// CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32)
656+
// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32)
657+
// CHECK: [[V:%.*]] = llvm.inline_asm {{.*}}tcgen05.ld.sync{{.*}} (i32) -> i32
658+
// CHECK: [[F:%.*]] = llvm.bitcast [[V]] : i32 to vector<2xf16>
659+
// CHECK: extractelement [[F]][[[C0]] : i32]
660+
// CHECK: extractelement [[F]][[[C1]] : i32]
661+
%0 = ttng.tmem_load %arg0 : !ttg.memdesc<128x2xf16, #tmem_x1_unpacked, #ttng.tensor_memory, mutable> -> tensor<128x2xf16, #blocked_x1>
662+
ttng.tmem_store %0, %arg0, %true : tensor<128x2xf16, #blocked_x1> -> !ttg.memdesc<128x2xf16, #tmem_x1_unpacked, #ttng.tensor_memory, mutable>
663+
tt.return
664+
}
665+
636666
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ void calculateAddressAndEmitTmemMessage(
355355
// are required to cover the entire set of rows per warp.
356356
int numRowPerWarp =
357357
(info.layoutAtom.rowStored == 16 && info.blockM == 64) ? 16 : 32;
358+
358359
for (int rowStart = 0; rowStart < numRowPerWarp;
359360
rowStart += message.numRows) {
360361
for (int colStart = 0; colStart < numColumns;
@@ -590,10 +591,17 @@ Value createTensorMemoryLoad(Location loc, triton::nvidia_gpu::TMEMLoadOp op,
590591
operands.push_back(ptxBuilder.newOperand(address, "r"));
591592
auto &ld = *ptxBuilder.create<PTXInstr>(opcode);
592593
ld(operands, /*onlyAttachMLIRArgs=*/true);
593-
SmallVector<Type> elemTypes(numRegPerMessage, i32_ty);
594-
MLIRContext *ctx = op.getContext();
595-
Type structTy = struct_ty(elemTypes);
596-
Value ret = ptxBuilder.launch(rewriter, loc, structTy);
594+
595+
// LLVM inline_asm with 1 result cannot return a struct.
596+
Type retTy;
597+
if (numRegPerMessage == 1) {
598+
retTy = i32_ty;
599+
} else {
600+
SmallVector<Type> elemTypes(numRegPerMessage, i32_ty);
601+
MLIRContext *ctx = op.getContext();
602+
retTy = struct_ty(elemTypes);
603+
}
604+
Value ret = ptxBuilder.launch(rewriter, loc, retTy);
597605
return ret;
598606
}
599607

@@ -606,8 +614,8 @@ static SmallVector<Value> unpackResults(Value packedValues, Type elemTy,
606614
Type packedType = elemTy;
607615
if (numElementsPer32B > 1)
608616
packedType = vec_ty(elemTy, numElementsPer32B);
609-
for (int i = 0; i < numCols; i++) {
610-
Value result = b.extract_val(i32_ty, packedValues, i);
617+
618+
auto unpackElement = [&](Value result) {
611619
result = b.bitcast(result, packedType);
612620
if (numElementsPer32B > 1) {
613621
for (int j = 0; j < numElementsPer32B; j++) {
@@ -617,6 +625,15 @@ static SmallVector<Value> unpackResults(Value packedValues, Type elemTy,
617625
} else {
618626
resultVals.push_back(result);
619627
}
628+
};
629+
630+
if (isa<LLVM::LLVMStructType>(packedValues.getType())) {
631+
for (int i = 0; i < numCols; i++) {
632+
Value result = b.extract_val(i32_ty, packedValues, i);
633+
unpackElement(result);
634+
}
635+
} else {
636+
unpackElement(packedValues);
620637
}
621638
return resultVals;
622639
}

0 commit comments

Comments
 (0)