Skip to content

Commit 9164c06

Browse files
authored
[BACKEND] Support ldmatrix m16n16.b8.trans (#7525)
1 parent 8cdced6 commit 9164c06

File tree

5 files changed

+176
-19
lines changed

5 files changed

+176
-19
lines changed

test/Conversion/nvgpu_to_llvm.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,23 @@ llvm.func @cluster_id() -> i32 {
1313

1414
// -----
1515

16+
// CHECK-LABEL: @ldmatrix
17+
llvm.func @ldmatrix(%ptr: !llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> {
18+
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];
19+
%0 = nvgpu.ldmatrix %ptr, m8n8, 16 : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
20+
// CHECK: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {$0, $1, $2, $3}, [$4];
21+
%1 = nvgpu.ldmatrix %ptr, m8n8, 16 {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
22+
// CHECK: ldmatrix.sync.aligned.m16n16.x4.trans.shared.b8 {$0, $1, $2, $3}, [$4];
23+
%l = nvgpu.ldmatrix %ptr, m16n16, 8 {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
24+
%2 = llvm.extractvalue %1[0] : !llvm.struct<(i32, i32, i32, i32)>
25+
%3 = llvm.insertvalue %2, %0[0] : !llvm.struct<(i32, i32, i32, i32)>
26+
%4 = llvm.extractvalue %l[0] : !llvm.struct<(i32, i32, i32, i32)>
27+
%5 = llvm.insertvalue %4, %3[1] : !llvm.struct<(i32, i32, i32, i32)>
28+
llvm.return %5 : !llvm.struct<(i32, i32, i32, i32)>
29+
}
30+
31+
// -----
32+
1633
!struct_128xf32 = !llvm.struct<(
1734
f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
1835
f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -880,9 +880,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
880880
tt.func @convert_dot_ldmatrix(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
881881
%AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
882882
%BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
883-
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
884-
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<col>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
885-
// CHECK-NOT: nvvm.ldmatrix
883+
// CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
884+
// CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
885+
// CHECK-NOT: nvgpu.ldmatrix
886886
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
887887
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b>
888888
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
@@ -910,9 +910,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
910910
tt.func @convert_dot_ldmatrix_swizzle(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
911911
%AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
912912
%BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
913-
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
914-
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<col>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
915-
// CHECK-NOT: nvvm.ldmatrix
913+
// CHECK: nvgpu.ldmatrix %{{.*}}, m8n8, 16 : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
914+
// CHECK: nvgpu.ldmatrix %{{.*}}, m8n8, 16 {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
915+
// CHECK-NOT: nvgpu.ldmatrix
916916
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
917917
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b>
918918
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
@@ -940,7 +940,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
940940
tt.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
941941
%AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
942942
%BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
943-
// CHECK-NOT: nvvm.ldmatrix
943+
// CHECK-NOT: nvgpu.ldmatrix
944944
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
945945
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b>
946946
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
@@ -968,8 +968,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
968968
tt.func @convert_dot_mmav3_shared(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) {
969969
%AA = ttg.local_alloc %A : (tensor<64x64xf16, #blocked0>) -> !ttg.memdesc<64x64xf16, #shared0, #smem>
970970
%BB = ttg.local_alloc %B : (tensor<64x64xf16, #blocked0>) -> !ttg.memdesc<64x64xf16, #shared0, #smem>
971-
// CHECK-COUNT-16: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
972-
// CHECK-COUNT-16: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<col>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
971+
// CHECK-COUNT-32: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
973972
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<64x64xf16, #shared0, #smem> -> tensor<64x64xf16, #dot_operand_a>
974973
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<64x64xf16, #shared0, #smem> -> tensor<64x64xf16, #dot_operand_b>
975974
%cst0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma0>
@@ -993,8 +992,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
993992
tt.func @convert_dot_fp8(%A: tensor<16x16xf8E5M2, #blocked0>, %B: tensor<16x16xf8E5M2, #blocked0>) {
994993
%AA = ttg.local_alloc %A : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem>
995994
%BB = ttg.local_alloc %B : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem>
996-
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 2 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
997-
// CHECK-NOT: nvvm.ldmatrix
995+
// CHECK: nvgpu.ldmatrix %{{.*}}, m8n8, 16 : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
996+
// CHECK-NOT: nvgpu.ldmatrix
998997
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_a>
999998
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_b>
1000999
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
@@ -1325,7 +1324,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
13251324
tt.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
13261325
%a:!ttg.memdesc<128x32xf16, #shared, #smem>, %b:!ttg.memdesc<32x256xf16, #shared, #smem>) {
13271326
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
1328-
// CHECK: nvvm.ldmatrix
1327+
// CHECK: nvgpu.ldmatrix
13291328
%a_mat = ttg.local_load %a : !ttg.memdesc<128x32xf16, #shared, #smem> -> tensor<128x32xf16, #dot_operand_a>
13301329
%b_mat = ttg.local_load %b : !ttg.memdesc<32x256xf16, #shared, #smem> -> tensor<32x256xf16, #dot_operand_b>
13311330

@@ -1401,9 +1400,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
14011400
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
14021401
%a:!ttg.memdesc<32x16xf32, #shared, #smem>, %b:!ttg.memdesc<16x32xf32, #shared, #smem>) {
14031402
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
1404-
// CHECK: nvvm.ldmatrix
1403+
// CHECK: nvgpu.ldmatrix
14051404
// CHECK-SAME: (i32, i32, i32, i32)
1406-
// CHECK: nvvm.ldmatrix
1405+
// CHECK: nvgpu.ldmatrix
14071406
// CHECK-SAME: (i32, i32, i32, i32)
14081407
%a_mat = ttg.local_load %a : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #dot_operand_a>
14091408
%b_mat = ttg.local_load %b : !ttg.memdesc<16x32xf32, #shared, #smem> -> tensor<16x32xf32, #dot_operand_b>
@@ -1876,8 +1875,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
18761875
%f16_shared = ttg.local_alloc %f16_inp : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
18771876
%i16_shared = ttg.local_alloc %i16_inp : (tensor<16x16xi16, #blocked0>) -> !ttg.memdesc<16x16xi16, #shared0, #smem>
18781877

1879-
// CHECK: nvvm.ldmatrix
1880-
// CHECK: nvvm.ldmatrix
1878+
// CHECK: nvgpu.ldmatrix
1879+
// CHECK: nvgpu.ldmatrix
18811880

18821881
%f16_dot = ttg.local_load %f16_shared : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
18831882
%i16_dot = ttg.local_load %i16_shared : !ttg.memdesc<16x16xi16, #shared0, #smem> -> tensor<16x16xi16, #dot_operand_b>

third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,26 @@ def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> {
105105
let assemblyFormat = "$opA `,` $opB `,` $useC (`,` $opC^)? attr-dict `:` functional-type(operands, $res)";
106106
}
107107

108+
def NVGPU_LoadMatrixShapeAttr : I32EnumAttr<
109+
"LoadMatrixShape", "",
110+
[
111+
I32EnumAttrCase<"m8n8", 0, "m8n8">,
112+
I32EnumAttrCase<"m16n16", 1, "m16n16">
113+
]> {
114+
let cppNamespace = "::mlir::triton::nvgpu";
115+
}
116+
117+
def NVGPU_LoadMatrixOp : NVGPU_Op<"ldmatrix", [MemoryEffects<[MemRead]>]> {
118+
let arguments = (
119+
ins LLVM_PointerShared:$addr,
120+
NVGPU_LoadMatrixShapeAttr:$shape,
121+
I32Attr:$bit_width,
122+
UnitAttr:$trans
123+
);
124+
let results = (outs AnyTypeOf<[LLVM_AnyStruct, I32]>:$result);
125+
let assemblyFormat = "$addr `,` $shape `,` $bit_width attr-dict `:` functional-type($addr, $result)";
126+
}
127+
108128
def NVGPU_ClusterCTAIdOp : NVGPU_Op<"cluster_id", [Pure]> {
109129
let results = (outs I32:$result);
110130
let assemblyFormat = "attr-dict";

third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp

Lines changed: 118 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,122 @@ class WarpIdOpPattern : public OpRewritePattern<ttn::WarpIdOp> {
238238
}
239239
};
240240

241+
// Base class for Matrix Operation Patterns
242+
template <typename MatrixOpType, typename ConcreteMatrixOpPattern>
243+
class MatrixOpPattern : public OpRewritePattern<MatrixOpType> {
244+
public:
245+
using OpRewritePattern<MatrixOpType>::OpRewritePattern;
246+
247+
LogicalResult matchAndRewrite(MatrixOpType op,
248+
PatternRewriter &rewriter) const override {
249+
unsigned vecSize = getVectorSize(op);
250+
bool trans = op.getTrans();
251+
// Template method for PTX assembly generation
252+
std::string ptxAsm =
253+
(llvm::Twine(ConcreteMatrixOpPattern::kOpCode) +
254+
getPtxModifiers(vecSize, trans, op.getShape(), op.getBitWidth()) +
255+
" " + getOperands(op, vecSize) + ";")
256+
.str();
257+
258+
OperandsAndConstraints operandAndConstraints =
259+
getOperandsAndConstraints(op, vecSize);
260+
Constraints outputConstraints = getOutputConstraints(op, vecSize);
261+
262+
return rewriteAsPtxAsm(op, rewriter, ptxAsm, operandAndConstraints,
263+
outputConstraints);
264+
}
265+
266+
protected:
267+
// Shared helper methods
268+
std::string getPtxModifiers(unsigned vecSize, bool trans,
269+
triton::nvgpu::LoadMatrixShape shape,
270+
int bitWidth) const {
271+
std::string ptxAsmBase = std::string(".sync.aligned");
272+
switch (shape) {
273+
case triton::nvgpu::LoadMatrixShape::m8n8:
274+
ptxAsmBase += ".m8n8";
275+
break;
276+
case triton::nvgpu::LoadMatrixShape::m16n16:
277+
ptxAsmBase += ".m16n16";
278+
break;
279+
default:
280+
llvm_unreachable("Invalid load matrix shape");
281+
}
282+
std::string suffix = trans ? ".trans.shared" : ".shared";
283+
suffix += ".b" + std::to_string(bitWidth);
284+
switch (vecSize) {
285+
case 1:
286+
return ptxAsmBase + ".x1" + suffix;
287+
case 2:
288+
return ptxAsmBase + ".x2" + suffix;
289+
case 4:
290+
return ptxAsmBase + ".x4" + suffix;
291+
default:
292+
llvm_unreachable("Invalid vector size");
293+
}
294+
}
295+
296+
std::string getPtxRegOperands(unsigned startIdx, unsigned count) const {
297+
llvm::SmallString<20> regOperands;
298+
llvm::raw_svector_ostream stream(regOperands);
299+
stream << "{";
300+
for (unsigned i = 0; i < count; i++) {
301+
stream << "$" + llvm::utostr(startIdx + i);
302+
if (i != count - 1)
303+
stream << ", ";
304+
}
305+
stream << "}";
306+
return std::string(regOperands.str());
307+
}
308+
309+
std::string getPtxAddrOperand(unsigned idx) const {
310+
return (llvm::Twine("[$") + llvm::utostr(idx) + "]").str();
311+
}
312+
313+
virtual std::string getOperands(MatrixOpType op, unsigned vecSize) const = 0;
314+
virtual OperandsAndConstraints
315+
getOperandsAndConstraints(MatrixOpType op, unsigned vecSize) const = 0;
316+
virtual Constraints getOutputConstraints(MatrixOpType op,
317+
unsigned vecSize) const = 0;
318+
virtual unsigned getVectorSize(MatrixOpType op) const = 0;
319+
};
320+
321+
// LoadMatrixOp Pattern
322+
class LoadMatrixOpPattern
323+
: public MatrixOpPattern<ttn::LoadMatrixOp, LoadMatrixOpPattern> {
324+
public:
325+
using MatrixOpPattern<ttn::LoadMatrixOp,
326+
LoadMatrixOpPattern>::MatrixOpPattern;
327+
static constexpr const char *kOpCode = "ldmatrix";
328+
329+
protected:
330+
unsigned getVectorSize(ttn::LoadMatrixOp op) const override {
331+
auto resultType = op.getType();
332+
if (auto structTy = dyn_cast<LLVM::LLVMStructType>(resultType)) {
333+
return structTy.getBody().size();
334+
}
335+
return 1;
336+
}
337+
338+
std::string getOperands(ttn::LoadMatrixOp op,
339+
unsigned vecSize) const override {
340+
return (llvm::Twine(getPtxRegOperands(0, vecSize)) + ", " +
341+
getPtxAddrOperand(vecSize))
342+
.str();
343+
}
344+
345+
OperandsAndConstraints
346+
getOperandsAndConstraints(ttn::LoadMatrixOp op,
347+
unsigned vecSize) const override {
348+
return {{op.getAddr(), "r"}};
349+
}
350+
351+
Constraints getOutputConstraints(ttn::LoadMatrixOp op,
352+
unsigned vecSize) const override {
353+
return Constraints(vecSize, "=r");
354+
}
355+
};
356+
241357
class LoadAcquireOpPattern : public OpRewritePattern<ttn::LoadAcquireOp> {
242358
public:
243359
using OpRewritePattern<ttn::LoadAcquireOp>::OpRewritePattern;
@@ -623,8 +739,8 @@ class ConvertNVGPUToLLVM
623739
patterns.add<NVGPUOpGenericPattern<ttn::ClusterCTAIdOp>>(
624740
context, kClusterCtaIdOp, Constraints({"=r"}), Constraints());
625741

626-
patterns.add<WGMMAOpPattern, LoadAcquireOpPattern, WGMMAWaitGroupOpPattern,
627-
WarpIdOpPattern>(context);
742+
patterns.add<LoadMatrixOpPattern, WGMMAOpPattern, LoadAcquireOpPattern,
743+
WGMMAWaitGroupOpPattern, WarpIdOpPattern>(context);
628744

629745
if (applyPatternsGreedily(mod, std::move(patterns)).failed())
630746
signalPassFailure();

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "Dialect/NVGPU/IR/Dialect.h"
12
#include "PatternTritonGPUOpToLLVM.h"
23
#include "TargetInfo.h"
34
#include "Utility.h"
@@ -212,7 +213,11 @@ LogicalResult lowerLdStMatrix(
212213
: static_cast<Type>(LLVM::LLVMStructType::getLiteral(
213214
ctx, SmallVector<Type>(nVecs, i32_ty)));
214215
auto res =
215-
rewriter.create<NVVM::LdMatrixOp>(loc, matTy, vecAddr, nVecs, layout)
216+
rewriter
217+
.create<triton::nvgpu::LoadMatrixOp>(
218+
loc, matTy, vecAddr, triton::nvgpu::LoadMatrixShape::m8n8,
219+
/*bitWidth=*/16,
220+
/*needTrans=*/transpose)
216221
.getResult();
217222
// Extract result into srcVals
218223
for (int j = 0; j < nVecs; j++) {

0 commit comments

Comments
 (0)