Skip to content

Commit 82e7a32

Browse files
authored
[IR][BACKEND] Add the trans attribute to LoadMatrix and StoreMatrix ops (#5467)
Following @ThomasRaoux's suggestion, this way we ensure that the trans attribute is not discarded and keeps the IR clear
1 parent 4828693 commit 82e7a32

File tree

4 files changed

+19
-9
lines changed

4 files changed

+19
-9
lines changed

test/Conversion/nvgpu_to_llvm.mlir

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ llvm.func @cluster_id() -> i32 {
4141
llvm.func @stmatrix(%i: i32, %ptr: !llvm.ptr<3>) {
4242
// CHECK: stmatrix.sync.aligned.m8n8.x4.shared.b16 [$0], {$1, $2, $3, $4};
4343
nvgpu.stmatrix %ptr, %i, %i, %i, %i : !llvm.ptr<3>, i32, i32, i32, i32
44+
// CHECK: stmatrix.sync.aligned.m8n8.x4.trans.shared.b16 [$0], {$1, $2, $3, $4};
45+
nvgpu.stmatrix %ptr, %i, %i, %i, %i {trans} : !llvm.ptr<3>, i32, i32, i32, i32
4446
llvm.return
4547
}
4648

@@ -50,7 +52,11 @@ llvm.func @stmatrix(%i: i32, %ptr: !llvm.ptr<3>) {
5052
llvm.func @ldmatrix(%ptr: !llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> {
5153
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];
5254
%0 = nvgpu.ldmatrix %ptr : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
53-
llvm.return %0 : !llvm.struct<(i32, i32, i32, i32)>
55+
// CHECK: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {$0, $1, $2, $3}, [$4];
56+
%1 = nvgpu.ldmatrix %ptr {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
57+
%2 = llvm.extractvalue %1[0] : !llvm.struct<(i32, i32, i32, i32)>
58+
%3 = llvm.insertvalue %2, %0[0] : !llvm.struct<(i32, i32, i32, i32)>
59+
llvm.return %3 : !llvm.struct<(i32, i32, i32, i32)>
5460
}
5561

5662
// -----

third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,19 @@ def NVGPU_ClusterWaitOp : NVGPU_Op<"cluster_wait", []> {
101101
}
102102

103103
def NVGPU_StoreMatrixOp : NVGPU_Op<"stmatrix", [MemoryEffects<[MemWrite]>]> {
104-
let arguments = (ins LLVM_PointerShared:$addr, Variadic<I32>:$vals);
104+
let arguments = (
105+
ins LLVM_PointerShared:$addr,
106+
Variadic<I32>:$vals,
107+
UnitAttr:$trans
108+
);
105109
let assemblyFormat = "operands attr-dict `:` type(operands)";
106110
}
107111

108112
def NVGPU_LoadMatrixOp : NVGPU_Op<"ldmatrix", [MemoryEffects<[MemRead]>]> {
109-
let arguments = (ins LLVM_PointerShared:$addr);
113+
let arguments = (
114+
ins LLVM_PointerShared:$addr,
115+
UnitAttr:$trans
116+
);
110117
let results = (outs LLVM_AnyStruct:$result);
111118
let assemblyFormat = "$addr attr-dict `:` functional-type($addr, $result)";
112119
}

third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,10 +243,7 @@ class MatrixOpPattern : public OpRewritePattern<MatrixOpType> {
243243
LogicalResult matchAndRewrite(MatrixOpType op,
244244
PatternRewriter &rewriter) const override {
245245
unsigned vecSize = getVectorSize(op);
246-
bool trans = op->hasAttr("trans")
247-
? op->template getAttrOfType<BoolAttr>("trans").getValue()
248-
: false;
249-
246+
bool trans = op.getTrans();
250247
// Template method for PTX assembly generation
251248
std::string ptxAsm =
252249
(llvm::Twine(ConcreteMatrixOpPattern::kOpCode) +

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,8 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef<Value> ptrs,
341341
stridedOffset = add(
342342
stridedOffset, mul(i32_val(batch * warpsPerCTA[0]), smemBatchOffset));
343343
Value readPtr = gep(ptr_ty(ctx, 3), shemTy, ptr, stridedOffset);
344-
auto ldMatrixOp = rewriter.create<nvgpu::LoadMatrixOp>(loc, resTy, readPtr);
345-
ldMatrixOp->setAttr("trans", rewriter.getBoolAttr(needTrans));
344+
auto ldMatrixOp =
345+
rewriter.create<nvgpu::LoadMatrixOp>(loc, resTy, readPtr, needTrans);
346346
auto resV4 = ldMatrixOp.getResult();
347347
return {extract_val(elemTy, resV4, 0), extract_val(elemTy, resV4, 1),
348348
extract_val(elemTy, resV4, 2), extract_val(elemTy, resV4, 3)};

0 commit comments

Comments
 (0)