Skip to content

Commit 895975f

Browse files
committed
[ROCDL][WIP] Added matrix load-transpose ops for gfx1250+
1 parent 5defeed commit 895975f

File tree

3 files changed

+119
-2
lines changed

3 files changed

+119
-2
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ def ROCDL_BarrierOp : ROCDL_Op<"barrier"> {
321321
let assemblyFormat = "attr-dict";
322322
}
323323

324+
def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>;
324325
def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>;
325326

326327
def ROCDL_BarrierInitOp : ROCDL_IntrOp<"s.barrier.init", [], [], [], 0, 0, 0, 0, [1], ["id"]>,
@@ -631,8 +632,6 @@ def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]
631632
//===---------------------------------------------------------------------===//
632633
// LDS transpose intrinsics (available in GFX950)
633634

634-
def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>;
635-
636635
class ROCDL_LDS_Read_Tr_IntrOp<string mnemonic> :
637636
ROCDL_IntrOp<mnemonic, [1], [], [], 1, 0, 1> {
638637
dag args = (ins Arg<ROCDLBufferLDS, "", [MemRead]>:$ptr);
@@ -650,6 +649,76 @@ def ROCDL_ds_read_tr8_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr8.b64">;
650649
def ROCDL_ds_read_tr6_b96 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr6.b96">;
651650
def ROCDL_ds_read_tr16_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr16.b64">;
652651

652+
653+
654+
//===---------------------------------------------------------------------===//
655+
// Glb/DS load-transpose intrinsics (available in GFX1250+)
656+
657+
class WrapperType<Type t, int w> {
658+
Type type = t;
659+
int bitwidth = w;
660+
}
661+
class IType<I t> : WrapperType<t, t.bitwidth> {}
662+
class FType<F t> : WrapperType<t, t.bitwidth> {}
663+
def BF16Type : WrapperType<BF16, 16> {}
664+
665+
666+
class AddrKind<string n, int s> {
667+
string name = n;
668+
int space = s;
669+
LLVM_PointerInAddressSpace type = LLVM_PointerInAddressSpace<s>;
670+
}
671+
def GlobalAddrKind : AddrKind<"global", 1>;
672+
def DSAddrKind : AddrKind<"ds", 3>;
673+
674+
class ROCDL_TrLoadOpMeta<AddrKind addKind, int inElemBits, int outElemBits, WrapperType outElemType> {
675+
string inBits = !cast<string>(inElemBits);
676+
string outBits = !cast<string>(outElemBits);
677+
LLVM_PointerInAddressSpace inType = addKind.type;
678+
int outNumElem = !div(outElemBits, outElemType.bitwidth);
679+
ROCDL_ConcreteVector outType = ROCDL_ConcreteVector<outElemType.type, outNumElem>;
680+
string inBitsEnc = !if(!eq(addKind.space, 1),
681+
!if(!eq(inElemBits, 8),
682+
!if(!eq(inElemBits, 16), "", inBits), inBits), inBits);
683+
string mnemonic = addKind.name # ".load.tr" # inBitsEnc # ".b" # outBits;
684+
}
685+
686+
class ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta meta> :
687+
ROCDL_IntrOp<meta.mnemonic, [1], [], [], 1, 0, 1> {
688+
689+
dag args = (ins Arg<meta.inType, "", [MemRead]>:$ptr);
690+
let arguments = !con(args, baseArgs);
691+
let results = (outs meta.outType:$res);
692+
let summary = "Loads and transposes a matrix from global memory or ds to registers (available in gfx1250+).";
693+
let description = [{
694+
Load a matrix of }] # meta.inBits # [{-bit data from the global memory,
695+
transpose data between row-major and column-major order,
696+
and store the result into a }] # meta.outBits # [{-bit vector register.
697+
698+
Available in gfx1250+.
699+
}];
700+
let assemblyFormat = "$ptr attr-dict `:` type($ptr) `->` type($res)";
701+
let extraClassDefinition = [{
702+
::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
703+
return {getPtr()};
704+
}
705+
}];
706+
}
707+
708+
def ROCDL_GlobalLoadTr4_2I32 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 4, 64, IType<I32>>>;
709+
def ROCDL_GlobalLoadTr8_2I32 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 8, 64, IType<I32>>>;
710+
def ROCDL_GlobalLoadTr6_3I32 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 6, 96, IType<I32>>>;
711+
def ROCDL_GlobalLoadTr8_8I16 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 16, 128, IType<I16>>>;
712+
//def ROCDL_GlobalLoadTr8_8F16 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 8, 128, FType<F16>>>;
713+
//def ROCDL_GlobalLoadTr8_8BF16 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 8, 128, BF16Type>>;
714+
715+
def ROCDL_DsLoadTr4_2I32 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 4, 64, IType<I32>>>;
716+
def ROCDL_DsLoadTr8_2I32 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 8, 64, IType<I32>>>;
717+
def ROCDL_DsLoadTr6_3I32 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 6, 96, IType<I32>>>;
718+
def ROCDL_DsLoadTr16_8I16 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 16, 128, IType<I16>>>;
719+
//def ROCDL_DsLoadTr16_8F16 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 16, 128, FType<F16>>>;
720+
//def ROCDL_DsLoadTr16_8BF16 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 16, 128, BF16Type>>;
721+
653722
//===---------------------------------------------------------------------===//
654723
// Load to LDS intrinsic (available in GFX9 and GFX10)
655724
//===---------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,30 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
650650
llvm.return %r3 : vector<4xf16>
651651
}
652652

653+
llvm.func @rocdl.load.tr.ops(%gl_ptr : !llvm.ptr<1>, %ds_ptr : !llvm.ptr<3>) {
654+
// CHECK-LABEL: @rocdl.load.tr.ops
655+
// CHECK-SAME: (%[[GL_PTR:.+]]: !llvm.ptr<1>, %[[DS_OTR:.+]]: !llvm.ptr<3>)
656+
// CHECK: %0 = rocdl.global.load.tr4.b64 %[[GL_PTR]] : <1> -> vector<2xi32>
657+
// CHECK: %1 = rocdl.global.load.tr.b64 %[[GL_PTR]] : <1> -> vector<2xi32>
658+
// CHECK: %2 = rocdl.global.load.tr6.b96 %[[GL_PTR]] : <1> -> vector<3xi32>
659+
// CHECK: %3 = rocdl.global.load.tr.b128 %[[GL_PTR]] : <1> -> vector<8xi16>
660+
// CHECK: %4 = rocdl.ds.load.tr4.b64 %[[DS_OTR]] : <3> -> vector<2xi32>
661+
// CHECK: %5 = rocdl.ds.load.tr8.b64 %[[DS_OTR]] : <3> -> vector<2xi32>
662+
// CHECK: %6 = rocdl.ds.load.tr6.b96 %[[DS_OTR]] : <3> -> vector<3xi32>
663+
// CHECK: %7 = rocdl.ds.load.tr16.b128 %[[DS_OTR]] : <3> -> vector<8xi16>
664+
665+
rocdl.global.load.tr4.b64 %gl_ptr : !llvm.ptr<1> -> vector<2 x i32>
666+
rocdl.global.load.tr.b64 %gl_ptr : !llvm.ptr<1> -> vector<2 x i32>
667+
rocdl.global.load.tr6.b96 %gl_ptr : !llvm.ptr<1> -> vector<3 x i32>
668+
rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8 x i16>
669+
670+
rocdl.ds.load.tr4.b64 %ds_ptr : !llvm.ptr<3> -> vector<2 x i32>
671+
rocdl.ds.load.tr8.b64 %ds_ptr : !llvm.ptr<3> -> vector<2 x i32>
672+
rocdl.ds.load.tr6.b96 %ds_ptr : !llvm.ptr<3> -> vector<3 x i32>
673+
rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8 x i16>
674+
llvm.return
675+
}
676+
653677
llvm.func @rocdl.load.to.lds(%src : !llvm.ptr<7>, %dst: !llvm.ptr<3>) {
654678
// CHECK-LABEL @rocdl.load.to.lds
655679
//CHECK: rocdl.load.to.lds %{{.*}}, %{{.*}}, 4, 0, 0 : <7>

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,30 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
10281028
llvm.return %r3 : vector<4xf16>
10291029
}
10301030

1031+
llvm.func @rocdl.load.tr.ops(%gl_ptr : !llvm.ptr<1>, %ds_ptr : !llvm.ptr<3>) {
1032+
// CHECK-LABEL: rocdl.load.tr.ops
1033+
// CHECK-SAME: (ptr addrspace(1) %[[GL_PTR:.+]], ptr addrspace(3) %[[DS_PTR:.+]])
1034+
// CHECK: call <2 x i32> @llvm.amdgcn.global.load.tr4.b64.v2i32(ptr addrspace(1) %[[GL_PTR]])
1035+
// CHECK: call <2 x i32> @llvm.amdgcn.global.load.tr.b64.v2i32(ptr addrspace(1) %[[GL_PTR]])
1036+
// CHECK: call <3 x i32> @llvm.amdgcn.global.load.tr6.b96.v3i32(ptr addrspace(1) %[[GL_PTR]])
1037+
// CHECK: call <8 x i16> @llvm.amdgcn.global.load.tr.b128.v8i16(ptr addrspace(1) %[[GL_PTR]])
1038+
// CHECK: call <2 x i32> @llvm.amdgcn.ds.load.tr4.b64.v2i32(ptr addrspace(3) %[[DS_PTR]])
1039+
// CHECK: call <2 x i32> @llvm.amdgcn.ds.load.tr8.b64.v2i32(ptr addrspace(3) %[[DS_PTR]])
1040+
// CHECK: call <3 x i32> @llvm.amdgcn.ds.load.tr6.b96.v3i32(ptr addrspace(3) %[[DS_PTR]])
1041+
// CHECK: call <8 x i16> @llvm.amdgcn.ds.load.tr16.b128.v8i16(ptr addrspace(3) %[[DS_PTR]])
1042+
1043+
rocdl.global.load.tr4.b64 %gl_ptr : !llvm.ptr<1> -> vector<2 x i32>
1044+
rocdl.global.load.tr.b64 %gl_ptr : !llvm.ptr<1> -> vector<2 x i32>
1045+
rocdl.global.load.tr6.b96 %gl_ptr : !llvm.ptr<1> -> vector<3 x i32>
1046+
rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8 x i16>
1047+
1048+
rocdl.ds.load.tr4.b64 %ds_ptr : !llvm.ptr<3> -> vector<2 x i32>
1049+
rocdl.ds.load.tr8.b64 %ds_ptr : !llvm.ptr<3> -> vector<2 x i32>
1050+
rocdl.ds.load.tr6.b96 %ds_ptr : !llvm.ptr<3> -> vector<3 x i32>
1051+
rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8 x i16>
1052+
llvm.return
1053+
}
1054+
10311055
llvm.func @rocdl.load.to.lds(%src : !llvm.ptr<7>, %dst: !llvm.ptr<3>) {
10321056
//CHECK: call void @llvm.amdgcn.load.to.lds.p7
10331057
rocdl.load.to.lds %src, %dst, 4, 0, 0 : !llvm.ptr<7>

0 commit comments

Comments
 (0)