Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 53 additions & 2 deletions mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ def ROCDL_BarrierOp : ROCDL_Op<"barrier"> {
let assemblyFormat = "attr-dict";
}

def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>;
def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>;

def ROCDL_BarrierInitOp : ROCDL_IntrOp<"s.barrier.init", [], [], [], 0, 0, 0, 0, [1], ["id"]>,
Expand Down Expand Up @@ -631,8 +632,6 @@ def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]
//===---------------------------------------------------------------------===//
// LDS transpose intrinsics (available in GFX950)

def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>;

class ROCDL_LDS_Read_Tr_IntrOp<string mnemonic> :
ROCDL_IntrOp<mnemonic, [1], [], [], 1, 0, 1> {
dag args = (ins Arg<ROCDLBufferLDS, "", [MemRead]>:$ptr);
Expand All @@ -650,6 +649,58 @@ def ROCDL_ds_read_tr8_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr8.b64">;
def ROCDL_ds_read_tr6_b96 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr6.b96">;
def ROCDL_ds_read_tr16_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr16.b64">;



//===---------------------------------------------------------------------===//
// Glb/DS load-transpose intrinsics (available in GFX1250+)

class AddrKind<string n, int s> {
string name = n;
int space = s;
}
def GlobalAddrKind : AddrKind<"global", 1>;
def DSAddrKind : AddrKind<"ds", 3>;
Comment on lines +657 to +662
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!


class ROCDL_TrLoadOpMeta<AddrKind kind, int inElemBits, int outElemBits> {
AddrKind addrKind = kind;
string inBits = !cast<string>(inElemBits);
string outBits = !cast<string>(outElemBits);
string inBitsEnc = !if(!eq(addrKind.space, 1),
!if(!or(!eq(inElemBits, 8), !eq(inElemBits, 16)), "", inBits), inBits);
string mnemonic = addrKind.name # ".load.tr" # inBitsEnc # ".b" # outBits;
}

class ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta meta> :
ROCDL_IntrOp<meta.mnemonic, [1], [], [], 1, 0, 1> {

dag args = (ins Arg<LLVM_PointerInAddressSpace<meta.addrKind.space>, "", [MemRead]>:$ptr);
let arguments = !con(args, baseArgs);
let summary = "Loads and transposes a matrix from " # meta.addrKind.name # " memory to registers (available in gfx1250+).";
let description = [{
Load a matrix of }] # meta.inBits # [{-bit data from the }] # meta.addrKind.name # [{ memory,
transpose data between row-major and column-major order,
and store the result into a }] # meta.outBits # [{-bit vector register.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll note these instructions seem generally underdocumented. While this PR may not be the right place for them, can we make a plan for surfacing their exact semantics?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@krzysz00, do you mean in the future, right?


Available in gfx1250+.
}];
let assemblyFormat = "$ptr attr-dict `:` qualified(type($ptr)) `->` type($res)";
let extraClassDefinition = [{
::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
return {getPtr()};
}
}];
}

def ROCDL_GlobalLoadTr4_B64 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 4, 64>>;
def ROCDL_GlobalLoadTr8_B64 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 8, 64>>;
def ROCDL_GlobalLoadTr6_B96 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 6, 96>>;
def ROCDL_GlobalLoadTr8_B128 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 16, 128>>;

def ROCDL_DsLoadTr4_B64 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 4, 64>>;
def ROCDL_DsLoadTr8_B64 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 8, 64>>;
def ROCDL_DsLoadTr6_B96 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 6, 96>>;
def ROCDL_DsLoadTr16_B128 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 16, 128>>;

//===---------------------------------------------------------------------===//
// Load to LDS intrinsic (available in GFX9 and GFX10)
//===---------------------------------------------------------------------===//
Expand Down
33 changes: 33 additions & 0 deletions mlir/test/Dialect/LLVMIR/rocdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,39 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
llvm.return %r3 : vector<4xf16>
}

llvm.func @rocdl.load.tr.ops(%gl_ptr : !llvm.ptr<1>, %ds_ptr : !llvm.ptr<3>) {
// CHECK-LABEL: @rocdl.load.tr.ops
// CHECK-SAME: (%[[GL_PTR:.+]]: !llvm.ptr<1>, %[[DS_OTR:.+]]: !llvm.ptr<3>)
// CHECK: rocdl.global.load.tr4.b64 %[[GL_PTR]] : !llvm.ptr<1> -> vector<2xi32>
// CHECK: rocdl.global.load.tr.b64 %[[GL_PTR]] : !llvm.ptr<1> -> vector<2xi32>
// CHECK: rocdl.global.load.tr6.b96 %[[GL_PTR]] : !llvm.ptr<1> -> vector<3xi32>
// CHECK: rocdl.global.load.tr.b128 %[[GL_PTR]] : !llvm.ptr<1> -> vector<8xi16>
// CHECK: rocdl.global.load.tr.b128 %[[GL_PTR]] : !llvm.ptr<1> -> vector<8xf16>
// CHECK: rocdl.global.load.tr.b128 %[[GL_PTR]] : !llvm.ptr<1> -> vector<8xbf16>
// CHECK: rocdl.ds.load.tr4.b64 %[[DS_OTR]] : !llvm.ptr<3> -> vector<2xi32>
// CHECK: rocdl.ds.load.tr8.b64 %[[DS_OTR]] : !llvm.ptr<3> -> vector<2xi32>
// CHECK: rocdl.ds.load.tr6.b96 %[[DS_OTR]] : !llvm.ptr<3> -> vector<3xi32>
// CHECK: rocdl.ds.load.tr16.b128 %[[DS_OTR]] : !llvm.ptr<3> -> vector<8xi16>
// CHECK: rocdl.ds.load.tr16.b128 %[[DS_OTR]] : !llvm.ptr<3> -> vector<8xf16>
// CHECK: rocdl.ds.load.tr16.b128 %[[DS_OTR]] : !llvm.ptr<3> -> vector<8xbf16>
// CHECK: llvm.return

rocdl.global.load.tr4.b64 %gl_ptr : !llvm.ptr<1> -> vector<2xi32>
rocdl.global.load.tr.b64 %gl_ptr : !llvm.ptr<1> -> vector<2xi32>
rocdl.global.load.tr6.b96 %gl_ptr : !llvm.ptr<1> -> vector<3xi32>
rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xi16>
rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xf16>
rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xbf16>

rocdl.ds.load.tr4.b64 %ds_ptr : !llvm.ptr<3> -> vector<2xi32>
rocdl.ds.load.tr8.b64 %ds_ptr : !llvm.ptr<3> -> vector<2xi32>
rocdl.ds.load.tr6.b96 %ds_ptr : !llvm.ptr<3> -> vector<3xi32>
rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xi16>
rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xf16>
rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xbf16>
llvm.return
}

llvm.func @rocdl.load.to.lds(%src : !llvm.ptr<7>, %dst: !llvm.ptr<3>) {
// CHECK-LABEL @rocdl.load.to.lds
//CHECK: rocdl.load.to.lds %{{.*}}, %{{.*}}, 4, 0, 0 : <7>
Expand Down
33 changes: 33 additions & 0 deletions mlir/test/Target/LLVMIR/rocdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,39 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
llvm.return %r3 : vector<4xf16>
}

llvm.func @rocdl.load.tr.ops(%gl_ptr : !llvm.ptr<1>, %ds_ptr : !llvm.ptr<3>) {
// CHECK-LABEL: rocdl.load.tr.ops
// CHECK-SAME: (ptr addrspace(1) %[[GL_PTR:.+]], ptr addrspace(3) %[[DS_PTR:.+]])
// CHECK: call <2 x i32> @llvm.amdgcn.global.load.tr4.b64.v2i32(ptr addrspace(1) %[[GL_PTR]])
// CHECK: call <2 x i32> @llvm.amdgcn.global.load.tr.b64.v2i32(ptr addrspace(1) %[[GL_PTR]])
// CHECK: call <3 x i32> @llvm.amdgcn.global.load.tr6.b96.v3i32(ptr addrspace(1) %[[GL_PTR]])
// CHECK: call <8 x i16> @llvm.amdgcn.global.load.tr.b128.v8i16(ptr addrspace(1) %[[GL_PTR]])
// CHECK: call <8 x half> @llvm.amdgcn.global.load.tr.b128.v8f16(ptr addrspace(1) %[[GL_PTR]])
// CHECK: call <8 x bfloat> @llvm.amdgcn.global.load.tr.b128.v8bf16(ptr addrspace(1) %[[GL_PTR]])

// CHECK: call <2 x i32> @llvm.amdgcn.ds.load.tr4.b64.v2i32(ptr addrspace(3) %[[DS_PTR]])
// CHECK: call <2 x i32> @llvm.amdgcn.ds.load.tr8.b64.v2i32(ptr addrspace(3) %[[DS_PTR]])
// CHECK: call <3 x i32> @llvm.amdgcn.ds.load.tr6.b96.v3i32(ptr addrspace(3) %[[DS_PTR]])
// CHECK: call <8 x i16> @llvm.amdgcn.ds.load.tr16.b128.v8i16(ptr addrspace(3) %[[DS_PTR]])
// CHECK: call <8 x half> @llvm.amdgcn.ds.load.tr16.b128.v8f16(ptr addrspace(3) %[[DS_PTR]])
// CHECK: call <8 x bfloat> @llvm.amdgcn.ds.load.tr16.b128.v8bf16(ptr addrspace(3) %[[DS_PTR]])

rocdl.global.load.tr4.b64 %gl_ptr : !llvm.ptr<1> -> vector<2xi32>
rocdl.global.load.tr.b64 %gl_ptr : !llvm.ptr<1> -> vector<2xi32>
rocdl.global.load.tr6.b96 %gl_ptr : !llvm.ptr<1> -> vector<3xi32>
rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xi16>
rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xf16>
rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xbf16>

rocdl.ds.load.tr4.b64 %ds_ptr : !llvm.ptr<3> -> vector<2xi32>
rocdl.ds.load.tr8.b64 %ds_ptr : !llvm.ptr<3> -> vector<2xi32>
rocdl.ds.load.tr6.b96 %ds_ptr : !llvm.ptr<3> -> vector<3xi32>
rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xi16>
rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xf16>
rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xbf16>
llvm.return
}

llvm.func @rocdl.load.to.lds(%src : !llvm.ptr<7>, %dst: !llvm.ptr<3>) {
//CHECK: call void @llvm.amdgcn.load.to.lds.p7
rocdl.load.to.lds %src, %dst, 4, 0, 0 : !llvm.ptr<7>
Expand Down