From 28305ec0268a8a9919a8e15ed981c32b28d0f45c Mon Sep 17 00:00:00 2001 From: Anchu Rajendran Date: Mon, 9 Sep 2024 17:05:41 -0500 Subject: [PATCH] Adding MLIR Op definition for scan --- .../Dialect/OpenMP/OpenMPClauseOperands.h | 9 +++ .../mlir/Dialect/OpenMP/OpenMPClauses.td | 55 +++++++++++++++++++ mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 14 +++++ mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 9 +++ mlir/test/Dialect/OpenMP/ops.mlir | 8 +++ 5 files changed, 95 insertions(+) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h index 38e4d8f245e4f..af96f24c8fe2c 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h @@ -105,6 +105,13 @@ struct IfClauseOps { Value ifVar; }; +struct InclusiveClauseOps { + llvm::SmallVector inclusiveVars; +}; + +struct ExclusiveClauseOps { + llvm::SmallVector exclusiveVars; +}; struct InReductionClauseOps { llvm::SmallVector inReductionVars; llvm::SmallVector inReductionByref; @@ -261,6 +268,8 @@ using LoopNestOperands = detail::Clauses; using MaskedOperands = detail::Clauses; +using ScanOperands = detail::Clauses; + using OrderedOperands = detail::Clauses; using OrderedRegionOperands = detail::Clauses; diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index e703c323edbc8..79b54545ef533 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -503,6 +503,61 @@ class OpenMP_IsDevicePtrClauseSkip< def OpenMP_IsDevicePtrClause : OpenMP_IsDevicePtrClauseSkip<>; +//===----------------------------------------------------------------------===// +// V5.2: [5.4.7] `inclusive` clause +//===----------------------------------------------------------------------===// + +class OpenMP_InclusiveClauseSkip< + bit traits = false, bit arguments = false, bit assemblyFormat = false, + bit description = false, bit extraClassDeclaration = false + > : OpenMP_Clause { + let arguments = (ins + Variadic:$inclusive_vars + ); + + let assemblyFormat = [{ + `inclusive` `(` $inclusive_vars `:` type($inclusive_vars) `)` + }]; + + let description = [{ + The inclusive clause is used on a separating directive that separates a + structured block into two structured block sequences. If the inclusive + clause is specified, the input phase includes the preceding structured block + sequence and the scan phase includes the following structured block sequence. + }]; +} + +def OpenMP_InclusiveClause : OpenMP_InclusiveClauseSkip<>; + +//===----------------------------------------------------------------------===// +// V5.2: [5.4.7] `exclusive` clause +//===----------------------------------------------------------------------===// + +class OpenMP_ExclusiveClauseSkip< + bit traits = false, bit arguments = false, bit assemblyFormat = false, + bit description = false, bit extraClassDeclaration = false + > : OpenMP_Clause { + let arguments = (ins + Variadic:$exclusive_vars + ); + + let assemblyFormat = [{ + `exclusive` `(` $exclusive_vars `:` type($exclusive_vars) `)` + }]; + + let description = [{ + The exclusive clause is used on a separating directive that separates a + structured block into two structured block sequences. If the exclusive clause + is specified, the input phase excludes the preceding structured block + sequence and instead includes the following structured block sequence, + while the scan phase includes the preceding structured block sequence. + }]; +} + +def OpenMP_ExclusiveClause : OpenMP_ExclusiveClauseSkip<>; + //===----------------------------------------------------------------------===// // V5.2: [5.4.6] `linear` clause //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 1aa4e771cd4de..2a81ac2f09072 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1202,6 +1202,20 @@ def OrderedRegionOp : OpenMP_Op<"ordered.region", clauses = [ let hasVerifier = 1; } +def ScanOp : OpenMP_Op<"scan", traits = [ + AttrSizedOperandSegments + ], clauses = [OpenMP_InclusiveClause, OpenMP_ExclusiveClause]> { + let summary = "scan construct with the region"; + let description = [{ + The scan without region is a stand-alone directive that + }] # clausesDescription; + + let builders = [ + OpBuilder<(ins CArg<"const ScanOperands &">:$clauses)> + ]; + +} + //===----------------------------------------------------------------------===// // 2.17.5 taskwait Construct //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 1a9b87f0d68c9..04b4f9b8e8b25 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2619,6 +2619,15 @@ LogicalResult PrivateClauseOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// Spec 5.2: Scan construct (5.6) +//===----------------------------------------------------------------------===// + +void ScanOp::build(OpBuilder &builder, OperationState &state, + const ScanOperands &clauses) { + ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars); +} + //===----------------------------------------------------------------------===// // Spec 5.2: Masked construct (10.5) //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index dce5b3950def4..09db285a0274e 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -30,6 +30,14 @@ func.func @omp_masked(%filtered_thread_id : i32) -> () { return } +func.func @omp_scan(%arg0 : memref) -> () { + // CHECK: omp.scan inclusive(%{{.*}} : memref) + omp.scan inclusive(%arg0 : memref) + // CHECK: omp.scan exclusive(%{{.*}} : memref) + omp.scan exclusive(%arg0 : memref) + return +} + func.func @omp_taskwait() -> () { // CHECK: omp.taskwait omp.taskwait