diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td index c863e5772032c..ad5580a161568 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td @@ -15,6 +15,83 @@ include "mlir/IR/OpBase.td" + +// Internal class to hold definitions of BlockArgOpenMPOpInterface methods, +// based on the name of the clause and what clause comes earlier in the list. +// +// The clause order will define the expected relative order between block +// arguments corresponding to each of these clauses. +class BlockArgOpenMPClause { + // Default-implemented method to be overriden by the corresponding clause. + // + // Usage example: + // + // ```c++ + // auto iface = cast(op); + // unsigned numInReductionArgs = iface.numInReductionBlockArgs(); + // ``` + InterfaceMethod numArgsMethod = InterfaceMethod< + "Get number of block arguments defined by `" # clauseNameSnake # "`.", + "unsigned", "num" # clauseNameCamel # "BlockArgs", (ins), [{}], [{ + return 0; + }] + >; + + // Unified access method for the start index of clause-associated entry block + // arguments. + // + // Usage example: + // + // ```c++ + // auto iface = cast(op); + // unsigned firstMapIndex = iface.getMapBlockArgsStart(); + // ``` + InterfaceMethod startMethod = InterfaceMethod< + "Get start index of block arguments defined by `" # clauseNameSnake # "`.", + "unsigned", "get" # clauseNameCamel # "BlockArgsStart", (ins), + !if(!initialized(previousClause), [{ + auto iface = ::llvm::cast(*$_op); + }] # "return iface." # previousClause.startMethod.name # "() + $_op." + # previousClause.numArgsMethod.name # "();", + "return 0;" + ) + >; + + // Unified access method for clause-associated entry block arguments. + // + // Usage example: + // + // ```c++ + // auto iface = cast(op); + // ArrayRef reductionArgs = iface.getReductionBlockArgs(); + // ``` + InterfaceMethod blockArgsMethod = InterfaceMethod< + "Get block arguments defined by `" # clauseNameSnake # "`.", + "::llvm::MutableArrayRef<::mlir::BlockArgument>", + "get" # clauseNameCamel # "BlockArgs", (ins), [{ + auto iface = ::llvm::cast(*$_op); + return $_op->getRegion(0).getArguments().slice( + }] # "iface." # startMethod.name # "(), $_op." # numArgsMethod.name # "());" + >; +} + +def BlockArgHostEvalClause : BlockArgOpenMPClause<"host_eval", "HostEval", ?>; +def BlockArgInReductionClause : BlockArgOpenMPClause< + "in_reduction", "InReduction", BlockArgHostEvalClause>; +def BlockArgMapClause : BlockArgOpenMPClause< + "map", "Map", BlockArgInReductionClause>; +def BlockArgPrivateClause : BlockArgOpenMPClause< + "private", "Private", BlockArgMapClause>; +def BlockArgReductionClause : BlockArgOpenMPClause< + "reduction", "Reduction", BlockArgPrivateClause>; +def BlockArgTaskReductionClause : BlockArgOpenMPClause< + "task_reduction", "TaskReduction", BlockArgReductionClause>; +def BlockArgUseDeviceAddrClause : BlockArgOpenMPClause< + "use_device_addr", "UseDeviceAddr", BlockArgTaskReductionClause>; +def BlockArgUseDevicePtrClause : BlockArgOpenMPClause< + "use_device_ptr", "UseDevicePtr", BlockArgUseDeviceAddrClause>; + def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> { let description = [{ OpenMP operations that define entry block arguments as part of the @@ -23,153 +100,24 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> { let cppNamespace = "::mlir::omp"; - let methods = [ - // Default-implemented methods to be overriden by the corresponding clauses. - InterfaceMethod<"Get number of block arguments defined by `host_eval`.", - "unsigned", "numHostEvalBlockArgs", (ins), [{}], [{ - return 0; - }]>, - InterfaceMethod<"Get number of block arguments defined by `in_reduction`.", - "unsigned", "numInReductionBlockArgs", (ins), [{}], [{ - return 0; - }]>, - InterfaceMethod<"Get number of block arguments defined by `map`.", - "unsigned", "numMapBlockArgs", (ins), [{}], [{ - return 0; - }]>, - InterfaceMethod<"Get number of block arguments defined by `private`.", - "unsigned", "numPrivateBlockArgs", (ins), [{}], [{ - return 0; - }]>, - InterfaceMethod<"Get number of block arguments defined by `reduction`.", - "unsigned", "numReductionBlockArgs", (ins), [{}], [{ - return 0; - }]>, - InterfaceMethod<"Get number of block arguments defined by `task_reduction`.", - "unsigned", "numTaskReductionBlockArgs", (ins), [{}], [{ - return 0; - }]>, - InterfaceMethod<"Get number of block arguments defined by `use_device_addr`.", - "unsigned", "numUseDeviceAddrBlockArgs", (ins), [{}], [{ - return 0; - }]>, - InterfaceMethod<"Get number of block arguments defined by `use_device_ptr`.", - "unsigned", "numUseDevicePtrBlockArgs", (ins), [{}], [{ - return 0; - }]>, + defvar clauses = [ BlockArgHostEvalClause, BlockArgInReductionClause, + BlockArgMapClause, BlockArgPrivateClause, BlockArgReductionClause, + BlockArgTaskReductionClause, BlockArgUseDeviceAddrClause, + BlockArgUseDevicePtrClause ]; - // Unified access methods for start indices of clause-associated entry block - // arguments. - InterfaceMethod<"Get start index of block arguments defined by `host_eval`.", - "unsigned", "getHostEvalBlockArgsStart", (ins), [{ - return 0; - }]>, - InterfaceMethod<"Get start index of block arguments defined by `in_reduction`.", - "unsigned", "getInReductionBlockArgsStart", (ins), [{ - auto iface = ::llvm::cast(*$_op); - return iface.getHostEvalBlockArgsStart() + $_op.numHostEvalBlockArgs(); - }]>, - InterfaceMethod<"Get start index of block arguments defined by `map`.", - "unsigned", "getMapBlockArgsStart", (ins), [{ - auto iface = ::llvm::cast(*$_op); - return iface.getInReductionBlockArgsStart() + - $_op.numInReductionBlockArgs(); - }]>, - InterfaceMethod<"Get start index of block arguments defined by `private`.", - "unsigned", "getPrivateBlockArgsStart", (ins), [{ - auto iface = ::llvm::cast(*$_op); - return iface.getMapBlockArgsStart() + $_op.numMapBlockArgs(); - }]>, - InterfaceMethod<"Get start index of block arguments defined by `reduction`.", - "unsigned", "getReductionBlockArgsStart", (ins), [{ - auto iface = ::llvm::cast(*$_op); - return iface.getPrivateBlockArgsStart() + $_op.numPrivateBlockArgs(); - }]>, - InterfaceMethod<"Get start index of block arguments defined by `task_reduction`.", - "unsigned", "getTaskReductionBlockArgsStart", (ins), [{ - auto iface = ::llvm::cast(*$_op); - return iface.getReductionBlockArgsStart() + $_op.numReductionBlockArgs(); - }]>, - InterfaceMethod<"Get start index of block arguments defined by `use_device_addr`.", - "unsigned", "getUseDeviceAddrBlockArgsStart", (ins), [{ - auto iface = ::llvm::cast(*$_op); - return iface.getTaskReductionBlockArgsStart() + $_op.numTaskReductionBlockArgs(); - }]>, - InterfaceMethod<"Get start index of block arguments defined by `use_device_ptr`.", - "unsigned", "getUseDevicePtrBlockArgsStart", (ins), [{ - auto iface = ::llvm::cast(*$_op); - return iface.getUseDeviceAddrBlockArgsStart() + $_op.numUseDeviceAddrBlockArgs(); - }]>, - - // Unified access methods for clause-associated entry block arguments. - InterfaceMethod<"Get block arguments defined by `host_eval`.", - "::llvm::MutableArrayRef<::mlir::BlockArgument>", - "getHostEvalBlockArgs", (ins), [{ - auto iface = ::llvm::cast(*$_op); - return $_op->getRegion(0).getArguments().slice( - iface.getHostEvalBlockArgsStart(), $_op.numHostEvalBlockArgs()); - }]>, - InterfaceMethod<"Get block arguments defined by `in_reduction`.", - "::llvm::MutableArrayRef<::mlir::BlockArgument>", - "getInReductionBlockArgs", (ins), [{ - auto iface = ::llvm::cast(*$_op); - return $_op->getRegion(0).getArguments().slice( - iface.getInReductionBlockArgsStart(), $_op.numInReductionBlockArgs()); - }]>, - InterfaceMethod<"Get block arguments defined by `map`.", - "::llvm::MutableArrayRef<::mlir::BlockArgument>", - "getMapBlockArgs", (ins), [{ - auto iface = ::llvm::cast(*$_op); - return $_op->getRegion(0).getArguments().slice( - iface.getMapBlockArgsStart(), $_op.numMapBlockArgs()); - }]>, - InterfaceMethod<"Get block arguments defined by `private`.", - "::llvm::MutableArrayRef<::mlir::BlockArgument>", - "getPrivateBlockArgs", (ins), [{ - auto iface = ::llvm::cast(*$_op); - return $_op->getRegion(0).getArguments().slice( - iface.getPrivateBlockArgsStart(), $_op.numPrivateBlockArgs()); - }]>, - InterfaceMethod<"Get block arguments defined by `reduction`.", - "::llvm::MutableArrayRef<::mlir::BlockArgument>", - "getReductionBlockArgs", (ins), [{ - auto iface = ::llvm::cast(*$_op); - return $_op->getRegion(0).getArguments().slice( - iface.getReductionBlockArgsStart(), $_op.numReductionBlockArgs()); - }]>, - InterfaceMethod<"Get block arguments defined by `task_reduction`.", - "::llvm::MutableArrayRef<::mlir::BlockArgument>", - "getTaskReductionBlockArgs", (ins), [{ - auto iface = ::llvm::cast(*$_op); - return $_op->getRegion(0).getArguments().slice( - iface.getTaskReductionBlockArgsStart(), - $_op.numTaskReductionBlockArgs()); - }]>, - InterfaceMethod<"Get block arguments defined by `use_device_addr`.", - "::llvm::MutableArrayRef<::mlir::BlockArgument>", - "getUseDeviceAddrBlockArgs", (ins), [{ - auto iface = ::llvm::cast(*$_op); - return $_op->getRegion(0).getArguments().slice( - iface.getUseDeviceAddrBlockArgsStart(), - $_op.numUseDeviceAddrBlockArgs()); - }]>, - InterfaceMethod<"Get block arguments defined by `use_device_ptr`.", - "::llvm::MutableArrayRef<::mlir::BlockArgument>", - "getUseDevicePtrBlockArgs", (ins), [{ - auto iface = ::llvm::cast(*$_op); - return $_op->getRegion(0).getArguments().slice( - iface.getUseDevicePtrBlockArgsStart(), - $_op.numUseDevicePtrBlockArgs()); - }]>, - ]; + let methods = !listconcat( + !foreach(clause, clauses, clause.numArgsMethod), + !foreach(clause, clauses, clause.startMethod), + !foreach(clause, clauses, clause.blockArgsMethod) + ); let verify = [{ auto iface = ::llvm::cast($_op); - unsigned expectedArgs = iface.numHostEvalBlockArgs() + - iface.numInReductionBlockArgs() + iface.numMapBlockArgs() + - iface.numPrivateBlockArgs() + iface.numReductionBlockArgs() + - iface.numTaskReductionBlockArgs() + iface.numUseDeviceAddrBlockArgs() + - iface.numUseDevicePtrBlockArgs(); + }] # "unsigned expectedArgs = " + # !interleave( + !foreach(clause, clauses, "iface." # clause.numArgsMethod.name # "()"), + " + " + ) # ";" # [{ if ($_op->getRegion(0).getNumArguments() < expectedArgs) return $_op->emitOpError() << "expected at least " << expectedArgs << " entry block argument(s)";