diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index d2df244eb9363..5241f9a6f2b43 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -146,6 +146,35 @@ class ROCDL_DimGetterFunctionOp : + FixedVectorOfLengthAndType<[length], [elem]>, + BuildableType< + "::mlir::VectorType::get({" # length # "} ," + # elem.builderCall # ")">; + +def ROCDL_V2I16Type : ROCDL_ConcreteVector; +def ROCDL_V2F16Type : ROCDL_ConcreteVector; +def ROCDL_V2I32Type : ROCDL_ConcreteVector; +def ROCDL_V2BF16Type : ROCDL_ConcreteVector; +def ROCDL_V2F32Type : ROCDL_ConcreteVector; +def ROCDL_V3I32Type : ROCDL_ConcreteVector; +def ROCDL_V4I32Type : ROCDL_ConcreteVector; +def ROCDL_V6I32Type : ROCDL_ConcreteVector; +def ROCDL_V8I32Type : ROCDL_ConcreteVector; +def ROCDL_V8BF16Type : ROCDL_ConcreteVector; +def ROCDL_V8F16Type : ROCDL_ConcreteVector; +def ROCDL_V8F32Type : ROCDL_ConcreteVector; +def ROCDL_V16BF16Type : ROCDL_ConcreteVector; +def ROCDL_V16F16Type : ROCDL_ConcreteVector; +def ROCDL_V16F32Type : ROCDL_ConcreteVector; +def ROCDL_V32F16Type : ROCDL_ConcreteVector; +def ROCDL_V32BF16Type : ROCDL_ConcreteVector; +def ROCDL_V32F32Type : ROCDL_ConcreteVector; + //===----------------------------------------------------------------------===// // Wave-level primitives //===----------------------------------------------------------------------===// @@ -663,6 +692,68 @@ def ROCDL_GlobalLoadLDSOp : }]; } +//===---------------------------------------------------------------------===// +// Tensor load/store intrinsics (available in GFX1250) +//===---------------------------------------------------------------------===// + +// Base class for tensor load/store operations with 4 descriptor groups. +class ROCDL_TensorLDSIntrOp : + ROCDL_IntrOp { + dag args = (ins ROCDL_V4I32Type:$dgroup0, ROCDL_V8I32Type:$dgroup1, + ROCDL_V4I32Type:$dgroup2, ROCDL_V4I32Type:$dgroup3, + I32Attr:$cachePolicy); + let arguments = !con(args, baseArgs); + let summary = "Base class for ROCDL tensor load/store to/from LDS."; + let description = [{ + Moves tiles of tensor data between global memory and LDS. The tile is + described by the $dgroup descriptors. 4 $dgroup descriptors allows for + movement of up to 5D tensors. $cachePolicy describes the memory scope and an + indicator of expected data re-use. + + This op is for gfx1250+ architectures. + }]; + let assemblyFormat = [{ + attr-dict operands `cachepolicy` $cachePolicy `:` type($dgroup0) `,` type($dgroup1) + }]; + let extraClassDefinition = [{ + SmallVector $cppClass::getAccessedOperands() { + return {getDgroup0(), getDgroup1(), getDgroup2(), getDgroup3()}; + } + }]; +} + +// Base class for tensor load/store operations with 2 descriptor groups +// (D2 variant). +class ROCDL_TensorLDSIntrD2Op : + ROCDL_IntrOp { + dag args = (ins ROCDL_V4I32Type:$dgroup0, ROCDL_V8I32Type:$dgroup1, + I32Attr:$cachePolicy); + let arguments = !con(args, baseArgs); + let summary = "Base class for ROCDL tensor load/store to/from LDS (D2 variant)."; + let description = [{ + Moves tiles of tensor data between global memory and LDS. The tile is + described by the $dgroup descriptors. 2 $dgroup descriptors allows for + movement of up to 2D tensors. $cachePolicy describes the memory scope and an + indicator of expected data re-use. + + This op is for gfx1250+ architectures. + }]; + let assemblyFormat = [{ + attr-dict operands `cachepolicy` $cachePolicy `:` type($dgroup0) `,` type($dgroup1) + }]; + let extraClassDefinition = [{ + SmallVector $cppClass::getAccessedOperands() { + return {getDgroup0(), getDgroup1()}; + } + }]; +} + +// Tensor load and store operations +def ROCDL_TensorLoadToLDSOp : ROCDL_TensorLDSIntrOp<"tensor.load.to.lds">; +def ROCDL_TensorStoreFromLDSOp : ROCDL_TensorLDSIntrOp<"tensor.store.from.lds">; +def ROCDL_TensorLoadToLDSD2Op : ROCDL_TensorLDSIntrD2Op<"tensor.load.to.lds.d2">; +def ROCDL_TensorStoreFromLDSD2Op : ROCDL_TensorLDSIntrD2Op<"tensor.store.from.lds.d2">; + //===---------------------------------------------------------------------===// // Operations on raw buffer resources (stride of 0, bounds checks either off or in // raw buffer mode). @@ -932,30 +1023,6 @@ def ROCDL_Permlane32SwapOp : ROCDL_IntrOp<"permlane32.swap", [], [], }]; } -class ROCDL_ConcreteVector : - FixedVectorOfLengthAndType<[length], [elem]>, - BuildableType< - "::mlir::VectorType::get({" # length # "} ," - # elem.builderCall # ")">; - -def ROCDL_V2I16Type : ROCDL_ConcreteVector; -def ROCDL_V2F16Type : ROCDL_ConcreteVector; -def ROCDL_V2I32Type : ROCDL_ConcreteVector; -def ROCDL_V2BF16Type : ROCDL_ConcreteVector; -def ROCDL_V2F32Type : ROCDL_ConcreteVector; -def ROCDL_V3I32Type : ROCDL_ConcreteVector; -def ROCDL_V6I32Type : ROCDL_ConcreteVector; -def ROCDL_V8I32Type : ROCDL_ConcreteVector; -def ROCDL_V8BF16Type : ROCDL_ConcreteVector; -def ROCDL_V8F16Type : ROCDL_ConcreteVector; -def ROCDL_V8F32Type : ROCDL_ConcreteVector; -def ROCDL_V16BF16Type : ROCDL_ConcreteVector; -def ROCDL_V16F16Type : ROCDL_ConcreteVector; -def ROCDL_V16F32Type : ROCDL_ConcreteVector; -def ROCDL_V32F16Type : ROCDL_ConcreteVector; -def ROCDL_V32BF16Type : ROCDL_ConcreteVector; -def ROCDL_V32F32Type : ROCDL_ConcreteVector; - //===---------------------------------------------------------------------===// // 16-bit float intrinsics //===---------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index d270ee8b089aa..e703600c71c8e 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -664,6 +664,36 @@ llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) { llvm.return } +// CHECK-LABEL @rocdl.tensor.load.to.lds +llvm.func @rocdl.tensor.load.to.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>, + %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) { + // CHECK: rocdl.tensor.load.to.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} cachepolicy 0 : vector<4xi32>, vector<8xi32> + rocdl.tensor.load.to.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32> + llvm.return +} + +// CHECK-LABEL @rocdl.tensor.store.from.lds +llvm.func @rocdl.tensor.store.from.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>, + %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) { + // CHECK: rocdl.tensor.store.from.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} cachepolicy 0 : vector<4xi32>, vector<8xi32> + rocdl.tensor.store.from.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32> + llvm.return +} + +// CHECK-LABEL @rocdl.tensor.load.to.lds.d2 +llvm.func @rocdl.tensor.load.to.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) { + // CHECK: rocdl.tensor.load.to.lds.d2 %{{.*}}, %{{.*}} cachepolicy 0 : vector<4xi32>, vector<8xi32> + rocdl.tensor.load.to.lds.d2 %dgroup0, %dgroup1 cachepolicy 0 : vector<4xi32>, vector<8xi32> + llvm.return +} + +// CHECK-LABEL @rocdl.tensor.store.from.lds.d2 +llvm.func @rocdl.tensor.store.from.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) { + // CHECK: rocdl.tensor.store.from.lds.d2 %{{.*}}, %{{.*}} cachepolicy 0 : vector<4xi32>, vector<8xi32> + rocdl.tensor.store.from.lds.d2 %dgroup0, %dgroup1 cachepolicy 0 : vector<4xi32>, vector<8xi32> + llvm.return +} + llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr, %stride : i16, %numRecords : i64, diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index 30126f6bff05a..8a848221a50dd 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -1040,6 +1040,36 @@ llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) { llvm.return } +// CHECK-LABEL: rocdl.tensor.load.to.lds +llvm.func @rocdl.tensor.load.to.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>, + %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) { + // CHECK: call void @llvm.amdgcn.tensor.load.to.lds(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 0) + rocdl.tensor.load.to.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32> + llvm.return +} + +// CHECK-LABEL: rocdl.tensor.store.from.lds +llvm.func @rocdl.tensor.store.from.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>, + %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) { + // CHECK: call void @llvm.amdgcn.tensor.store.from.lds(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 0) + rocdl.tensor.store.from.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32> + llvm.return +} + +// CHECK-LABEL: rocdl.tensor.load.to.lds.d2 +llvm.func @rocdl.tensor.load.to.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) { + // CHECK: call void @llvm.amdgcn.tensor.load.to.lds.d2(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i32 0) + rocdl.tensor.load.to.lds.d2 %dgroup0, %dgroup1 cachepolicy 0 : vector<4xi32>, vector<8xi32> + llvm.return +} + +// CHECK-LABEL: rocdl.tensor.store.from.lds.d2 +llvm.func @rocdl.tensor.store.from.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) { + // CHECK: call void @llvm.amdgcn.tensor.store.from.lds.d2(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i32 0) + rocdl.tensor.store.from.lds.d2 %dgroup0, %dgroup1 cachepolicy 0 : vector<4xi32>, vector<8xi32> + llvm.return +} + llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr, %stride : i16, %numRecords : i64,