diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 71dac3ad39b7b..0b8c0f7f381c4 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -412,6 +412,36 @@ def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1] def ROCDL_wmma_f32_16x16x16_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>; def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>; +//===---------------------------------------------------------------------===// +// LDS transpose intrinsics (available in GFX950) + +def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>; +def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>; + +class ROCDL_LDS_Read_Tr_IntrOp : + ROCDL_IntrOp, + Arguments<(ins Arg:$ptr)>{ + let assemblyFormat = "$ptr attr-dict `:` type($ptr) `->` type($res)"; +} + +def ROCDL_ds_read_tr4_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr4.b64">; +def ROCDL_ds_read_tr8_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr8.b64">; +def ROCDL_ds_read_tr6_b96 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr6.b96">; +def ROCDL_ds_read_tr16_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr16.b64">; + +//===---------------------------------------------------------------------===// +// Global load to LDS intrinsic (available in GFX950) + +def ROCDL_GlobalLoadLDSOp : + ROCDL_IntrOp<"global.load.lds", [], [], [], 0>, + Arguments<(ins Arg:$globalPtr, + Arg:$ldsPtr, + I32:$size, + I32:$offset, + I32:$aux)> { + let assemblyFormat = "operands attr-dict"; +} + //===---------------------------------------------------------------------===// // Operations on raw buffer resources (stride of 0, bounds checks either off or in // raw buffer mode). diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index 92789246edb4f..c80ebebaafe3a 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -227,6 +227,32 @@ func.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32, llvm.return } +llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> { + // CHECK-LABEL: rocdl.ds.read.tr + // CHECK: rocdl.ds.read.tr4.b64 {{.*}} : <3> -> vector<2xi32> + %r0 = rocdl.ds.read.tr4.b64 %ptr : !llvm.ptr<3> -> vector<2xi32> + // CHECK: rocdl.ds.read.tr6.b96 {{.*}} : <3> -> vector<3xi32> + %r1 = rocdl.ds.read.tr6.b96 %ptr : !llvm.ptr<3> -> vector<3xi32> + // CHECK: rocdl.ds.read.tr8.b64 {{.*}} : <3> -> vector<2xi32> + %r2 = rocdl.ds.read.tr8.b64 %ptr : !llvm.ptr<3> -> vector<2xi32> + // CHECK: rocdl.ds.read.tr16.b64 {{.*}} : <3> -> vector<4xf16> + %r3 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xf16> + // CHECK: rocdl.ds.read.tr16.b64 {{.*}} : <3> -> vector<4xbf16> + %r4 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xbf16> + llvm.return %r3 : vector<4xf16> +} + +llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) { + %aux = llvm.mlir.constant(0 : i32) : i32 + %offset = llvm.mlir.constant(0 : i32) : i32 + %size = llvm.mlir.constant(10 : i32) : i32 + + //CHECK: rocdl.global.load.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} + rocdl.global.load.lds %src, %dst, %size, %offset, %aux + + llvm.return +} + llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr, %stride : i16, %numRecords : i32, diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index 0620c23b5fdad..996e0e34c790c 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -424,6 +424,30 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v llvm.return %r0 : vector<8xf32> } +llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> { + // CHECK-LABEL: rocdl.ds.read.tr + // CHECK: call <2 x i32> @llvm.amdgcn.ds.read.tr4.b64.v2i32(ptr addrspace(3) %0) + %r0 = rocdl.ds.read.tr4.b64 %ptr : !llvm.ptr<3> -> vector<2xi32> + // CHECK: call <3 x i32> @llvm.amdgcn.ds.read.tr6.b96.v3i32(ptr addrspace(3) %0) + %r1 = rocdl.ds.read.tr6.b96 %ptr : !llvm.ptr<3> -> vector<3xi32> + // CHECK: call <2 x i32> @llvm.amdgcn.ds.read.tr8.b64.v2i32(ptr addrspace(3) %0) + %r2 = rocdl.ds.read.tr8.b64 %ptr : !llvm.ptr<3> -> vector<2xi32> + // CHECK: call <4 x half> @llvm.amdgcn.ds.read.tr16.b64.v4f16(ptr addrspace(3) %0) + %r3 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xf16> + // CHECK: call <4 x bfloat> @llvm.amdgcn.ds.read.tr16.b64.v4bf16(ptr addrspace(3) %0) + %r4 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xbf16> + llvm.return %r3 : vector<4xf16> +} + +llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) { + %aux = llvm.mlir.constant(0 : i32) : i32 + %offset = llvm.mlir.constant(0 : i32) : i32 + %size = llvm.mlir.constant(10 : i32) : i32 + //CHECK: call void @llvm.amdgcn.global.load.lds + rocdl.global.load.lds %src, %dst, %size, %offset, %aux + llvm.return +} + llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr, %stride : i16, %numRecords : i32,