1515
1616include "mlir/IR/OpBase.td"
1717
18+
19+ // Internal class to hold definitions of BlockArgOpenMPOpInterface methods,
20+ // based on the name of the clause and what clause comes earlier in the list.
21+ //
22+ // The clause order will define the expected relative order between block
23+ // arguments corresponding to each of these clauses.
24+ class BlockArgOpenMPClause<string clauseNameSnake, string clauseNameCamel,
25+ BlockArgOpenMPClause previousClause> {
26+ // Default-implemented method to be overriden by the corresponding clause.
27+ InterfaceMethod numArgsMethod = InterfaceMethod<
28+ "Get number of block arguments defined by `" # clauseNameSnake # "`.",
29+ "unsigned", "num" # clauseNameCamel # "BlockArgs", (ins), [{}], [{
30+ return 0;
31+ }]
32+ >;
33+
34+ // Unified access method for the start index of clause-associated entry block
35+ // arguments.
36+ InterfaceMethod startMethod = InterfaceMethod<
37+ "Get start index of block arguments defined by `" # clauseNameSnake # "`.",
38+ "unsigned", "get" # clauseNameCamel # "BlockArgsStart", (ins),
39+ !if(!initialized(previousClause), [{
40+ auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
41+ }] # "return iface." # previousClause.startMethod.name # "() + $_op."
42+ # previousClause.numArgsMethod.name # "();",
43+ "return 0;"
44+ )
45+ >;
46+
47+ // Unified access method for clause-associated entry block arguments.
48+ InterfaceMethod blockArgsMethod = InterfaceMethod<
49+ "Get block arguments defined by `" # clauseNameSnake # "`.",
50+ "::llvm::MutableArrayRef<::mlir::BlockArgument>",
51+ "get" # clauseNameCamel # "BlockArgs", (ins), [{
52+ auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
53+ return $_op->getRegion(0).getArguments().slice(
54+ }] # "iface." # startMethod.name # "(), $_op." # numArgsMethod.name # "());"
55+ >;
56+ }
57+
58+ def BlockArgHostEvalClause : BlockArgOpenMPClause<"host_eval", "HostEval", ?>;
59+ def BlockArgInReductionClause : BlockArgOpenMPClause<
60+ "in_reduction", "InReduction", BlockArgHostEvalClause>;
61+ def BlockArgMapClause : BlockArgOpenMPClause<
62+ "map", "Map", BlockArgInReductionClause>;
63+ def BlockArgPrivateClause : BlockArgOpenMPClause<
64+ "private", "Private", BlockArgMapClause>;
65+ def BlockArgReductionClause : BlockArgOpenMPClause<
66+ "reduction", "Reduction", BlockArgPrivateClause>;
67+ def BlockArgTaskReductionClause : BlockArgOpenMPClause<
68+ "task_reduction", "TaskReduction", BlockArgReductionClause>;
69+ def BlockArgUseDeviceAddrClause : BlockArgOpenMPClause<
70+ "use_device_addr", "UseDeviceAddr", BlockArgTaskReductionClause>;
71+ def BlockArgUseDevicePtrClause : BlockArgOpenMPClause<
72+ "use_device_ptr", "UseDevicePtr", BlockArgUseDeviceAddrClause>;
73+
1874def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
1975 let description = [{
2076 OpenMP operations that define entry block arguments as part of the
@@ -23,153 +79,24 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
2379
2480 let cppNamespace = "::mlir::omp";
2581
26- let methods = [
27- // Default-implemented methods to be overriden by the corresponding clauses.
28- InterfaceMethod<"Get number of block arguments defined by `host_eval`.",
29- "unsigned", "numHostEvalBlockArgs", (ins), [{}], [{
30- return 0;
31- }]>,
32- InterfaceMethod<"Get number of block arguments defined by `in_reduction`.",
33- "unsigned", "numInReductionBlockArgs", (ins), [{}], [{
34- return 0;
35- }]>,
36- InterfaceMethod<"Get number of block arguments defined by `map`.",
37- "unsigned", "numMapBlockArgs", (ins), [{}], [{
38- return 0;
39- }]>,
40- InterfaceMethod<"Get number of block arguments defined by `private`.",
41- "unsigned", "numPrivateBlockArgs", (ins), [{}], [{
42- return 0;
43- }]>,
44- InterfaceMethod<"Get number of block arguments defined by `reduction`.",
45- "unsigned", "numReductionBlockArgs", (ins), [{}], [{
46- return 0;
47- }]>,
48- InterfaceMethod<"Get number of block arguments defined by `task_reduction`.",
49- "unsigned", "numTaskReductionBlockArgs", (ins), [{}], [{
50- return 0;
51- }]>,
52- InterfaceMethod<"Get number of block arguments defined by `use_device_addr`.",
53- "unsigned", "numUseDeviceAddrBlockArgs", (ins), [{}], [{
54- return 0;
55- }]>,
56- InterfaceMethod<"Get number of block arguments defined by `use_device_ptr`.",
57- "unsigned", "numUseDevicePtrBlockArgs", (ins), [{}], [{
58- return 0;
59- }]>,
82+ defvar clauses = [ BlockArgHostEvalClause, BlockArgInReductionClause,
83+ BlockArgMapClause, BlockArgPrivateClause, BlockArgReductionClause,
84+ BlockArgTaskReductionClause, BlockArgUseDeviceAddrClause,
85+ BlockArgUseDevicePtrClause ];
6086
61- // Unified access methods for start indices of clause-associated entry block
62- // arguments.
63- InterfaceMethod<"Get start index of block arguments defined by `host_eval`.",
64- "unsigned", "getHostEvalBlockArgsStart", (ins), [{
65- return 0;
66- }]>,
67- InterfaceMethod<"Get start index of block arguments defined by `in_reduction`.",
68- "unsigned", "getInReductionBlockArgsStart", (ins), [{
69- auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
70- return iface.getHostEvalBlockArgsStart() + $_op.numHostEvalBlockArgs();
71- }]>,
72- InterfaceMethod<"Get start index of block arguments defined by `map`.",
73- "unsigned", "getMapBlockArgsStart", (ins), [{
74- auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
75- return iface.getInReductionBlockArgsStart() +
76- $_op.numInReductionBlockArgs();
77- }]>,
78- InterfaceMethod<"Get start index of block arguments defined by `private`.",
79- "unsigned", "getPrivateBlockArgsStart", (ins), [{
80- auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
81- return iface.getMapBlockArgsStart() + $_op.numMapBlockArgs();
82- }]>,
83- InterfaceMethod<"Get start index of block arguments defined by `reduction`.",
84- "unsigned", "getReductionBlockArgsStart", (ins), [{
85- auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
86- return iface.getPrivateBlockArgsStart() + $_op.numPrivateBlockArgs();
87- }]>,
88- InterfaceMethod<"Get start index of block arguments defined by `task_reduction`.",
89- "unsigned", "getTaskReductionBlockArgsStart", (ins), [{
90- auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
91- return iface.getReductionBlockArgsStart() + $_op.numReductionBlockArgs();
92- }]>,
93- InterfaceMethod<"Get start index of block arguments defined by `use_device_addr`.",
94- "unsigned", "getUseDeviceAddrBlockArgsStart", (ins), [{
95- auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
96- return iface.getTaskReductionBlockArgsStart() + $_op.numTaskReductionBlockArgs();
97- }]>,
98- InterfaceMethod<"Get start index of block arguments defined by `use_device_ptr`.",
99- "unsigned", "getUseDevicePtrBlockArgsStart", (ins), [{
100- auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
101- return iface.getUseDeviceAddrBlockArgsStart() + $_op.numUseDeviceAddrBlockArgs();
102- }]>,
103-
104- // Unified access methods for clause-associated entry block arguments.
105- InterfaceMethod<"Get block arguments defined by `host_eval`.",
106- "::llvm::MutableArrayRef<::mlir::BlockArgument>",
107- "getHostEvalBlockArgs", (ins), [{
108- auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
109- return $_op->getRegion(0).getArguments().slice(
110- iface.getHostEvalBlockArgsStart(), $_op.numHostEvalBlockArgs());
111- }]>,
112- InterfaceMethod<"Get block arguments defined by `in_reduction`.",
113- "::llvm::MutableArrayRef<::mlir::BlockArgument>",
114- "getInReductionBlockArgs", (ins), [{
115- auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
116- return $_op->getRegion(0).getArguments().slice(
117- iface.getInReductionBlockArgsStart(), $_op.numInReductionBlockArgs());
118- }]>,
119- InterfaceMethod<"Get block arguments defined by `map`.",
120- "::llvm::MutableArrayRef<::mlir::BlockArgument>",
121- "getMapBlockArgs", (ins), [{
122- auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
123- return $_op->getRegion(0).getArguments().slice(
124- iface.getMapBlockArgsStart(), $_op.numMapBlockArgs());
125- }]>,
126- InterfaceMethod<"Get block arguments defined by `private`.",
127- "::llvm::MutableArrayRef<::mlir::BlockArgument>",
128- "getPrivateBlockArgs", (ins), [{
129- auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
130- return $_op->getRegion(0).getArguments().slice(
131- iface.getPrivateBlockArgsStart(), $_op.numPrivateBlockArgs());
132- }]>,
133- InterfaceMethod<"Get block arguments defined by `reduction`.",
134- "::llvm::MutableArrayRef<::mlir::BlockArgument>",
135- "getReductionBlockArgs", (ins), [{
136- auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
137- return $_op->getRegion(0).getArguments().slice(
138- iface.getReductionBlockArgsStart(), $_op.numReductionBlockArgs());
139- }]>,
140- InterfaceMethod<"Get block arguments defined by `task_reduction`.",
141- "::llvm::MutableArrayRef<::mlir::BlockArgument>",
142- "getTaskReductionBlockArgs", (ins), [{
143- auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
144- return $_op->getRegion(0).getArguments().slice(
145- iface.getTaskReductionBlockArgsStart(),
146- $_op.numTaskReductionBlockArgs());
147- }]>,
148- InterfaceMethod<"Get block arguments defined by `use_device_addr`.",
149- "::llvm::MutableArrayRef<::mlir::BlockArgument>",
150- "getUseDeviceAddrBlockArgs", (ins), [{
151- auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
152- return $_op->getRegion(0).getArguments().slice(
153- iface.getUseDeviceAddrBlockArgsStart(),
154- $_op.numUseDeviceAddrBlockArgs());
155- }]>,
156- InterfaceMethod<"Get block arguments defined by `use_device_ptr`.",
157- "::llvm::MutableArrayRef<::mlir::BlockArgument>",
158- "getUseDevicePtrBlockArgs", (ins), [{
159- auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
160- return $_op->getRegion(0).getArguments().slice(
161- iface.getUseDevicePtrBlockArgsStart(),
162- $_op.numUseDevicePtrBlockArgs());
163- }]>,
164- ];
87+ let methods = !listconcat(
88+ !foreach(clause, clauses, clause.numArgsMethod),
89+ !foreach(clause, clauses, clause.startMethod),
90+ !foreach(clause, clauses, clause.blockArgsMethod)
91+ );
16592
16693 let verify = [{
16794 auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
168- unsigned expectedArgs = iface.numHostEvalBlockArgs() +
169- iface.numInReductionBlockArgs() + iface.numMapBlockArgs() +
170- iface.numPrivateBlockArgs() + iface.numReductionBlockArgs() +
171- iface.numTaskReductionBlockArgs() + iface.numUseDeviceAddrBlockArgs() +
172- iface.numUseDevicePtrBlockArgs();
95+ }] # " unsigned expectedArgs = "
96+ # !interleave(
97+ !foreach(clause, clauses, " iface." # clause.numArgsMethod.name # "()"),
98+ " + "
99+ ) # ";" # [{
173100 if ($_op->getRegion(0).getNumArguments() < expectedArgs)
174101 return $_op->emitOpError() << "expected at least " << expectedArgs
175102 << " entry block argument(s)";
0 commit comments