From 68987020e86940a2114e58e2a0682d5fbd034a9e Mon Sep 17 00:00:00 2001 From: Ognjen Plavsic Date: Fri, 17 Jan 2025 15:44:32 +0000 Subject: [PATCH 1/3] Add LDS transpose intrinstics --- mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 16 ++++++++++++++++ mlir/test/Dialect/LLVMIR/rocdl.mlir | 15 +++++++++++++++ mlir/test/Target/LLVMIR/rocdl.mlir | 15 +++++++++++++++ 3 files changed, 46 insertions(+) diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 71dac3ad39b7b..94be2f1c01dc9 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -412,6 +412,22 @@ def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1] def ROCDL_wmma_f32_16x16x16_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>; def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>; +//===---------------------------------------------------------------------===// +// LDS transpose intrinsics + +def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>; + +class ROCDL_Ds_Read_Tr_IntrOp : + ROCDL_IntrOp, + Arguments<(ins Arg:$ptr)>{ + let assemblyFormat = "$ptr attr-dict `:` type($ptr) `->` type($res)"; + } + +def ROCDL_ds_read_tr4_b64 : ROCDL_Ds_Read_Tr_IntrOp<"ds.read.tr4.b64">; +def ROCDL_ds_read_tr8_b64 : ROCDL_Ds_Read_Tr_IntrOp<"ds.read.tr8.b64">; +def ROCDL_ds_read_tr6_b96 : ROCDL_Ds_Read_Tr_IntrOp<"ds.read.tr6.b96">; +def ROCDL_ds_read_tr16_b64 : ROCDL_Ds_Read_Tr_IntrOp<"ds.read.tr16.b64">; + //===---------------------------------------------------------------------===// // Operations on raw buffer resources (stride of 0, bounds checks either off or in // raw buffer mode). diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index 92789246edb4f..6676219203ddf 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -227,6 +227,21 @@ func.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32, llvm.return } +llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> { + // CHECK-LABEL: rocdl.ds.read.tr + // CHECK: rocdl.ds.read.tr4.b64 {{.*}} : <3> -> vector<2xi32> + %r0 = rocdl.ds.read.tr4.b64 %ptr : !llvm.ptr<3> -> vector<2xi32> + // CHECK: rocdl.ds.read.tr6.b96 {{.*}} : <3> -> vector<3xi32> + %r1 = rocdl.ds.read.tr6.b96 %ptr : !llvm.ptr<3> -> vector<3xi32> + // CHECK: rocdl.ds.read.tr8.b64 {{.*}} : <3> -> vector<2xi32> + %r2 = rocdl.ds.read.tr8.b64 %ptr : !llvm.ptr<3> -> vector<2xi32> + // CHECK: rocdl.ds.read.tr16.b64 {{.*}} : <3> -> vector<4xf16> + %r3 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xf16> + // CHECK: rocdl.ds.read.tr16.b64 {{.*}} : <3> -> vector<4xbf16> + %r4 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xbf16> + llvm.return %r3 : vector<4xf16> +} + llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr, %stride : i16, %numRecords : i32, diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index 0620c23b5fdad..21c4f8ab55469 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -424,6 +424,21 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v llvm.return %r0 : vector<8xf32> } +llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> { + // CHECK-LABEL: rocdl.ds.read.tr + // CHECK: call <2 x i32> @llvm.amdgcn.ds.read.tr4.b64.v2i32(ptr addrspace(3) %0) + %r0 = rocdl.ds.read.tr4.b64 %ptr : !llvm.ptr<3> -> vector<2xi32> + // CHECK: call <3 x i32> @llvm.amdgcn.ds.read.tr6.b96.v3i32(ptr addrspace(3) %0) + %r1 = rocdl.ds.read.tr6.b96 %ptr : !llvm.ptr<3> -> vector<3xi32> + // CHECK: call <2 x i32> @llvm.amdgcn.ds.read.tr8.b64.v2i32(ptr addrspace(3) %0) + %r2 = rocdl.ds.read.tr8.b64 %ptr : !llvm.ptr<3> -> vector<2xi32> + // CHECK: call <4 x half> @llvm.amdgcn.ds.read.tr16.b64.v4f16(ptr addrspace(3) %0) + %r3 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xf16> + // CHECK: call <4 x bfloat> @llvm.amdgcn.ds.read.tr16.b64.v4bf16(ptr addrspace(3) %0) + %r4 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xbf16> + llvm.return %r3 : vector<4xf16> +} + llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr, %stride : i16, %numRecords : i32, From 9dbdf1e35cb02228d96d5dab4f72fb29584ffc23 Mon Sep 17 00:00:00 2001 From: Ognjen Plavsic Date: Fri, 17 Jan 2025 18:47:05 +0000 Subject: [PATCH 2/3] Add global to LDS intrinsic --- mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 24 ++++++++++++++++---- mlir/test/Dialect/LLVMIR/rocdl.mlir | 11 +++++++++ mlir/test/Target/LLVMIR/rocdl.mlir | 9 ++++++++ 3 files changed, 39 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 94be2f1c01dc9..0e23c183b4d92 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -415,18 +415,32 @@ def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", //===---------------------------------------------------------------------===// // LDS transpose intrinsics +def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>; def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>; -class ROCDL_Ds_Read_Tr_IntrOp : +class ROCDL_LDS_Read_Tr_IntrOp : ROCDL_IntrOp, Arguments<(ins Arg:$ptr)>{ let assemblyFormat = "$ptr attr-dict `:` type($ptr) `->` type($res)"; } -def ROCDL_ds_read_tr4_b64 : ROCDL_Ds_Read_Tr_IntrOp<"ds.read.tr4.b64">; -def ROCDL_ds_read_tr8_b64 : ROCDL_Ds_Read_Tr_IntrOp<"ds.read.tr8.b64">; -def ROCDL_ds_read_tr6_b96 : ROCDL_Ds_Read_Tr_IntrOp<"ds.read.tr6.b96">; -def ROCDL_ds_read_tr16_b64 : ROCDL_Ds_Read_Tr_IntrOp<"ds.read.tr16.b64">; +def ROCDL_ds_read_tr4_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr4.b64">; +def ROCDL_ds_read_tr8_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr8.b64">; +def ROCDL_ds_read_tr6_b96 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr6.b96">; +def ROCDL_ds_read_tr16_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr16.b64">; + +//===---------------------------------------------------------------------===// +// Global load to LDS intrinsic + +def ROCDL_GlobalLoadLDSOp : + ROCDL_IntrOp<"global.load.lds", [], [], [], 0>, + Arguments<(ins Arg:$globalPtr, + Arg:$ldsPtr, + I32:$size, + I32:$offset, + I32:$aux)> { + let assemblyFormat = "operands attr-dict"; +} //===---------------------------------------------------------------------===// // Operations on raw buffer resources (stride of 0, bounds checks either off or in diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index 6676219203ddf..c80ebebaafe3a 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -242,6 +242,17 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> { llvm.return %r3 : vector<4xf16> } +llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) { + %aux = llvm.mlir.constant(0 : i32) : i32 + %offset = llvm.mlir.constant(0 : i32) : i32 + %size = llvm.mlir.constant(10 : i32) : i32 + + //CHECK: rocdl.global.load.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} + rocdl.global.load.lds %src, %dst, %size, %offset, %aux + + llvm.return +} + llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr, %stride : i16, %numRecords : i32, diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index 21c4f8ab55469..996e0e34c790c 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -439,6 +439,15 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> { llvm.return %r3 : vector<4xf16> } +llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) { + %aux = llvm.mlir.constant(0 : i32) : i32 + %offset = llvm.mlir.constant(0 : i32) : i32 + %size = llvm.mlir.constant(10 : i32) : i32 + //CHECK: call void @llvm.amdgcn.global.load.lds + rocdl.global.load.lds %src, %dst, %size, %offset, %aux + llvm.return +} + llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr, %stride : i16, %numRecords : i32, From 05471984c75ae6186bffe43790c436d7d563b20d Mon Sep 17 00:00:00 2001 From: Ognjen Plavsic Date: Mon, 20 Jan 2025 10:50:31 +0000 Subject: [PATCH 3/3] Address review comments --- mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 0e23c183b4d92..0b8c0f7f381c4 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -413,7 +413,7 @@ def ROCDL_wmma_f32_16x16x16_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>; //===---------------------------------------------------------------------===// -// LDS transpose intrinsics +// LDS transpose intrinsics (available in GFX950) def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>; def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>; @@ -422,7 +422,7 @@ class ROCDL_LDS_Read_Tr_IntrOp : ROCDL_IntrOp, Arguments<(ins Arg:$ptr)>{ let assemblyFormat = "$ptr attr-dict `:` type($ptr) `->` type($res)"; - } +} def ROCDL_ds_read_tr4_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr4.b64">; def ROCDL_ds_read_tr8_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr8.b64">; @@ -430,7 +430,7 @@ def ROCDL_ds_read_tr6_b96 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr6.b96">; def ROCDL_ds_read_tr16_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr16.b64">; //===---------------------------------------------------------------------===// -// Global load to LDS intrinsic +// Global load to LDS intrinsic (available in GFX950) def ROCDL_GlobalLoadLDSOp : ROCDL_IntrOp<"global.load.lds", [], [], [], 0>,