diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 2f87975ebaa04..a18c18af8a753 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -2116,6 +2116,56 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels", let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// acc.kernel_environment +//===----------------------------------------------------------------------===// + +def OpenACC_KernelEnvironmentOp : OpenACC_Op<"kernel_environment", + [AttrSizedOperandSegments, RecursiveMemoryEffects, SingleBlock, + NoTerminator, + MemoryEffects<[MemWrite, + MemRead]>]> { + let summary = "Decomposition of compute constructs to capture data mapping " + "and asynchronous behavior information"; + let description = [{ + The `acc.kernel_environment` operation represents a decomposition of + any OpenACC compute construct (acc.kernels, acc.parallel, or + acc.serial) that captures data mapping and asynchronous behavior: + - data clause operands + - async clause operands + - wait clause operands + + This allows kernel execution parallelism and privatization to be + handled separately, facilitating eventual lowering to GPU dialect where + kernel launching and compute offloading are handled separately. + }]; + + let arguments = (ins + Variadic:$dataClauseOperands, + Variadic:$asyncOperands, + OptionalAttr:$asyncOperandsDeviceType, + OptionalAttr:$asyncOnly, + Variadic:$waitOperands, + OptionalAttr:$waitOperandsSegments, + OptionalAttr:$waitOperandsDeviceType, + OptionalAttr:$hasWaitDevnum, + OptionalAttr:$waitOnly); + + let regions = (region SizedRegion<1>:$region); + + let assemblyFormat = [{ + oilist( + `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)` + | `async` `` custom($asyncOperands, + type($asyncOperands), $asyncOperandsDeviceType, $asyncOnly) + | `wait` `` custom($waitOperands, type($waitOperands), + $waitOperandsDeviceType, $waitOperandsSegments, $hasWaitDevnum, + $waitOnly) + ) + $region attr-dict + }]; +} + //===----------------------------------------------------------------------===// // 2.6.5 data Construct //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir index 77d18da49276a..042ee2503cb95 100644 --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -2243,3 +2243,76 @@ func.func @test_firstprivate_map(%arg0: memref<10xf32>) { // CHECK-NEXT: acc.yield // CHECK-NEXT: } // CHECK-NEXT: return + +// ----- + +func.func @test_kernel_environment(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) { + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + + // Create data clause operands for the kernel environment + %copyin = acc.copyin varPtr(%arg0 : memref<1024xf32>) -> memref<1024xf32> + %create = acc.create varPtr(%arg1 : memref<1024xf32>) -> memref<1024xf32> + + // Kernel environment wraps gpu.launch and captures data mapping + acc.kernel_environment dataOperands(%copyin, %create : memref<1024xf32>, memref<1024xf32>) { + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c1024, %block_y = %c1, %block_z = %c1) { + // Kernel body uses the mapped data + %val = memref.load %copyin[%tx] : memref<1024xf32> + %result = arith.mulf %val, %val : f32 + memref.store %result, %create[%tx] : memref<1024xf32> + gpu.terminator + } + } + + // Copy results back to host and deallocate device memory + acc.copyout accPtr(%create : memref<1024xf32>) to varPtr(%arg1 : memref<1024xf32>) + acc.delete accPtr(%copyin : memref<1024xf32>) + + return +} + +// CHECK-LABEL: func @test_kernel_environment +// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr(%{{.*}} : memref<1024xf32>) -> memref<1024xf32> +// CHECK: %[[CREATE:.*]] = acc.create varPtr(%{{.*}} : memref<1024xf32>) -> memref<1024xf32> +// CHECK: acc.kernel_environment dataOperands(%[[COPYIN]], %[[CREATE]] : memref<1024xf32>, memref<1024xf32>) { +// CHECK: gpu.launch +// CHECK: memref.load %[[COPYIN]] +// CHECK: memref.store %{{.*}}, %[[CREATE]] +// CHECK: } +// CHECK: } +// CHECK: acc.copyout accPtr(%[[CREATE]] : memref<1024xf32>) to varPtr(%{{.*}} : memref<1024xf32>) +// CHECK: acc.delete accPtr(%[[COPYIN]] : memref<1024xf32>) + +// ----- + +func.func @test_kernel_environment_with_async(%arg0: memref<1024xf32>) { + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + %async_val = arith.constant 1 : i32 + + %create = acc.create varPtr(%arg0 : memref<1024xf32>) async(%async_val : i32) -> memref<1024xf32> + + // Kernel environment with async clause + acc.kernel_environment dataOperands(%create : memref<1024xf32>) async(%async_val : i32) { + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c1024, %block_y = %c1, %block_z = %c1) { + %f0 = arith.constant 0.0 : f32 + memref.store %f0, %create[%tx] : memref<1024xf32> + gpu.terminator + } + } + + acc.copyout accPtr(%create : memref<1024xf32>) async(%async_val : i32) to varPtr(%arg0 : memref<1024xf32>) + + return +} + +// CHECK-LABEL: func @test_kernel_environment_with_async +// CHECK: %[[ASYNC:.*]] = arith.constant 1 : i32 +// CHECK: %[[CREATE:.*]] = acc.create varPtr(%{{.*}} : memref<1024xf32>) async(%[[ASYNC]] : i32) -> memref<1024xf32> +// CHECK: acc.kernel_environment dataOperands(%[[CREATE]] : memref<1024xf32>) async(%[[ASYNC]] : i32) +// CHECK: gpu.launch +// CHECK: memref.store %{{.*}}, %[[CREATE]] +// CHECK: acc.copyout accPtr(%[[CREATE]] : memref<1024xf32>) async(%[[ASYNC]] : i32) to varPtr(%{{.*}} : memref<1024xf32>)