Skip to content

Commit 873c7af

Browse files
committed
[MLIR][OpenMP] Simplify definition of the BlockArgOpenMPOpInterface, NFC
This patch removes code duplication from the definition of methods of the `BlockArgOpenMPOpInterface` and makes the order relationship between entry block argument generating clauses explicit. The goal of this change is to make the addition of clauses and methods to the interface less error-prone.
1 parent 5e26fb1 commit 873c7af

File tree

1 file changed

+70
-143
lines changed

1 file changed

+70
-143
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td

Lines changed: 70 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,62 @@
1515

1616
include "mlir/IR/OpBase.td"
1717

18+
19+
// Internal class to hold definitions of BlockArgOpenMPOpInterface methods,
20+
// based on the name of the clause and what clause comes earlier in the list.
21+
//
22+
// The clause order will define the expected relative order between block
23+
// arguments corresponding to each of these clauses.
24+
class BlockArgOpenMPClause<string clauseNameSnake, string clauseNameCamel,
25+
BlockArgOpenMPClause previousClause> {
26+
// Default-implemented method to be overriden by the corresponding clause.
27+
InterfaceMethod numArgsMethod = InterfaceMethod<
28+
"Get number of block arguments defined by `" # clauseNameSnake # "`.",
29+
"unsigned", "num" # clauseNameCamel # "BlockArgs", (ins), [{}], [{
30+
return 0;
31+
}]
32+
>;
33+
34+
// Unified access method for the start index of clause-associated entry block
35+
// arguments.
36+
InterfaceMethod startMethod = InterfaceMethod<
37+
"Get start index of block arguments defined by `" # clauseNameSnake # "`.",
38+
"unsigned", "get" # clauseNameCamel # "BlockArgsStart", (ins),
39+
!if(!initialized(previousClause), [{
40+
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
41+
}] # "return iface." # previousClause.startMethod.name # "() + $_op."
42+
# previousClause.numArgsMethod.name # "();",
43+
"return 0;"
44+
)
45+
>;
46+
47+
// Unified access method for clause-associated entry block arguments.
48+
InterfaceMethod blockArgsMethod = InterfaceMethod<
49+
"Get block arguments defined by `" # clauseNameSnake # "`.",
50+
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
51+
"get" # clauseNameCamel # "BlockArgs", (ins), [{
52+
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
53+
return $_op->getRegion(0).getArguments().slice(
54+
}] # "iface." # startMethod.name # "(), $_op." # numArgsMethod.name # "());"
55+
>;
56+
}
57+
58+
def BlockArgHostEvalClause : BlockArgOpenMPClause<"host_eval", "HostEval", ?>;
59+
def BlockArgInReductionClause : BlockArgOpenMPClause<
60+
"in_reduction", "InReduction", BlockArgHostEvalClause>;
61+
def BlockArgMapClause : BlockArgOpenMPClause<
62+
"map", "Map", BlockArgInReductionClause>;
63+
def BlockArgPrivateClause : BlockArgOpenMPClause<
64+
"private", "Private", BlockArgMapClause>;
65+
def BlockArgReductionClause : BlockArgOpenMPClause<
66+
"reduction", "Reduction", BlockArgPrivateClause>;
67+
def BlockArgTaskReductionClause : BlockArgOpenMPClause<
68+
"task_reduction", "TaskReduction", BlockArgReductionClause>;
69+
def BlockArgUseDeviceAddrClause : BlockArgOpenMPClause<
70+
"use_device_addr", "UseDeviceAddr", BlockArgTaskReductionClause>;
71+
def BlockArgUseDevicePtrClause : BlockArgOpenMPClause<
72+
"use_device_ptr", "UseDevicePtr", BlockArgUseDeviceAddrClause>;
73+
1874
def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
1975
let description = [{
2076
OpenMP operations that define entry block arguments as part of the
@@ -23,153 +79,24 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
2379

2480
let cppNamespace = "::mlir::omp";
2581

26-
let methods = [
27-
// Default-implemented methods to be overriden by the corresponding clauses.
28-
InterfaceMethod<"Get number of block arguments defined by `host_eval`.",
29-
"unsigned", "numHostEvalBlockArgs", (ins), [{}], [{
30-
return 0;
31-
}]>,
32-
InterfaceMethod<"Get number of block arguments defined by `in_reduction`.",
33-
"unsigned", "numInReductionBlockArgs", (ins), [{}], [{
34-
return 0;
35-
}]>,
36-
InterfaceMethod<"Get number of block arguments defined by `map`.",
37-
"unsigned", "numMapBlockArgs", (ins), [{}], [{
38-
return 0;
39-
}]>,
40-
InterfaceMethod<"Get number of block arguments defined by `private`.",
41-
"unsigned", "numPrivateBlockArgs", (ins), [{}], [{
42-
return 0;
43-
}]>,
44-
InterfaceMethod<"Get number of block arguments defined by `reduction`.",
45-
"unsigned", "numReductionBlockArgs", (ins), [{}], [{
46-
return 0;
47-
}]>,
48-
InterfaceMethod<"Get number of block arguments defined by `task_reduction`.",
49-
"unsigned", "numTaskReductionBlockArgs", (ins), [{}], [{
50-
return 0;
51-
}]>,
52-
InterfaceMethod<"Get number of block arguments defined by `use_device_addr`.",
53-
"unsigned", "numUseDeviceAddrBlockArgs", (ins), [{}], [{
54-
return 0;
55-
}]>,
56-
InterfaceMethod<"Get number of block arguments defined by `use_device_ptr`.",
57-
"unsigned", "numUseDevicePtrBlockArgs", (ins), [{}], [{
58-
return 0;
59-
}]>,
82+
defvar clauses = [ BlockArgHostEvalClause, BlockArgInReductionClause,
83+
BlockArgMapClause, BlockArgPrivateClause, BlockArgReductionClause,
84+
BlockArgTaskReductionClause, BlockArgUseDeviceAddrClause,
85+
BlockArgUseDevicePtrClause ];
6086

61-
// Unified access methods for start indices of clause-associated entry block
62-
// arguments.
63-
InterfaceMethod<"Get start index of block arguments defined by `host_eval`.",
64-
"unsigned", "getHostEvalBlockArgsStart", (ins), [{
65-
return 0;
66-
}]>,
67-
InterfaceMethod<"Get start index of block arguments defined by `in_reduction`.",
68-
"unsigned", "getInReductionBlockArgsStart", (ins), [{
69-
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
70-
return iface.getHostEvalBlockArgsStart() + $_op.numHostEvalBlockArgs();
71-
}]>,
72-
InterfaceMethod<"Get start index of block arguments defined by `map`.",
73-
"unsigned", "getMapBlockArgsStart", (ins), [{
74-
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
75-
return iface.getInReductionBlockArgsStart() +
76-
$_op.numInReductionBlockArgs();
77-
}]>,
78-
InterfaceMethod<"Get start index of block arguments defined by `private`.",
79-
"unsigned", "getPrivateBlockArgsStart", (ins), [{
80-
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
81-
return iface.getMapBlockArgsStart() + $_op.numMapBlockArgs();
82-
}]>,
83-
InterfaceMethod<"Get start index of block arguments defined by `reduction`.",
84-
"unsigned", "getReductionBlockArgsStart", (ins), [{
85-
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
86-
return iface.getPrivateBlockArgsStart() + $_op.numPrivateBlockArgs();
87-
}]>,
88-
InterfaceMethod<"Get start index of block arguments defined by `task_reduction`.",
89-
"unsigned", "getTaskReductionBlockArgsStart", (ins), [{
90-
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
91-
return iface.getReductionBlockArgsStart() + $_op.numReductionBlockArgs();
92-
}]>,
93-
InterfaceMethod<"Get start index of block arguments defined by `use_device_addr`.",
94-
"unsigned", "getUseDeviceAddrBlockArgsStart", (ins), [{
95-
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
96-
return iface.getTaskReductionBlockArgsStart() + $_op.numTaskReductionBlockArgs();
97-
}]>,
98-
InterfaceMethod<"Get start index of block arguments defined by `use_device_ptr`.",
99-
"unsigned", "getUseDevicePtrBlockArgsStart", (ins), [{
100-
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
101-
return iface.getUseDeviceAddrBlockArgsStart() + $_op.numUseDeviceAddrBlockArgs();
102-
}]>,
103-
104-
// Unified access methods for clause-associated entry block arguments.
105-
InterfaceMethod<"Get block arguments defined by `host_eval`.",
106-
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
107-
"getHostEvalBlockArgs", (ins), [{
108-
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
109-
return $_op->getRegion(0).getArguments().slice(
110-
iface.getHostEvalBlockArgsStart(), $_op.numHostEvalBlockArgs());
111-
}]>,
112-
InterfaceMethod<"Get block arguments defined by `in_reduction`.",
113-
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
114-
"getInReductionBlockArgs", (ins), [{
115-
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
116-
return $_op->getRegion(0).getArguments().slice(
117-
iface.getInReductionBlockArgsStart(), $_op.numInReductionBlockArgs());
118-
}]>,
119-
InterfaceMethod<"Get block arguments defined by `map`.",
120-
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
121-
"getMapBlockArgs", (ins), [{
122-
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
123-
return $_op->getRegion(0).getArguments().slice(
124-
iface.getMapBlockArgsStart(), $_op.numMapBlockArgs());
125-
}]>,
126-
InterfaceMethod<"Get block arguments defined by `private`.",
127-
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
128-
"getPrivateBlockArgs", (ins), [{
129-
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
130-
return $_op->getRegion(0).getArguments().slice(
131-
iface.getPrivateBlockArgsStart(), $_op.numPrivateBlockArgs());
132-
}]>,
133-
InterfaceMethod<"Get block arguments defined by `reduction`.",
134-
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
135-
"getReductionBlockArgs", (ins), [{
136-
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
137-
return $_op->getRegion(0).getArguments().slice(
138-
iface.getReductionBlockArgsStart(), $_op.numReductionBlockArgs());
139-
}]>,
140-
InterfaceMethod<"Get block arguments defined by `task_reduction`.",
141-
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
142-
"getTaskReductionBlockArgs", (ins), [{
143-
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
144-
return $_op->getRegion(0).getArguments().slice(
145-
iface.getTaskReductionBlockArgsStart(),
146-
$_op.numTaskReductionBlockArgs());
147-
}]>,
148-
InterfaceMethod<"Get block arguments defined by `use_device_addr`.",
149-
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
150-
"getUseDeviceAddrBlockArgs", (ins), [{
151-
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
152-
return $_op->getRegion(0).getArguments().slice(
153-
iface.getUseDeviceAddrBlockArgsStart(),
154-
$_op.numUseDeviceAddrBlockArgs());
155-
}]>,
156-
InterfaceMethod<"Get block arguments defined by `use_device_ptr`.",
157-
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
158-
"getUseDevicePtrBlockArgs", (ins), [{
159-
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
160-
return $_op->getRegion(0).getArguments().slice(
161-
iface.getUseDevicePtrBlockArgsStart(),
162-
$_op.numUseDevicePtrBlockArgs());
163-
}]>,
164-
];
87+
let methods = !listconcat(
88+
!foreach(clause, clauses, clause.numArgsMethod),
89+
!foreach(clause, clauses, clause.startMethod),
90+
!foreach(clause, clauses, clause.blockArgsMethod)
91+
);
16592

16693
let verify = [{
16794
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
168-
unsigned expectedArgs = iface.numHostEvalBlockArgs() +
169-
iface.numInReductionBlockArgs() + iface.numMapBlockArgs() +
170-
iface.numPrivateBlockArgs() + iface.numReductionBlockArgs() +
171-
iface.numTaskReductionBlockArgs() + iface.numUseDeviceAddrBlockArgs() +
172-
iface.numUseDevicePtrBlockArgs();
95+
}] # "unsigned expectedArgs = "
96+
# !interleave(
97+
!foreach(clause, clauses, "iface." # clause.numArgsMethod.name # "()"),
98+
" + "
99+
) # ";" # [{
173100
if ($_op->getRegion(0).getNumArguments() < expectedArgs)
174101
return $_op->emitOpError() << "expected at least " << expectedArgs
175102
<< " entry block argument(s)";

0 commit comments

Comments
 (0)