Skip to content

Commit 8a83700

Browse files
[ROCDL] Added matrix load-transpose ops for gfx1250+ (llvm#165564)
This patch adds load-transpose instructions for gfx1250+ arch to ROCDL. Note, this is work in progress but I'd like to share the ideas here and hope to get some comments. Co-authored-by: Krzysztof Drewniak <[email protected]>
1 parent 5a1acf7 commit 8a83700

File tree

3 files changed

+119
-2
lines changed

3 files changed

+119
-2
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ def ROCDL_BarrierOp : ROCDL_Op<"barrier"> {
321321
let assemblyFormat = "attr-dict";
322322
}
323323

324+
def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>;
324325
def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>;
325326

326327
def ROCDL_BarrierInitOp : ROCDL_IntrOp<"s.barrier.init", [], [], [], 0, 0, 0, 0, [1], ["id"]>,
@@ -631,8 +632,6 @@ def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]
631632
//===---------------------------------------------------------------------===//
632633
// LDS transpose intrinsics (available in GFX950)
633634

634-
def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>;
635-
636635
class ROCDL_LDS_Read_Tr_IntrOp<string mnemonic> :
637636
ROCDL_IntrOp<mnemonic, [1], [], [], 1, 0, 1> {
638637
dag args = (ins Arg<ROCDLBufferLDS, "", [MemRead]>:$ptr);
@@ -650,6 +649,58 @@ def ROCDL_ds_read_tr8_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr8.b64">;
650649
def ROCDL_ds_read_tr6_b96 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr6.b96">;
651650
def ROCDL_ds_read_tr16_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr16.b64">;
652651

652+
653+
654+
//===---------------------------------------------------------------------===//
655+
// Glb/DS load-transpose intrinsics (available in GFX1250+)
656+
657+
class AddrKind<string n, int s> {
658+
string name = n;
659+
int space = s;
660+
}
661+
def GlobalAddrKind : AddrKind<"global", 1>;
662+
def DSAddrKind : AddrKind<"ds", 3>;
663+
664+
class ROCDL_TrLoadOpMeta<AddrKind kind, int inElemBits, int outElemBits> {
665+
AddrKind addrKind = kind;
666+
string inBits = !cast<string>(inElemBits);
667+
string outBits = !cast<string>(outElemBits);
668+
string inBitsEnc = !if(!eq(addrKind.space, 1),
669+
!if(!or(!eq(inElemBits, 8), !eq(inElemBits, 16)), "", inBits), inBits);
670+
string mnemonic = addrKind.name # ".load.tr" # inBitsEnc # ".b" # outBits;
671+
}
672+
673+
class ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta meta> :
674+
ROCDL_IntrOp<meta.mnemonic, [1], [], [], 1, 0, 1> {
675+
676+
dag args = (ins Arg<LLVM_PointerInAddressSpace<meta.addrKind.space>, "", [MemRead]>:$ptr);
677+
let arguments = !con(args, baseArgs);
678+
let summary = "Loads and transposes a matrix from " # meta.addrKind.name # " memory to registers (available in gfx1250+).";
679+
let description = [{
680+
Load a matrix of }] # meta.inBits # [{-bit data from the }] # meta.addrKind.name # [{ memory,
681+
transpose data between row-major and column-major order,
682+
and store the result into a }] # meta.outBits # [{-bit vector register.
683+
684+
Available in gfx1250+.
685+
}];
686+
let assemblyFormat = "$ptr attr-dict `:` qualified(type($ptr)) `->` type($res)";
687+
let extraClassDefinition = [{
688+
::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
689+
return {getPtr()};
690+
}
691+
}];
692+
}
693+
694+
def ROCDL_GlobalLoadTr4_B64 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 4, 64>>;
695+
def ROCDL_GlobalLoadTr8_B64 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 8, 64>>;
696+
def ROCDL_GlobalLoadTr6_B96 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 6, 96>>;
697+
def ROCDL_GlobalLoadTr8_B128 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 16, 128>>;
698+
699+
def ROCDL_DsLoadTr4_B64 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 4, 64>>;
700+
def ROCDL_DsLoadTr8_B64 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 8, 64>>;
701+
def ROCDL_DsLoadTr6_B96 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 6, 96>>;
702+
def ROCDL_DsLoadTr16_B128 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 16, 128>>;
703+
653704
//===---------------------------------------------------------------------===//
654705
// Load to LDS intrinsic (available in GFX9 and GFX10)
655706
//===---------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,39 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
650650
llvm.return %r3 : vector<4xf16>
651651
}
652652

653+
llvm.func @rocdl.load.tr.ops(%gl_ptr : !llvm.ptr<1>, %ds_ptr : !llvm.ptr<3>) {
654+
// CHECK-LABEL: @rocdl.load.tr.ops
655+
// CHECK-SAME: (%[[GL_PTR:.+]]: !llvm.ptr<1>, %[[DS_OTR:.+]]: !llvm.ptr<3>)
656+
// CHECK: rocdl.global.load.tr4.b64 %[[GL_PTR]] : !llvm.ptr<1> -> vector<2xi32>
657+
// CHECK: rocdl.global.load.tr.b64 %[[GL_PTR]] : !llvm.ptr<1> -> vector<2xi32>
658+
// CHECK: rocdl.global.load.tr6.b96 %[[GL_PTR]] : !llvm.ptr<1> -> vector<3xi32>
659+
// CHECK: rocdl.global.load.tr.b128 %[[GL_PTR]] : !llvm.ptr<1> -> vector<8xi16>
660+
// CHECK: rocdl.global.load.tr.b128 %[[GL_PTR]] : !llvm.ptr<1> -> vector<8xf16>
661+
// CHECK: rocdl.global.load.tr.b128 %[[GL_PTR]] : !llvm.ptr<1> -> vector<8xbf16>
662+
// CHECK: rocdl.ds.load.tr4.b64 %[[DS_OTR]] : !llvm.ptr<3> -> vector<2xi32>
663+
// CHECK: rocdl.ds.load.tr8.b64 %[[DS_OTR]] : !llvm.ptr<3> -> vector<2xi32>
664+
// CHECK: rocdl.ds.load.tr6.b96 %[[DS_OTR]] : !llvm.ptr<3> -> vector<3xi32>
665+
// CHECK: rocdl.ds.load.tr16.b128 %[[DS_OTR]] : !llvm.ptr<3> -> vector<8xi16>
666+
// CHECK: rocdl.ds.load.tr16.b128 %[[DS_OTR]] : !llvm.ptr<3> -> vector<8xf16>
667+
// CHECK: rocdl.ds.load.tr16.b128 %[[DS_OTR]] : !llvm.ptr<3> -> vector<8xbf16>
668+
// CHECK: llvm.return
669+
670+
rocdl.global.load.tr4.b64 %gl_ptr : !llvm.ptr<1> -> vector<2xi32>
671+
rocdl.global.load.tr.b64 %gl_ptr : !llvm.ptr<1> -> vector<2xi32>
672+
rocdl.global.load.tr6.b96 %gl_ptr : !llvm.ptr<1> -> vector<3xi32>
673+
rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xi16>
674+
rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xf16>
675+
rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xbf16>
676+
677+
rocdl.ds.load.tr4.b64 %ds_ptr : !llvm.ptr<3> -> vector<2xi32>
678+
rocdl.ds.load.tr8.b64 %ds_ptr : !llvm.ptr<3> -> vector<2xi32>
679+
rocdl.ds.load.tr6.b96 %ds_ptr : !llvm.ptr<3> -> vector<3xi32>
680+
rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xi16>
681+
rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xf16>
682+
rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xbf16>
683+
llvm.return
684+
}
685+
653686
llvm.func @rocdl.load.to.lds(%src : !llvm.ptr<7>, %dst: !llvm.ptr<3>) {
654687
// CHECK-LABEL @rocdl.load.to.lds
655688
//CHECK: rocdl.load.to.lds %{{.*}}, %{{.*}}, 4, 0, 0 : <7>

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,39 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
10281028
llvm.return %r3 : vector<4xf16>
10291029
}
10301030

1031+
llvm.func @rocdl.load.tr.ops(%gl_ptr : !llvm.ptr<1>, %ds_ptr : !llvm.ptr<3>) {
1032+
// CHECK-LABEL: rocdl.load.tr.ops
1033+
// CHECK-SAME: (ptr addrspace(1) %[[GL_PTR:.+]], ptr addrspace(3) %[[DS_PTR:.+]])
1034+
// CHECK: call <2 x i32> @llvm.amdgcn.global.load.tr4.b64.v2i32(ptr addrspace(1) %[[GL_PTR]])
1035+
// CHECK: call <2 x i32> @llvm.amdgcn.global.load.tr.b64.v2i32(ptr addrspace(1) %[[GL_PTR]])
1036+
// CHECK: call <3 x i32> @llvm.amdgcn.global.load.tr6.b96.v3i32(ptr addrspace(1) %[[GL_PTR]])
1037+
// CHECK: call <8 x i16> @llvm.amdgcn.global.load.tr.b128.v8i16(ptr addrspace(1) %[[GL_PTR]])
1038+
// CHECK: call <8 x half> @llvm.amdgcn.global.load.tr.b128.v8f16(ptr addrspace(1) %[[GL_PTR]])
1039+
// CHECK: call <8 x bfloat> @llvm.amdgcn.global.load.tr.b128.v8bf16(ptr addrspace(1) %[[GL_PTR]])
1040+
1041+
// CHECK: call <2 x i32> @llvm.amdgcn.ds.load.tr4.b64.v2i32(ptr addrspace(3) %[[DS_PTR]])
1042+
// CHECK: call <2 x i32> @llvm.amdgcn.ds.load.tr8.b64.v2i32(ptr addrspace(3) %[[DS_PTR]])
1043+
// CHECK: call <3 x i32> @llvm.amdgcn.ds.load.tr6.b96.v3i32(ptr addrspace(3) %[[DS_PTR]])
1044+
// CHECK: call <8 x i16> @llvm.amdgcn.ds.load.tr16.b128.v8i16(ptr addrspace(3) %[[DS_PTR]])
1045+
// CHECK: call <8 x half> @llvm.amdgcn.ds.load.tr16.b128.v8f16(ptr addrspace(3) %[[DS_PTR]])
1046+
// CHECK: call <8 x bfloat> @llvm.amdgcn.ds.load.tr16.b128.v8bf16(ptr addrspace(3) %[[DS_PTR]])
1047+
1048+
rocdl.global.load.tr4.b64 %gl_ptr : !llvm.ptr<1> -> vector<2xi32>
1049+
rocdl.global.load.tr.b64 %gl_ptr : !llvm.ptr<1> -> vector<2xi32>
1050+
rocdl.global.load.tr6.b96 %gl_ptr : !llvm.ptr<1> -> vector<3xi32>
1051+
rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xi16>
1052+
rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xf16>
1053+
rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xbf16>
1054+
1055+
rocdl.ds.load.tr4.b64 %ds_ptr : !llvm.ptr<3> -> vector<2xi32>
1056+
rocdl.ds.load.tr8.b64 %ds_ptr : !llvm.ptr<3> -> vector<2xi32>
1057+
rocdl.ds.load.tr6.b96 %ds_ptr : !llvm.ptr<3> -> vector<3xi32>
1058+
rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xi16>
1059+
rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xf16>
1060+
rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xbf16>
1061+
llvm.return
1062+
}
1063+
10311064
llvm.func @rocdl.load.to.lds(%src : !llvm.ptr<7>, %dst: !llvm.ptr<3>) {
10321065
//CHECK: call void @llvm.amdgcn.load.to.lds.p7
10331066
rocdl.load.to.lds %src, %dst, 4, 0, 0 : !llvm.ptr<7>

0 commit comments

Comments
 (0)