Skip to content

Commit 43f1ad4

Browse files
authored
[IR][BACKEND] Introduce nvgpu.ldmatrix IR (#5442)
The purpose is to replace the legacy way of using `ldmatrix` through PTXBuilder.
1 parent 5132916 commit 43f1ad4

File tree

5 files changed

+178
-91
lines changed

5 files changed

+178
-91
lines changed

test/Conversion/nvgpu_to_llvm.mlir

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,24 @@ llvm.func @cluster_id() -> i32 {
3737

3838
// -----
3939

40-
// CHECK-LABEL: @st_matrix
41-
llvm.func @st_matrix(%i: i32, %ptr: !llvm.ptr<3>) {
40+
// CHECK-LABEL: @stmatrix
41+
llvm.func @stmatrix(%i: i32, %ptr: !llvm.ptr<3>) {
4242
// CHECK: stmatrix.sync.aligned.m8n8.x4.shared.b16 [$0], {$1, $2, $3, $4};
4343
nvgpu.stmatrix %ptr, %i, %i, %i, %i : !llvm.ptr<3>, i32, i32, i32, i32
4444
llvm.return
4545
}
4646

4747
// -----
4848

49+
// CHECK-LABEL: @ldmatrix
50+
llvm.func @ldmatrix(%ptr: !llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> {
51+
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];
52+
%0 = nvgpu.ldmatrix %ptr : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
53+
llvm.return %0 : !llvm.struct<(i32, i32, i32, i32)>
54+
}
55+
56+
// -----
57+
4958
!struct_128xf32 = !llvm.struct<(
5059
f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
5160
f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -845,10 +845,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
845845
tt.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
846846
%AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
847847
%BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
848-
// CHECK: llvm.inline_asm
849-
// CHECK: ldmatrix.sync.aligned.m8n8.x4
850-
// CHECK: llvm.inline_asm
851-
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4
848+
// CHECK: nvgpu.ldmatrix
849+
// CHECK: nvgpu.ldmatrix
852850
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
853851
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b>
854852
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
@@ -876,8 +874,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
876874
tt.func @convert_dot_fp8(%A: tensor<16x16xf8E5M2, #blocked0>, %B: tensor<16x16xf8E5M2, #blocked0>) {
877875
%AA = ttg.local_alloc %A : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem>
878876
%BB = ttg.local_alloc %B : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem>
879-
// CHECK: llvm.inline_asm
880-
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4
877+
// CHECK: nvgpu.ldmatrix
881878
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_a>
882879
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_b>
883880
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
@@ -1177,7 +1174,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11771174
tt.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
11781175
%a:!ttg.memdesc<128x32xf16, #shared, #smem>, %b:!ttg.memdesc<32x256xf16, #shared, #smem>) {
11791176
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
1180-
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16
1177+
// CHECK: nvgpu.ldmatrix
11811178
%a_mat = ttg.local_load %a : !ttg.memdesc<128x32xf16, #shared, #smem> -> tensor<128x32xf16, #dot_operand_a>
11821179
%b_mat = ttg.local_load %b : !ttg.memdesc<32x256xf16, #shared, #smem> -> tensor<32x256xf16, #dot_operand_b>
11831180

@@ -1227,11 +1224,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
12271224
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
12281225
%a:!ttg.memdesc<32x16xf32, #shared, #smem>, %b:!ttg.memdesc<16x32xf32, #shared, #smem>) {
12291226
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
1230-
// CHECK: llvm.inline_asm
1231-
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16
1227+
// CHECK: nvgpu.ldmatrix
12321228
// CHECK-SAME: (i32, i32, i32, i32)
1233-
// CHECK: llvm.inline_asm
1234-
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16
1229+
// CHECK: nvgpu.ldmatrix
12351230
// CHECK-SAME: (i32, i32, i32, i32)
12361231
%a_mat = ttg.local_load %a : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #dot_operand_a>
12371232
%b_mat = ttg.local_load %b : !ttg.memdesc<16x32xf32, #shared, #smem> -> tensor<16x32xf32, #dot_operand_b>
@@ -1720,10 +1715,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
17201715
%f16_shared = ttg.local_alloc %f16_inp : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
17211716
%i16_shared = ttg.local_alloc %i16_inp : (tensor<16x16xi16, #blocked0>) -> !ttg.memdesc<16x16xi16, #shared0, #smem>
17221717

1723-
// CHECK: llvm.inline_asm
1724-
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4
1725-
// CHECK: llvm.inline_asm
1726-
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4
1718+
// CHECK: nvgpu.ldmatrix
1719+
// CHECK: nvgpu.ldmatrix
17271720

17281721
%f16_dot = ttg.local_load %f16_shared : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
17291722
%i16_dot = ttg.local_load %i16_shared : !ttg.memdesc<16x16xi16, #shared0, #smem> -> tensor<16x16xi16, #dot_operand_b>

third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,10 @@ def NVGPU_WGMMACommitGroupOp : NVGPU_Op<"wgmma_commit_group", []> {
4343
let assemblyFormat = "attr-dict";
4444
}
4545

46-
def NVGPU_WGMMAWaitGroupOp : NVGPU_Op<"wgmma_wait_group",
47-
[DeclareOpInterfaceMethods<InferTypeOpInterface>,
48-
AllTypesMatch<["input", "output"]>]> {
46+
def NVGPU_WGMMAWaitGroupOp : NVGPU_Op<"wgmma_wait_group", [DeclareOpInterfaceMethods<InferTypeOpInterface>,
47+
AllTypesMatch<["input", "output"]>]> {
4948
let arguments = (ins LLVM_AnyStruct:$input, I32Attr:$pendings);
5049
let results = (outs LLVM_AnyStruct:$output);
51-
let assemblyFormat = "attr-dict";
5250
let assemblyFormat = "$input attr-dict `:` type($input)";
5351
}
5452

@@ -103,10 +101,16 @@ def NVGPU_ClusterWaitOp : NVGPU_Op<"cluster_wait", []> {
103101
}
104102

105103
def NVGPU_StoreMatrixOp : NVGPU_Op<"stmatrix", [MemoryEffects<[MemWrite]>]> {
106-
let arguments = (ins LLVM_PointerShared:$addr, Variadic<I32>:$datas);
104+
let arguments = (ins LLVM_PointerShared:$addr, Variadic<I32>:$vals);
107105
let assemblyFormat = "operands attr-dict `:` type(operands)";
108106
}
109107

108+
def NVGPU_LoadMatrixOp : NVGPU_Op<"ldmatrix", [MemoryEffects<[MemRead]>]> {
109+
let arguments = (ins LLVM_PointerShared:$addr);
110+
let results = (outs LLVM_AnyStruct:$result);
111+
let assemblyFormat = "$addr attr-dict `:` functional-type($addr, $result)";
112+
}
113+
110114
def NVGPU_ClusterCTAIdOp : NVGPU_Op<"cluster_id", [Pure]> {
111115
let results = (outs I32:$result);
112116
let assemblyFormat = "attr-dict";

third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp

Lines changed: 146 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,20 @@ using ttn::OperandsAndConstraints;
2323

2424
namespace {
2525

26-
const std::string Wgmma_Fence_Op = "wgmma.fence.sync.aligned;";
27-
const std::string Wgmma_Commit_Group_Op = "wgmma.commit_group.sync.aligned;";
28-
const std::string Cluster_Wait_Op = "barrier.cluster.wait.aligned;";
29-
const std::string Fence_Mbarrier_Init_Op =
30-
"fence.mbarrier_init.release.cluster;";
31-
const std::string Cluster_Cta_Id_Op = "{\n"
32-
".reg .u32 a<5>; \n"
33-
"mov.u32 a0, %cluster_ctaid.x;\n" // x
34-
"mov.u32 a1, %cluster_ctaid.y;\n" // y
35-
"mov.u32 a2, %cluster_ctaid.z;\n" // z
36-
"mov.u32 a3, %cluster_nctaid.x;\n" // nx
37-
"mov.u32 a4, %cluster_nctaid.y;\n" // ny
38-
"mad.lo.u32 a1, a2, a4, a1; \n"
39-
"mad.lo.u32 $0, a1, a3, a0; \n"
40-
"}";
26+
const std::string kWgmmaFenceOp = "wgmma.fence.sync.aligned;";
27+
const std::string kWgmmaCommitGroupOp = "wgmma.commit_group.sync.aligned;";
28+
const std::string kClusterWaitOp = "barrier.cluster.wait.aligned;";
29+
const std::string kFenceMbarrierInitOp = "fence.mbarrier_init.release.cluster;";
30+
const std::string kClusterCtaIdOp = "{\n"
31+
".reg .u32 a<5>; \n"
32+
"mov.u32 a0, %cluster_ctaid.x;\n" // x
33+
"mov.u32 a1, %cluster_ctaid.y;\n" // y
34+
"mov.u32 a2, %cluster_ctaid.z;\n" // z
35+
"mov.u32 a3, %cluster_nctaid.x;\n" // nx
36+
"mov.u32 a4, %cluster_nctaid.y;\n" // ny
37+
"mad.lo.u32 a1, a2, a4, a1; \n"
38+
"mad.lo.u32 $0, a1, a3, a0; \n"
39+
"}";
4140

4241
bool isNumber(const std::string &s) {
4342
return !s.empty() && std::find_if(s.begin(), s.end(), [](unsigned char c) {
@@ -235,46 +234,141 @@ class ClusterArriveOpPattern : public OpRewritePattern<ttn::ClusterArriveOp> {
235234
}
236235
};
237236

238-
class StoreMatrixOpPattern : public OpRewritePattern<ttn::StoreMatrixOp> {
237+
// Base class for Matrix Operation Patterns
238+
template <typename MatrixOpType, typename ConcreteMatrixOpPattern>
239+
class MatrixOpPattern : public OpRewritePattern<MatrixOpType> {
239240
public:
240-
using OpRewritePattern<ttn::StoreMatrixOp>::OpRewritePattern;
241+
using OpRewritePattern<MatrixOpType>::OpRewritePattern;
241242

242-
LogicalResult matchAndRewrite(ttn::StoreMatrixOp op,
243+
LogicalResult matchAndRewrite(MatrixOpType op,
243244
PatternRewriter &rewriter) const override {
244-
return rewriteAsPtxAsm(op, rewriter, getPtxAsm(op),
245-
getOperandsAndConstraints(op));
246-
}
247-
248-
OperandsAndConstraints
249-
getOperandsAndConstraints(ttn::StoreMatrixOp op) const {
250-
OperandsAndConstraints operandsAndTypes;
251-
auto addr = op.getAddr();
252-
auto datas = op.getDatas();
253-
operandsAndTypes.push_back({addr, "r"});
254-
for (unsigned i = 0; i < datas.size(); i++) {
255-
operandsAndTypes.push_back({datas[i], "r"});
256-
}
257-
return operandsAndTypes;
245+
unsigned vecSize = getVectorSize(op);
246+
bool trans = op->hasAttr("trans")
247+
? op->template getAttrOfType<BoolAttr>("trans").getValue()
248+
: false;
249+
250+
// Template method for PTX assembly generation
251+
std::string ptxAsm =
252+
(llvm::Twine(ConcreteMatrixOpPattern::kOpCode) +
253+
getPtxModifiers(vecSize, trans) + " " + getOperands(op, vecSize) + ";")
254+
.str();
255+
256+
OperandsAndConstraints operandAndConstraints =
257+
getOperandsAndConstraints(op, vecSize);
258+
Constraints outputConstraints = getOutputConstraints(op, vecSize);
259+
260+
return rewriteAsPtxAsm(op, rewriter, ptxAsm, operandAndConstraints,
261+
outputConstraints);
258262
}
259263

260-
std::string getPtxAsm(ttn::StoreMatrixOp op) const {
261-
auto datas = op.getDatas();
262-
std::string ptxAsm;
263-
switch (datas.size()) {
264+
protected:
265+
// Shared helper methods
266+
std::string getPtxModifiers(unsigned vecSize, bool trans) const {
267+
auto ptxAsmBase = llvm::Twine(".sync.aligned.m8n8");
268+
const std::string suffix = trans ? ".trans.shared.b16" : ".shared.b16";
269+
switch (vecSize) {
264270
case 1:
265-
ptxAsm = "stmatrix.sync.aligned.m8n8.x1.shared.b16 [$0], {$1};";
266-
break;
271+
return (ptxAsmBase + ".x1" + suffix).str();
267272
case 2:
268-
ptxAsm = "stmatrix.sync.aligned.m8n8.x2.shared.b16 [$0], {$1, $2};";
269-
break;
273+
return (ptxAsmBase + ".x2" + suffix).str();
270274
case 4:
271-
ptxAsm =
272-
"stmatrix.sync.aligned.m8n8.x4.shared.b16 [$0], {$1, $2, $3, $4};";
273-
break;
275+
return (ptxAsmBase + ".x4" + suffix).str();
274276
default:
275-
assert(false && "Invalid size");
277+
assert(false && "Invalid vector size");
276278
}
277-
return ptxAsm;
279+
}
280+
281+
std::string getPtxRegOperands(unsigned startIdx, unsigned count) const {
282+
llvm::SmallString<20> regOperands;
283+
llvm::raw_svector_ostream stream(regOperands);
284+
stream << "{";
285+
for (unsigned i = 0; i < count; i++) {
286+
stream << "$" + llvm::utostr(startIdx + i);
287+
if (i != count - 1)
288+
stream << ", ";
289+
}
290+
stream << "}";
291+
return std::string(regOperands.str());
292+
}
293+
294+
std::string getPtxAddrOperand(unsigned idx) const {
295+
return (llvm::Twine("[$") + llvm::utostr(idx) + "]").str();
296+
}
297+
298+
virtual std::string getOperands(MatrixOpType op, unsigned vecSize) const = 0;
299+
virtual OperandsAndConstraints
300+
getOperandsAndConstraints(MatrixOpType op, unsigned vecSize) const = 0;
301+
virtual Constraints getOutputConstraints(MatrixOpType op,
302+
unsigned vecSize) const = 0;
303+
virtual unsigned getVectorSize(MatrixOpType op) const = 0;
304+
};
305+
306+
// StoreMatrixOp Pattern
307+
class StoreMatrixOpPattern
308+
: public MatrixOpPattern<ttn::StoreMatrixOp, StoreMatrixOpPattern> {
309+
public:
310+
using MatrixOpPattern<ttn::StoreMatrixOp,
311+
StoreMatrixOpPattern>::MatrixOpPattern;
312+
static constexpr const char *kOpCode = "stmatrix";
313+
314+
protected:
315+
unsigned getVectorSize(ttn::StoreMatrixOp op) const override {
316+
return op.getVals().size();
317+
}
318+
319+
std::string getOperands(ttn::StoreMatrixOp op,
320+
unsigned vecSize) const override {
321+
return (llvm::Twine(getPtxAddrOperand(0)) + ", " +
322+
getPtxRegOperands(1, vecSize))
323+
.str();
324+
}
325+
326+
OperandsAndConstraints
327+
getOperandsAndConstraints(ttn::StoreMatrixOp op,
328+
unsigned vecSize) const override {
329+
OperandsAndConstraints constraints = {{op.getAddr(), "r"}};
330+
for (unsigned i = 0; i < vecSize; i++) {
331+
constraints.push_back({op.getVals()[i], "r"});
332+
}
333+
return constraints;
334+
}
335+
336+
Constraints getOutputConstraints(ttn::StoreMatrixOp op,
337+
unsigned vecSize) const override {
338+
return {}; // No output constraints for StoreMatrixOp
339+
}
340+
};
341+
342+
// LoadMatrixOp Pattern
343+
class LoadMatrixOpPattern
344+
: public MatrixOpPattern<ttn::LoadMatrixOp, LoadMatrixOpPattern> {
345+
public:
346+
using MatrixOpPattern<ttn::LoadMatrixOp,
347+
LoadMatrixOpPattern>::MatrixOpPattern;
348+
static constexpr const char *kOpCode = "ldmatrix";
349+
350+
protected:
351+
unsigned getVectorSize(ttn::LoadMatrixOp op) const override {
352+
auto resultType = cast<LLVM::LLVMStructType>(op.getType());
353+
return resultType.getBody().size();
354+
}
355+
356+
std::string getOperands(ttn::LoadMatrixOp op,
357+
unsigned vecSize) const override {
358+
return (llvm::Twine(getPtxRegOperands(0, vecSize)) + ", " +
359+
getPtxAddrOperand(vecSize))
360+
.str();
361+
}
362+
363+
OperandsAndConstraints
364+
getOperandsAndConstraints(ttn::LoadMatrixOp op,
365+
unsigned vecSize) const override {
366+
return {{op.getAddr(), "r"}};
367+
}
368+
369+
Constraints getOutputConstraints(ttn::LoadMatrixOp op,
370+
unsigned vecSize) const override {
371+
return Constraints(vecSize, "=r");
278372
}
279373
};
280374

@@ -507,17 +601,16 @@ class ConvertNVGPUToLLVM : public ConvertNVGPUToLLVMBase<ConvertNVGPUToLLVM> {
507601
#define POPULATE_NVGPU_OP(SRC_OP, ASM) \
508602
patterns.add<NVGPUOpGenericPattern<SRC_OP>>(context, ASM, Constraints(), \
509603
Constraints());
510-
POPULATE_NVGPU_OP(ttn::WGMMAFenceOp, Wgmma_Fence_Op)
511-
POPULATE_NVGPU_OP(ttn::WGMMACommitGroupOp, Wgmma_Commit_Group_Op)
512-
POPULATE_NVGPU_OP(ttn::ClusterWaitOp, Cluster_Wait_Op)
604+
POPULATE_NVGPU_OP(ttn::WGMMAFenceOp, kWgmmaFenceOp)
605+
POPULATE_NVGPU_OP(ttn::WGMMACommitGroupOp, kWgmmaCommitGroupOp)
606+
POPULATE_NVGPU_OP(ttn::ClusterWaitOp, kClusterWaitOp)
513607
#undef POPULATE_NVGPU_OP
514608
patterns.add<NVGPUOpGenericPattern<ttn::ClusterCTAIdOp>>(
515-
context, Cluster_Cta_Id_Op, Constraints({"=r"}), Constraints());
609+
context, kClusterCtaIdOp, Constraints({"=r"}), Constraints());
516610

517-
patterns
518-
.add<FenceAsyncSharedOpPattern, StoreMatrixOpPattern,
519-
ClusterArriveOpPattern, WGMMAOpPattern, WGMMAWaitGroupOpPattern>(
520-
context);
611+
patterns.add<FenceAsyncSharedOpPattern, LoadMatrixOpPattern,
612+
StoreMatrixOpPattern, ClusterArriveOpPattern, WGMMAOpPattern,
613+
WGMMAWaitGroupOpPattern>(context);
521614

522615
if (applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed())
523616
signalPassFailure();

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h"
22
#include "Utility.h"
33
#include "mlir/Support/LLVM.h"
4+
#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h"
45

56
using namespace mlir;
67

@@ -339,23 +340,10 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef<Value> ptrs,
339340
if (batch != 0)
340341
stridedOffset = add(
341342
stridedOffset, mul(i32_val(batch * warpsPerCTA[0]), smemBatchOffset));
342-
343343
Value readPtr = gep(ptr_ty(ctx, 3), shemTy, ptr, stridedOffset);
344-
345-
PTXBuilder builder;
346-
// ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a
347-
// thread.
348-
auto resArgs = builder.newListOperand(4, "=r");
349-
auto addrArg = builder.newAddrOperand(readPtr, "r");
350-
351-
auto ldmatrix = builder.create("ldmatrix.sync.aligned.m8n8.x4")
352-
->o("trans", needTrans /*predicate*/)
353-
.o("shared.b16");
354-
ldmatrix(resArgs, addrArg);
355-
356-
// The result type is 4xi32, each i32 is composed of 2xf16
357-
// elements (adjacent two columns in a row) or a single f32 element.
358-
Value resV4 = builder.launch(rewriter, loc, resTy);
344+
auto ldMatrixOp = rewriter.create<nvgpu::LoadMatrixOp>(loc, resTy, readPtr);
345+
ldMatrixOp->setAttr("trans", rewriter.getBoolAttr(needTrans));
346+
auto resV4 = ldMatrixOp.getResult();
359347
return {extract_val(elemTy, resV4, 0), extract_val(elemTy, resV4, 1),
360348
extract_val(elemTy, resV4, 2), extract_val(elemTy, resV4, 3)};
361349
} else {

0 commit comments

Comments
 (0)