diff --git a/mlir/docs/Dialects/OpenMPDialect/_index.md b/mlir/docs/Dialects/OpenMPDialect/_index.md index b651b3c06485c..adde176750437 100644 --- a/mlir/docs/Dialects/OpenMPDialect/_index.md +++ b/mlir/docs/Dialects/OpenMPDialect/_index.md @@ -352,12 +352,29 @@ let assemblyFormat = clausesAssemblyFormat # [{ ``` The `BlockArgOpenMPOpInterface` has been introduced to simplify the addition and -handling of these kinds of clauses. It holds `numBlockArgs()` -functions that by default return 0, to be overriden by each clause through the -`extraClassDeclaration` property. Based on these functions and the expected -alphabetical sorting between entry block argument-defining clauses, it -implements `getBlockArgs()` functions that are the intended method -of accessing clause-defined block arguments. +handling of these kinds of clauses. Adding it to an operation directly, or +indirectly through a clause, results in the addition of overridable +`getVars()` and `numBlockArgs()` public functions for +all entry block argument-generating clauses. By default, the reported number of +block arguments defined by a clause will correspond to the number of operands +taken by the operation for that clause. This list of operands will be empty by +default, and will automatically be overriden by getters of the corresponding +`Variadic<...> $_vars` argument of the same clause's definition. + +In addition to these methods added to the actual operations, the +`BlockArgOpenMPOpInterface` itself defines a set of methods based on the +previous ones and on the convention that entry block arguments for multiple +clauses are sorted alphabetically by clause name. These are listed below, and +they represent the main way in which clause-defined block arguments should be +accessed: + - `getBlockArgsStart()`: Returns the index within the list of + entry block arguments where the first element defined by the given clause + should be located. + - `getBlockArgs()`: Returns the list of entry block arguments + defined by the given clause. + - `getBlockArgsPairs()`: Returns a list of pairs where the first element is + the outside value, or operand, and the second element is the corresponding + entry block argument. ## Loop-Associated Directives diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 098c33c11c030..12da584926af8 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -470,12 +470,6 @@ class OpenMP_HasDeviceAddrClauseSkip< Variadic:$has_device_addr_vars ); - let extraClassDeclaration = [{ - unsigned numHasDeviceAddrBlockArgs() { - return getHasDeviceAddrVars().size(); - } - }]; - let description = [{ The optional `has_device_addr_vars` indicates that list items already have device addresses, so they may be directly accessed from the target device. @@ -565,12 +559,6 @@ class OpenMP_HostEvalClauseSkip< Variadic:$host_eval_vars ); - let extraClassDeclaration = [{ - unsigned numHostEvalBlockArgs() { - return getHostEvalVars().size(); - } - }]; - let description = [{ The optional `host_eval_vars` holds values defined outside of the region of the `IsolatedFromAbove` operation for which a corresponding entry block @@ -629,12 +617,10 @@ class OpenMP_InReductionClauseSkip< let extraClassDeclaration = [{ /// Returns the reduction variables. - SmallVector getReductionVars() { + SmallVector getAllReductionVars() { return SmallVector(getInReductionVars().begin(), getInReductionVars().end()); } - - unsigned numInReductionBlockArgs() { return getInReductionVars().size(); } }]; // Description varies depending on the operation. Assembly format not defined @@ -749,6 +735,9 @@ class OpenMP_MapClauseSkip< Variadic:$map_vars ); + // This assembly format should only be used by operations where `map` does not + // define entry block arguments. Otherwise, it must be printed and parsed + // together with the corresponding region. let optAssemblyFormat = [{ `map_entries` `(` $map_vars `:` type($map_vars) `)` }]; @@ -1060,8 +1049,6 @@ class OpenMP_DetachClauseSkip< : OpenMP_Clause { - let traits = [BlockArgOpenMPOpInterface]; - let arguments = (ins Optional:$event_handle); let optAssemblyFormat = [{ @@ -1126,10 +1113,6 @@ class OpenMP_PrivateClauseSkip< OptionalAttr:$private_syms ); - let extraClassDeclaration = [{ - unsigned numPrivateBlockArgs() { return getPrivateVars().size(); } - }]; - // TODO: Add description. // Assembly format not defined because this clause must be processed together // with the first region of the operation, as it defines entry block @@ -1186,7 +1169,6 @@ class OpenMP_ReductionClauseSkip< let extraClassDeclaration = [{ /// Returns the number of reduction variables. unsigned getNumReductionVars() { return getReductionVars().size(); } - unsigned numReductionBlockArgs() { return getReductionVars().size(); } }]; // Description varies depending on the operation. @@ -1316,14 +1298,10 @@ class OpenMP_TaskReductionClauseSkip< let extraClassDeclaration = [{ /// Returns the reduction variables. - SmallVector getReductionVars() { + SmallVector getAllReductionVars() { return SmallVector(getTaskReductionVars().begin(), getTaskReductionVars().end()); } - - unsigned numTaskReductionBlockArgs() { - return getTaskReductionVars().size(); - } }]; let description = [{ @@ -1413,12 +1391,6 @@ class OpenMP_UseDeviceAddrClauseSkip< Variadic:$use_device_addr_vars ); - let extraClassDeclaration = [{ - unsigned numUseDeviceAddrBlockArgs() { - return getUseDeviceAddrVars().size(); - } - }]; - let description = [{ The optional `use_device_addr_vars` specifies the address of the objects in the device data environment. @@ -1448,12 +1420,6 @@ class OpenMP_UseDevicePtrClauseSkip< Variadic:$use_device_ptr_vars ); - let extraClassDeclaration = [{ - unsigned numUseDevicePtrBlockArgs() { - return getUseDevicePtrVars().size(); - } - }]; - let description = [{ The optional `use_device_ptr_vars` specifies the device pointers to the corresponding list items in the device data environment. diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index c1dae8d543eef..401c4c11d8986 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -284,8 +284,8 @@ def SectionOp : OpenMP_Op<"section", traits = [ // Override BlockArgOpenMPOpInterface methods based on the parent // omp.sections operation. Only forward-declare here because SectionsOp is // not completely defined at this point. - unsigned numPrivateBlockArgs(); - unsigned numReductionBlockArgs(); + OperandRange getPrivateVars(); + OperandRange getReductionVars(); }] # clausesExtraClassDeclaration; let assemblyFormat = "$region attr-dict"; } @@ -824,11 +824,6 @@ def TaskloopOp : OpenMP_Op<"taskloop", traits = [ /// Returns the reduction variables SmallVector getAllReductionVars(); - // Define BlockArgOpenMPOpInterface methods here because they are not - // inherited from the respective clauses. - unsigned numInReductionBlockArgs() { return getInReductionVars().size(); } - unsigned numReductionBlockArgs() { return getReductionVars().size(); } - void getEffects(SmallVectorImpl &effects); }] # clausesExtraClassDeclaration; @@ -1151,6 +1146,14 @@ def TargetDataOp: OpenMP_Op<"target_data", traits = [ OpBuilder<(ins CArg<"const TargetDataOperands &">:$clauses)> ]; + let extraClassDeclaration = [{ + // Override BlockArgOpenMPOpInterface method because `map` clauses have no + // associated entry block arguments in this operation. + unsigned numMapBlockArgs() { + return 0; + } + }] # clausesExtraClassDeclaration; + let assemblyFormat = clausesAssemblyFormat # [{ custom( $region, $use_device_addr_vars, type($use_device_addr_vars), @@ -1185,6 +1188,14 @@ def TargetEnterDataOp: OpenMP_Op<"target_enter_data", traits = [ OpBuilder<(ins CArg<"const TargetEnterExitUpdateDataOperands &">:$clauses)> ]; + let extraClassDeclaration = [{ + // Override BlockArgOpenMPOpInterface method because `map` clauses have no + // associated entry block arguments in this operation. + unsigned numMapBlockArgs() { + return 0; + } + }] # clausesExtraClassDeclaration; + let hasVerifier = 1; } @@ -1213,6 +1224,14 @@ def TargetExitDataOp: OpenMP_Op<"target_exit_data", traits = [ OpBuilder<(ins CArg<"const TargetEnterExitUpdateDataOperands &">:$clauses)> ]; + let extraClassDeclaration = [{ + // Override BlockArgOpenMPOpInterface method because `map` clauses have no + // associated entry block arguments in this operation. + unsigned numMapBlockArgs() { + return 0; + } + }] # clausesExtraClassDeclaration; + let hasVerifier = 1; } @@ -1249,6 +1268,14 @@ def TargetUpdateOp: OpenMP_Op<"target_update", traits = [ OpBuilder<(ins CArg<"const TargetEnterExitUpdateDataOperands &">:$clauses)> ]; + let extraClassDeclaration = [{ + // Override BlockArgOpenMPOpInterface method because `map` clauses have no + // associated entry block arguments in this operation. + unsigned numMapBlockArgs() { + return 0; + } + }] # clausesExtraClassDeclaration; + let hasVerifier = 1; } @@ -1292,8 +1319,6 @@ def TargetOp : OpenMP_Op<"target", traits = [ ]; let extraClassDeclaration = [{ - unsigned numMapBlockArgs() { return getMapVars().size(); } - mlir::Value getMappedValueForPrivateVar(unsigned privVarIdx) { std::optional privateMapIdices = getPrivateMapsAttr(); @@ -1818,6 +1843,14 @@ def DeclareMapperInfoOp : OpenMP_Op<"declare_mapper.info", [ OpBuilder<(ins CArg<"const DeclareMapperInfoOperands &">:$clauses)> ]; + let extraClassDeclaration = [{ + // Override BlockArgOpenMPOpInterface method because `map` clauses have no + // associated entry block arguments in this operation. + unsigned numMapBlockArgs() { + return 0; + } + }] # clausesExtraClassDeclaration; + let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td index 85996368fd946..0766b4e8d1472 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td @@ -23,19 +23,41 @@ include "mlir/IR/OpBase.td" // arguments corresponding to each of these clauses. class BlockArgOpenMPClause { - // Default-implemented method to be overriden by the corresponding clause. + // Default-implemented method, overriden by the corresponding clause. It + // returns the range of operands passed to the operation associated to the + // clause. + // + // For the override to work, the clause tablegen definition must contain a + // `Variadic<...> $clause_name_vars` argument. // // Usage example: // // ```c++ - // auto iface = cast(op); - // unsigned numInReductionArgs = iface.numInReductionBlockArgs(); + // OperandRange reductionVars = op.getReductionVars(); + // ``` + InterfaceMethod varsMethod = InterfaceMethod< + "Get operation operands associated to `" # clauseNameSnake # "`.", + "::mlir::OperandRange", "get" # clauseNameCamel # "Vars", (ins), [{}], [{ + return {0, 0}; + }] + >; + + // It returns the number of entry block arguments introduced by the given + // clause. + // + // By default, it will be the number of operands corresponding to that clause, + // but it can be overriden by operations where this might not be the case + // (e.g. `map` clause in `omp.target_update`). + // + // Usage example: + // + // ```c++ + // unsigned numInReductionArgs = op.numInReductionBlockArgs(); // ``` InterfaceMethod numArgsMethod = InterfaceMethod< "Get number of block arguments defined by `" # clauseNameSnake # "`.", - "unsigned", "num" # clauseNameCamel # "BlockArgs", (ins), [{}], [{ - return 0; - }] + "unsigned", "num" # clauseNameCamel # "BlockArgs", (ins), [{}], + "return $_op." # varsMethod.name # "().size();" >; // Unified access method for the start index of clause-associated entry block @@ -52,7 +74,7 @@ class BlockArgOpenMPClause(*$_op); - }] # "return iface." # previousClause.startMethod.name # "() + $_op." + }] # "return iface." # previousClause.startMethod.name # "() + iface." # previousClause.numArgsMethod.name # "();", "return 0;" ) @@ -72,7 +94,7 @@ class BlockArgOpenMPClause(*$_op); return $_op->getRegion(0).getArguments().slice( - }] # "iface." # startMethod.name # "(), $_op." # numArgsMethod.name # "());" + }] # "iface." # startMethod.name # "(), iface." # numArgsMethod.name # "());" >; } @@ -109,9 +131,26 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> { BlockArgUseDeviceAddrClause, BlockArgUseDevicePtrClause ]; let methods = !listconcat( + !foreach(clause, clauses, clause.varsMethod), !foreach(clause, clauses, clause.numArgsMethod), !foreach(clause, clauses, clause.startMethod), - !foreach(clause, clauses, clause.blockArgsMethod) + !foreach(clause, clauses, clause.blockArgsMethod), + [ + InterfaceMethod< + "Populate a vector of pairs representing the matching between operands " + "and entry block arguments.", "void", "getBlockArgsPairs", + (ins "::llvm::SmallVectorImpl> &" : $pairs), + [{ + auto iface = ::llvm::cast(*$_op); + }] # !interleave(!foreach(clause, clauses, [{ + }] # "if (iface." # clause.numArgsMethod.name # "() > 0) {" # [{ + }] # " for (auto [var, arg] : ::llvm::zip_equal(" # + "iface." # clause.varsMethod.name # "()," # + "iface." # clause.blockArgsMethod.name # "()))" # [{ + pairs.emplace_back(var, arg); + } }]), "\n") + > + ] ); let verify = [{ diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 65a296c5d4829..84b4d30076646 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2248,12 +2248,12 @@ LogicalResult TeamsOp::verify() { // SectionOp //===----------------------------------------------------------------------===// -unsigned SectionOp::numPrivateBlockArgs() { - return getParentOp().numPrivateBlockArgs(); +OperandRange SectionOp::getPrivateVars() { + return getParentOp().getPrivateVars(); } -unsigned SectionOp::numReductionBlockArgs() { - return getParentOp().numReductionBlockArgs(); +OperandRange SectionOp::getReductionVars() { + return getParentOp().getReductionVars(); } //===----------------------------------------------------------------------===//