Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 5 additions & 39 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
Original file line number Diff line number Diff line change
Expand Up @@ -470,12 +470,6 @@ class OpenMP_HasDeviceAddrClauseSkip<
Variadic<OpenMP_PointerLikeType>:$has_device_addr_vars
);

let extraClassDeclaration = [{
unsigned numHasDeviceAddrBlockArgs() {
return getHasDeviceAddrVars().size();
}
}];

let description = [{
The optional `has_device_addr_vars` indicates that list items already have
device addresses, so they may be directly accessed from the target device.
Expand Down Expand Up @@ -565,12 +559,6 @@ class OpenMP_HostEvalClauseSkip<
Variadic<AnyType>:$host_eval_vars
);

let extraClassDeclaration = [{
unsigned numHostEvalBlockArgs() {
return getHostEvalVars().size();
}
}];

let description = [{
The optional `host_eval_vars` holds values defined outside of the region of
the `IsolatedFromAbove` operation for which a corresponding entry block
Expand Down Expand Up @@ -629,12 +617,10 @@ class OpenMP_InReductionClauseSkip<

let extraClassDeclaration = [{
/// Returns the reduction variables.
SmallVector<Value> getReductionVars() {
SmallVector<Value> getAllReductionVars() {
return SmallVector<Value>(getInReductionVars().begin(),
getInReductionVars().end());
}

unsigned numInReductionBlockArgs() { return getInReductionVars().size(); }
}];

// Description varies depending on the operation. Assembly format not defined
Expand Down Expand Up @@ -749,6 +735,9 @@ class OpenMP_MapClauseSkip<
Variadic<OpenMP_PointerLikeType>:$map_vars
);

// This assembly format should only be used by operations where `map` does not
// define entry block arguments. Otherwise, it must be printed and parsed
// together with the corresponding region.
let optAssemblyFormat = [{
`map_entries` `(` $map_vars `:` type($map_vars) `)`
}];
Expand Down Expand Up @@ -1060,8 +1049,6 @@ class OpenMP_DetachClauseSkip<
: OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {

let traits = [BlockArgOpenMPOpInterface];

let arguments = (ins Optional<OpenMP_PointerLikeType>:$event_handle);

let optAssemblyFormat = [{
Expand Down Expand Up @@ -1126,10 +1113,6 @@ class OpenMP_PrivateClauseSkip<
OptionalAttr<SymbolRefArrayAttr>:$private_syms
);

let extraClassDeclaration = [{
unsigned numPrivateBlockArgs() { return getPrivateVars().size(); }
}];

// TODO: Add description.
// Assembly format not defined because this clause must be processed together
// with the first region of the operation, as it defines entry block
Expand Down Expand Up @@ -1186,7 +1169,6 @@ class OpenMP_ReductionClauseSkip<
let extraClassDeclaration = [{
/// Returns the number of reduction variables.
unsigned getNumReductionVars() { return getReductionVars().size(); }
unsigned numReductionBlockArgs() { return getReductionVars().size(); }
}];

// Description varies depending on the operation.
Expand Down Expand Up @@ -1316,14 +1298,10 @@ class OpenMP_TaskReductionClauseSkip<

let extraClassDeclaration = [{
/// Returns the reduction variables.
SmallVector<Value> getReductionVars() {
SmallVector<Value> getAllReductionVars() {
return SmallVector<Value>(getTaskReductionVars().begin(),
getTaskReductionVars().end());
}

unsigned numTaskReductionBlockArgs() {
return getTaskReductionVars().size();
}
}];

let description = [{
Expand Down Expand Up @@ -1413,12 +1391,6 @@ class OpenMP_UseDeviceAddrClauseSkip<
Variadic<OpenMP_PointerLikeType>:$use_device_addr_vars
);

let extraClassDeclaration = [{
unsigned numUseDeviceAddrBlockArgs() {
return getUseDeviceAddrVars().size();
}
}];

let description = [{
The optional `use_device_addr_vars` specifies the address of the objects in
the device data environment.
Expand Down Expand Up @@ -1448,12 +1420,6 @@ class OpenMP_UseDevicePtrClauseSkip<
Variadic<OpenMP_PointerLikeType>:$use_device_ptr_vars
);

let extraClassDeclaration = [{
unsigned numUseDevicePtrBlockArgs() {
return getUseDevicePtrVars().size();
}
}];

let description = [{
The optional `use_device_ptr_vars` specifies the device pointers to the
corresponding list items in the device data environment.
Expand Down
51 changes: 42 additions & 9 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,8 @@ def SectionOp : OpenMP_Op<"section", traits = [
// Override BlockArgOpenMPOpInterface methods based on the parent
// omp.sections operation. Only forward-declare here because SectionsOp is
// not completely defined at this point.
unsigned numPrivateBlockArgs();
unsigned numReductionBlockArgs();
OperandRange getPrivateVars();
OperandRange getReductionVars();
}] # clausesExtraClassDeclaration;
let assemblyFormat = "$region attr-dict";
}
Expand Down Expand Up @@ -824,11 +824,6 @@ def TaskloopOp : OpenMP_Op<"taskloop", traits = [
/// Returns the reduction variables
SmallVector<Value> getAllReductionVars();

// Define BlockArgOpenMPOpInterface methods here because they are not
// inherited from the respective clauses.
unsigned numInReductionBlockArgs() { return getInReductionVars().size(); }
unsigned numReductionBlockArgs() { return getReductionVars().size(); }

void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
}] # clausesExtraClassDeclaration;

Expand Down Expand Up @@ -1151,6 +1146,14 @@ def TargetDataOp: OpenMP_Op<"target_data", traits = [
OpBuilder<(ins CArg<"const TargetDataOperands &">:$clauses)>
];

let extraClassDeclaration = [{
// Override BlockArgOpenMPOpInterface method because `map` clauses have no
// associated entry block arguments in this operation.
unsigned numMapBlockArgs() {
return 0;
}
}] # clausesExtraClassDeclaration;

let assemblyFormat = clausesAssemblyFormat # [{
custom<UseDeviceAddrUseDevicePtrRegion>(
$region, $use_device_addr_vars, type($use_device_addr_vars),
Expand Down Expand Up @@ -1185,6 +1188,14 @@ def TargetEnterDataOp: OpenMP_Op<"target_enter_data", traits = [
OpBuilder<(ins CArg<"const TargetEnterExitUpdateDataOperands &">:$clauses)>
];

let extraClassDeclaration = [{
// Override BlockArgOpenMPOpInterface method because `map` clauses have no
// associated entry block arguments in this operation.
unsigned numMapBlockArgs() {
return 0;
}
}] # clausesExtraClassDeclaration;

let hasVerifier = 1;
}

Expand Down Expand Up @@ -1213,6 +1224,14 @@ def TargetExitDataOp: OpenMP_Op<"target_exit_data", traits = [
OpBuilder<(ins CArg<"const TargetEnterExitUpdateDataOperands &">:$clauses)>
];

let extraClassDeclaration = [{
// Override BlockArgOpenMPOpInterface method because `map` clauses have no
// associated entry block arguments in this operation.
unsigned numMapBlockArgs() {
return 0;
}
}] # clausesExtraClassDeclaration;

let hasVerifier = 1;
}

Expand Down Expand Up @@ -1249,6 +1268,14 @@ def TargetUpdateOp: OpenMP_Op<"target_update", traits = [
OpBuilder<(ins CArg<"const TargetEnterExitUpdateDataOperands &">:$clauses)>
];

let extraClassDeclaration = [{
// Override BlockArgOpenMPOpInterface method because `map` clauses have no
// associated entry block arguments in this operation.
unsigned numMapBlockArgs() {
return 0;
}
}] # clausesExtraClassDeclaration;

let hasVerifier = 1;
}

Expand Down Expand Up @@ -1292,8 +1319,6 @@ def TargetOp : OpenMP_Op<"target", traits = [
];

let extraClassDeclaration = [{
unsigned numMapBlockArgs() { return getMapVars().size(); }

mlir::Value getMappedValueForPrivateVar(unsigned privVarIdx) {
std::optional<DenseI64ArrayAttr> privateMapIdices = getPrivateMapsAttr();

Expand Down Expand Up @@ -1818,6 +1843,14 @@ def DeclareMapperInfoOp : OpenMP_Op<"declare_mapper.info", [
OpBuilder<(ins CArg<"const DeclareMapperInfoOperands &">:$clauses)>
];

let extraClassDeclaration = [{
// Override BlockArgOpenMPOpInterface method because `map` clauses have no
// associated entry block arguments in this operation.
unsigned numMapBlockArgs() {
return 0;
}
}] # clausesExtraClassDeclaration;

let hasVerifier = 1;
}

Expand Down
57 changes: 48 additions & 9 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,41 @@ include "mlir/IR/OpBase.td"
// arguments corresponding to each of these clauses.
class BlockArgOpenMPClause<string clauseNameSnake, string clauseNameCamel,
BlockArgOpenMPClause previousClause> {
// Default-implemented method to be overriden by the corresponding clause.
// Default-implemented method, overriden by the corresponding clause. It
// returns the range of operands passed to the operation associated to the
// clause.
//
// For the override to work, the clause tablegen definition must contain a
// `Variadic<...> $clause_name_vars` argument.
//
// Usage example:
//
// ```c++
// auto iface = cast<BlockArgOpenMPOpInterface>(op);
// unsigned numInReductionArgs = iface.numInReductionBlockArgs();
// OperandRange reductionVars = op.getReductionVars();
// ```
InterfaceMethod varsMethod = InterfaceMethod<
"Get operation operands associated to `" # clauseNameSnake # "`.",
"::mlir::OperandRange", "get" # clauseNameCamel # "Vars", (ins), [{}], [{
return {0, 0};
}]
>;

// It returns the number of entry block arguments introduced by the given
// clause.
//
// By default, it will be the number of operands corresponding to that clause,
// but it can be overriden by operations where this might not be the case
// (e.g. `map` clause in `omp.target_update`).
//
// Usage example:
//
// ```c++
// unsigned numInReductionArgs = op.numInReductionBlockArgs();
// ```
InterfaceMethod numArgsMethod = InterfaceMethod<
"Get number of block arguments defined by `" # clauseNameSnake # "`.",
"unsigned", "num" # clauseNameCamel # "BlockArgs", (ins), [{}], [{
return 0;
}]
"unsigned", "num" # clauseNameCamel # "BlockArgs", (ins), [{}],
"return $_op." # varsMethod.name # "().size();"
>;

// Unified access method for the start index of clause-associated entry block
Expand All @@ -52,7 +74,7 @@ class BlockArgOpenMPClause<string clauseNameSnake, string clauseNameCamel,
"unsigned", "get" # clauseNameCamel # "BlockArgsStart", (ins),
!if(!initialized(previousClause), [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
}] # "return iface." # previousClause.startMethod.name # "() + $_op."
}] # "return iface." # previousClause.startMethod.name # "() + iface."
# previousClause.numArgsMethod.name # "();",
"return 0;"
)
Expand All @@ -72,7 +94,7 @@ class BlockArgOpenMPClause<string clauseNameSnake, string clauseNameCamel,
"get" # clauseNameCamel # "BlockArgs", (ins), [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
return $_op->getRegion(0).getArguments().slice(
}] # "iface." # startMethod.name # "(), $_op." # numArgsMethod.name # "());"
}] # "iface." # startMethod.name # "(), iface." # numArgsMethod.name # "());"
>;
}

Expand Down Expand Up @@ -109,9 +131,26 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
BlockArgUseDeviceAddrClause, BlockArgUseDevicePtrClause ];

let methods = !listconcat(
!foreach(clause, clauses, clause.varsMethod),
!foreach(clause, clauses, clause.numArgsMethod),
!foreach(clause, clauses, clause.startMethod),
!foreach(clause, clauses, clause.blockArgsMethod)
!foreach(clause, clauses, clause.blockArgsMethod),
[
InterfaceMethod<
"Populate a vector of pairs representing the matching between operands "
"and entry block arguments.", "void", "getBlockArgsPairs",
(ins "::llvm::SmallVectorImpl<std::pair<::mlir::Value, ::mlir::BlockArgument>> &" : $pairs),
[{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
}] # !interleave(!foreach(clause, clauses, [{
}] # "if (iface." # clause.numArgsMethod.name # "() > 0) {" # [{
}] # " for (auto [var, arg] : ::llvm::zip_equal(" #
"iface." # clause.varsMethod.name # "()," #
"iface." # clause.blockArgsMethod.name # "()))" # [{
pairs.emplace_back(var, arg);
} }]), "\n")
>
]
);

let verify = [{
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2248,12 +2248,12 @@ LogicalResult TeamsOp::verify() {
// SectionOp
//===----------------------------------------------------------------------===//

unsigned SectionOp::numPrivateBlockArgs() {
return getParentOp().numPrivateBlockArgs();
OperandRange SectionOp::getPrivateVars() {
return getParentOp().getPrivateVars();
}

unsigned SectionOp::numReductionBlockArgs() {
return getParentOp().numReductionBlockArgs();
OperandRange SectionOp::getReductionVars() {
return getParentOp().getReductionVars();
}

//===----------------------------------------------------------------------===//
Expand Down