diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 921fdf36a59b0..3d8bf9c169406 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -321,6 +321,7 @@ def ROCDL_BarrierOp : ROCDL_Op<"barrier"> { let assemblyFormat = "attr-dict"; } +def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>; def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>; def ROCDL_BarrierInitOp : ROCDL_IntrOp<"s.barrier.init", [], [], [], 0, 0, 0, 0, [1], ["id"]>, @@ -631,8 +632,6 @@ def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1] //===---------------------------------------------------------------------===// // LDS transpose intrinsics (available in GFX950) -def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>; - class ROCDL_LDS_Read_Tr_IntrOp : ROCDL_IntrOp { dag args = (ins Arg:$ptr); @@ -650,6 +649,58 @@ def ROCDL_ds_read_tr8_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr8.b64">; def ROCDL_ds_read_tr6_b96 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr6.b96">; def ROCDL_ds_read_tr16_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr16.b64">; + + +//===---------------------------------------------------------------------===// +// Glb/DS load-transpose intrinsics (available in GFX1250+) + +class AddrKind { + string name = n; + int space = s; +} +def GlobalAddrKind : AddrKind<"global", 1>; +def DSAddrKind : AddrKind<"ds", 3>; + +class ROCDL_TrLoadOpMeta { + AddrKind addrKind = kind; + string inBits = !cast(inElemBits); + string outBits = !cast(outElemBits); + string inBitsEnc = !if(!eq(addrKind.space, 1), + !if(!or(!eq(inElemBits, 8), !eq(inElemBits, 16)), "", inBits), inBits); + string mnemonic = addrKind.name # ".load.tr" # inBitsEnc # ".b" # outBits; +} + +class ROCDL_TrLoadOp : + ROCDL_IntrOp { + + dag args = (ins Arg, "", [MemRead]>:$ptr); + let arguments = !con(args, baseArgs); + let summary = "Loads and transposes a matrix from " # meta.addrKind.name # " memory to registers (available in gfx1250+)."; + let description = [{ + Load a matrix of }] # meta.inBits # [{-bit data from the }] # meta.addrKind.name # [{ memory, + transpose data between row-major and column-major order, + and store the result into a }] # meta.outBits # [{-bit vector register. + + Available in gfx1250+. + }]; + let assemblyFormat = "$ptr attr-dict `:` qualified(type($ptr)) `->` type($res)"; + let extraClassDefinition = [{ + ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() { + return {getPtr()}; + } + }]; +} + +def ROCDL_GlobalLoadTr4_B64 : ROCDL_TrLoadOp>; +def ROCDL_GlobalLoadTr8_B64 : ROCDL_TrLoadOp>; +def ROCDL_GlobalLoadTr6_B96 : ROCDL_TrLoadOp>; +def ROCDL_GlobalLoadTr8_B128 : ROCDL_TrLoadOp>; + +def ROCDL_DsLoadTr4_B64 : ROCDL_TrLoadOp>; +def ROCDL_DsLoadTr8_B64 : ROCDL_TrLoadOp>; +def ROCDL_DsLoadTr6_B96 : ROCDL_TrLoadOp>; +def ROCDL_DsLoadTr16_B128 : ROCDL_TrLoadOp>; + //===---------------------------------------------------------------------===// // Load to LDS intrinsic (available in GFX9 and GFX10) //===---------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index 5e857599b65ea..d50cc41684e3c 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -650,6 +650,39 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> { llvm.return %r3 : vector<4xf16> } +llvm.func @rocdl.load.tr.ops(%gl_ptr : !llvm.ptr<1>, %ds_ptr : !llvm.ptr<3>) { + // CHECK-LABEL: @rocdl.load.tr.ops + // CHECK-SAME: (%[[GL_PTR:.+]]: !llvm.ptr<1>, %[[DS_OTR:.+]]: !llvm.ptr<3>) + // CHECK: rocdl.global.load.tr4.b64 %[[GL_PTR]] : !llvm.ptr<1> -> vector<2xi32> + // CHECK: rocdl.global.load.tr.b64 %[[GL_PTR]] : !llvm.ptr<1> -> vector<2xi32> + // CHECK: rocdl.global.load.tr6.b96 %[[GL_PTR]] : !llvm.ptr<1> -> vector<3xi32> + // CHECK: rocdl.global.load.tr.b128 %[[GL_PTR]] : !llvm.ptr<1> -> vector<8xi16> + // CHECK: rocdl.global.load.tr.b128 %[[GL_PTR]] : !llvm.ptr<1> -> vector<8xf16> + // CHECK: rocdl.global.load.tr.b128 %[[GL_PTR]] : !llvm.ptr<1> -> vector<8xbf16> + // CHECK: rocdl.ds.load.tr4.b64 %[[DS_OTR]] : !llvm.ptr<3> -> vector<2xi32> + // CHECK: rocdl.ds.load.tr8.b64 %[[DS_OTR]] : !llvm.ptr<3> -> vector<2xi32> + // CHECK: rocdl.ds.load.tr6.b96 %[[DS_OTR]] : !llvm.ptr<3> -> vector<3xi32> + // CHECK: rocdl.ds.load.tr16.b128 %[[DS_OTR]] : !llvm.ptr<3> -> vector<8xi16> + // CHECK: rocdl.ds.load.tr16.b128 %[[DS_OTR]] : !llvm.ptr<3> -> vector<8xf16> + // CHECK: rocdl.ds.load.tr16.b128 %[[DS_OTR]] : !llvm.ptr<3> -> vector<8xbf16> + // CHECK: llvm.return + + rocdl.global.load.tr4.b64 %gl_ptr : !llvm.ptr<1> -> vector<2xi32> + rocdl.global.load.tr.b64 %gl_ptr : !llvm.ptr<1> -> vector<2xi32> + rocdl.global.load.tr6.b96 %gl_ptr : !llvm.ptr<1> -> vector<3xi32> + rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xi16> + rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xf16> + rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xbf16> + + rocdl.ds.load.tr4.b64 %ds_ptr : !llvm.ptr<3> -> vector<2xi32> + rocdl.ds.load.tr8.b64 %ds_ptr : !llvm.ptr<3> -> vector<2xi32> + rocdl.ds.load.tr6.b96 %ds_ptr : !llvm.ptr<3> -> vector<3xi32> + rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xi16> + rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xf16> + rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xbf16> + llvm.return +} + llvm.func @rocdl.load.to.lds(%src : !llvm.ptr<7>, %dst: !llvm.ptr<3>) { // CHECK-LABEL @rocdl.load.to.lds //CHECK: rocdl.load.to.lds %{{.*}}, %{{.*}}, 4, 0, 0 : <7> diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index 3fbd9e0567948..db02918d7186c 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -1028,6 +1028,39 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> { llvm.return %r3 : vector<4xf16> } +llvm.func @rocdl.load.tr.ops(%gl_ptr : !llvm.ptr<1>, %ds_ptr : !llvm.ptr<3>) { + // CHECK-LABEL: rocdl.load.tr.ops + // CHECK-SAME: (ptr addrspace(1) %[[GL_PTR:.+]], ptr addrspace(3) %[[DS_PTR:.+]]) + // CHECK: call <2 x i32> @llvm.amdgcn.global.load.tr4.b64.v2i32(ptr addrspace(1) %[[GL_PTR]]) + // CHECK: call <2 x i32> @llvm.amdgcn.global.load.tr.b64.v2i32(ptr addrspace(1) %[[GL_PTR]]) + // CHECK: call <3 x i32> @llvm.amdgcn.global.load.tr6.b96.v3i32(ptr addrspace(1) %[[GL_PTR]]) + // CHECK: call <8 x i16> @llvm.amdgcn.global.load.tr.b128.v8i16(ptr addrspace(1) %[[GL_PTR]]) + // CHECK: call <8 x half> @llvm.amdgcn.global.load.tr.b128.v8f16(ptr addrspace(1) %[[GL_PTR]]) + // CHECK: call <8 x bfloat> @llvm.amdgcn.global.load.tr.b128.v8bf16(ptr addrspace(1) %[[GL_PTR]]) + + // CHECK: call <2 x i32> @llvm.amdgcn.ds.load.tr4.b64.v2i32(ptr addrspace(3) %[[DS_PTR]]) + // CHECK: call <2 x i32> @llvm.amdgcn.ds.load.tr8.b64.v2i32(ptr addrspace(3) %[[DS_PTR]]) + // CHECK: call <3 x i32> @llvm.amdgcn.ds.load.tr6.b96.v3i32(ptr addrspace(3) %[[DS_PTR]]) + // CHECK: call <8 x i16> @llvm.amdgcn.ds.load.tr16.b128.v8i16(ptr addrspace(3) %[[DS_PTR]]) + // CHECK: call <8 x half> @llvm.amdgcn.ds.load.tr16.b128.v8f16(ptr addrspace(3) %[[DS_PTR]]) + // CHECK: call <8 x bfloat> @llvm.amdgcn.ds.load.tr16.b128.v8bf16(ptr addrspace(3) %[[DS_PTR]]) + + rocdl.global.load.tr4.b64 %gl_ptr : !llvm.ptr<1> -> vector<2xi32> + rocdl.global.load.tr.b64 %gl_ptr : !llvm.ptr<1> -> vector<2xi32> + rocdl.global.load.tr6.b96 %gl_ptr : !llvm.ptr<1> -> vector<3xi32> + rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xi16> + rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xf16> + rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xbf16> + + rocdl.ds.load.tr4.b64 %ds_ptr : !llvm.ptr<3> -> vector<2xi32> + rocdl.ds.load.tr8.b64 %ds_ptr : !llvm.ptr<3> -> vector<2xi32> + rocdl.ds.load.tr6.b96 %ds_ptr : !llvm.ptr<3> -> vector<3xi32> + rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xi16> + rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xf16> + rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xbf16> + llvm.return +} + llvm.func @rocdl.load.to.lds(%src : !llvm.ptr<7>, %dst: !llvm.ptr<3>) { //CHECK: call void @llvm.amdgcn.load.to.lds.p7 rocdl.load.to.lds %src, %dst, 4, 0, 0 : !llvm.ptr<7>