Skip to content

Commit dd2b1a3

Browse files
shraiyshjeanPerier
authored andcommitted
[MLIR][OpenMP] Fixed the missing inclusive clause in omp.wsloop and fix order clause
This patch adds the inclusive clause (which was missed in previous reorganization - https://reviews.llvm.org/D110903) in omp.wsloop operation. Added a test for validating it. Also fixes the order clause, which was not accepting any values. It now accepts "concurrent" as a value, as specified in the standard. Reviewed By: kiranchandramohan, peixin, clementval Differential Revision: https://reviews.llvm.org/D112198
1 parent db7ae8c commit dd2b1a3

File tree

5 files changed

+188
-22
lines changed

5 files changed

+188
-22
lines changed

llvm/include/llvm/Frontend/OpenMP/OMP.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,11 +285,13 @@ def OMPC_NonTemporal : Clause<"nontemporal"> {
285285
let isValueList = true;
286286
}
287287

288-
def OMP_ORDER_concurrent : ClauseVal<"default",2,0> { let isDefault = 1; }
288+
def OMP_ORDER_concurrent : ClauseVal<"concurrent",1,1> {}
289+
def OMP_ORDER_unknown : ClauseVal<"unknown",2,0> { let isDefault = 1; }
289290
def OMPC_Order : Clause<"order"> {
290291
let clangClass = "OMPOrderClause";
291292
let enumClauseValue = "OrderKind";
292293
let allowedClauseValues = [
294+
OMP_ORDER_unknown,
293295
OMP_ORDER_concurrent
294296
];
295297
}

llvm/unittests/Frontend/OpenMPParsingTest.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,9 @@ TEST(OpenMPParsingTest, isAllowedClauseForDirective) {
5555
}
5656

5757
TEST(OpenMPParsingTest, getOrderKind) {
58-
EXPECT_EQ(getOrderKind("foobar"), OMP_ORDER_concurrent);
59-
EXPECT_EQ(getOrderKind("default"), OMP_ORDER_concurrent);
58+
EXPECT_EQ(getOrderKind("foobar"), OMP_ORDER_unknown);
59+
EXPECT_EQ(getOrderKind("unknown"), OMP_ORDER_unknown);
60+
EXPECT_EQ(getOrderKind("concurrent"), OMP_ORDER_concurrent);
6061
}
6162

6263
TEST(OpenMPParsingTest, getProcBindKind) {

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,6 @@ enum ClauseType {
497497
collapseClause,
498498
orderClause,
499499
orderedClause,
500-
inclusiveClause,
501500
memoryOrderClause,
502501
hintClause,
503502
COUNT
@@ -582,8 +581,7 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
582581
// segments
583582
if (clause == defaultClause || clause == procBindClause ||
584583
clause == nowaitClause || clause == collapseClause ||
585-
clause == orderClause || clause == orderedClause ||
586-
clause == inclusiveClause)
584+
clause == orderClause || clause == orderedClause)
587585
continue;
588586

589587
pos[clause] = currPos++;
@@ -601,7 +599,7 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
601599
bool allowRepeat = false) -> ParseResult {
602600
if (!llvm::is_contained(clauses, clause))
603601
return parser.emitError(parser.getCurrentLocation())
604-
<< clauseKeyword << "is not a valid clause for the " << opName
602+
<< clauseKeyword << " is not a valid clause for the " << opName
605603
<< " operation";
606604
if (done[clause] && !allowRepeat)
607605
return parser.emitError(parser.getCurrentLocation())
@@ -722,12 +720,7 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
722720
parser.parseKeyword(&order) || parser.parseRParen())
723721
return failure();
724722
auto attr = parser.getBuilder().getStringAttr(order);
725-
result.addAttribute("order", attr);
726-
} else if (clauseKeyword == "inclusive") {
727-
if (checkAllowed(inclusiveClause))
728-
return failure();
729-
auto attr = UnitAttr::get(parser.getBuilder().getContext());
730-
result.addAttribute("inclusive", attr);
723+
result.addAttribute("order_val", attr);
731724
} else if (clauseKeyword == "memory_order") {
732725
StringRef memoryOrder;
733726
if (checkAllowed(memoryOrderClause) || parser.parseLParen() ||
@@ -884,11 +877,11 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
884877
///
885878
/// wsloop ::= `omp.wsloop` loop-control clause-list
886879
/// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds
887-
/// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps
880+
/// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
888881
/// steps := `step` `(`ssa-id-list`)`
889882
/// clause-list ::= clause clause-list | empty
890883
/// clause ::= private | firstprivate | lastprivate | linear | schedule |
891-
// collapse | nowait | ordered | order | inclusive | reduction
884+
// collapse | nowait | ordered | order | reduction
892885
static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
893886

894887
// Parse an opening `(` followed by induction variables followed by `)`
@@ -915,6 +908,11 @@ static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
915908
parser.resolveOperands(upper, loopVarType, result.operands))
916909
return failure();
917910

911+
if (succeeded(parser.parseOptionalKeyword("inclusive"))) {
912+
auto attr = UnitAttr::get(parser.getBuilder().getContext());
913+
result.addAttribute("inclusive", attr);
914+
}
915+
918916
// Parse step values.
919917
SmallVector<OpAsmParser::OperandType> steps;
920918
if (parser.parseKeyword("step") ||
@@ -945,7 +943,11 @@ static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
945943
static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
946944
auto args = op.getRegion().front().getArguments();
947945
p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound()
948-
<< ") to (" << op.upperBound() << ") step (" << op.step() << ") ";
946+
<< ") to (" << op.upperBound() << ") ";
947+
if (op.inclusive()) {
948+
p << "inclusive ";
949+
}
950+
p << "step (" << op.step() << ") ";
949951

950952
printDataVars(p, op.private_vars(), "private");
951953
printDataVars(p, op.firstprivate_vars(), "firstprivate");
@@ -972,15 +974,14 @@ static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
972974
if (auto ordered = op.ordered_val())
973975
p << "ordered(" << ordered << ") ";
974976

977+
if (auto order = op.order_val())
978+
p << "order(" << order << ") ";
979+
975980
if (!op.reduction_vars().empty()) {
976981
p << "reduction(";
977982
printReductionVarList(p, op.reductions(), op.reduction_vars());
978983
}
979984

980-
if (op.inclusive()) {
981-
p << "inclusive ";
982-
}
983-
984985
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
985986
}
986987

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,62 @@ func @copyin_once(%n : memref<i32>) {
6969
}
7070

7171
// -----
72-
72+
73+
func @lastprivate_not_allowed(%n : memref<i32>) {
74+
// expected-error@+1 {{lastprivate is not a valid clause for the omp.parallel operation}}
75+
omp.parallel lastprivate(%n : memref<i32>) {}
76+
return
77+
}
78+
79+
// -----
80+
81+
func @nowait_not_allowed(%n : memref<i32>) {
82+
// expected-error@+1 {{nowait is not a valid clause for the omp.parallel operation}}
83+
omp.parallel nowait {}
84+
return
85+
}
86+
87+
// -----
88+
89+
func @linear_not_allowed(%data_var : memref<i32>, %linear_var : i32) {
90+
// expected-error@+1 {{linear is not a valid clause for the omp.parallel operation}}
91+
omp.parallel linear(%data_var = %linear_var : memref<i32>) {}
92+
return
93+
}
94+
95+
// -----
96+
97+
func @schedule_not_allowed() {
98+
// expected-error@+1 {{schedule is not a valid clause for the omp.parallel operation}}
99+
omp.parallel schedule(static) {}
100+
return
101+
}
102+
103+
// -----
104+
105+
func @collapse_not_allowed() {
106+
// expected-error@+1 {{collapse is not a valid clause for the omp.parallel operation}}
107+
omp.parallel collapse(3) {}
108+
return
109+
}
110+
111+
// -----
112+
113+
func @order_not_allowed() {
114+
// expected-error@+1 {{order is not a valid clause for the omp.parallel operation}}
115+
omp.parallel order(concurrent) {}
116+
return
117+
}
118+
119+
// -----
120+
121+
func @ordered_not_allowed() {
122+
// expected-error@+1 {{ordered is not a valid clause for the omp.parallel operation}}
123+
omp.parallel ordered(2) {}
124+
}
125+
126+
// -----
127+
73128
func @default_once() {
74129
// expected-error@+1 {{at most one default clause can appear on the omp.parallel operation}}
75130
omp.parallel default(private) default(firstprivate) {
@@ -90,6 +145,78 @@ func @proc_bind_once() {
90145

91146
// -----
92147

148+
func @inclusive_not_a_clause(%lb : index, %ub : index, %step : index) {
149+
// expected-error @below {{inclusive is not a valid clause}}
150+
omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) nowait inclusive {
151+
omp.yield
152+
}
153+
}
154+
155+
// -----
156+
157+
func @order_value(%lb : index, %ub : index, %step : index) {
158+
// expected-error @below {{attribute 'order_val' failed to satisfy constraint: OrderKind Clause}}
159+
omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) order(default) {
160+
omp.yield
161+
}
162+
}
163+
164+
// -----
165+
166+
func @shared_not_allowed(%lb : index, %ub : index, %step : index, %var : memref<i32>) {
167+
// expected-error @below {{shared is not a valid clause for the omp.wsloop operation}}
168+
omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) shared(%var) {
169+
omp.yield
170+
}
171+
}
172+
173+
// -----
174+
175+
func @copyin(%lb : index, %ub : index, %step : index, %var : memref<i32>) {
176+
// expected-error @below {{copyin is not a valid clause for the omp.wsloop operation}}
177+
omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) copyin(%var) {
178+
omp.yield
179+
}
180+
}
181+
182+
// -----
183+
184+
func @if_not_allowed(%lb : index, %ub : index, %step : index, %bool_var : i1) {
185+
// expected-error @below {{if is not a valid clause for the omp.wsloop operation}}
186+
omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) if(%bool_var: i1) {
187+
omp.yield
188+
}
189+
}
190+
191+
// -----
192+
193+
func @num_threads_not_allowed(%lb : index, %ub : index, %step : index, %int_var : i32) {
194+
// expected-error @below {{num_threads is not a valid clause for the omp.wsloop operation}}
195+
omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) num_threads(%int_var: i32) {
196+
omp.yield
197+
}
198+
}
199+
200+
// -----
201+
202+
func @default_not_allowed(%lb : index, %ub : index, %step : index) {
203+
// expected-error @below {{default is not a valid clause for the omp.wsloop operation}}
204+
omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) default(private) {
205+
omp.yield
206+
}
207+
}
208+
209+
// -----
210+
211+
func @proc_bind_not_allowed(%lb : index, %ub : index, %step : index) {
212+
// expected-error @below {{proc_bind is not a valid clause for the omp.wsloop operation}}
213+
omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) proc_bind(close) {
214+
omp.yield
215+
}
216+
}
217+
218+
// -----
219+
93220
// expected-error @below {{op expects initializer region with one argument of the reduction type}}
94221
omp.reduction.declare @add_f32 : f64
95222
init {

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,27 @@ func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_threads :
123123
omp.terminator
124124
}
125125

126-
return
126+
// CHECK: omp.parallel default(private)
127+
omp.parallel default(private) {
128+
omp.terminator
129+
}
130+
131+
// CHECK: omp.parallel default(firstprivate)
132+
omp.parallel default(firstprivate) {
133+
omp.terminator
134+
}
135+
136+
// CHECK: omp.parallel default(shared)
137+
omp.parallel default(shared) {
138+
omp.terminator
139+
}
140+
141+
// CHECK: omp.parallel default(none)
142+
omp.parallel default(none) {
143+
omp.terminator
144+
}
145+
146+
return
127147
}
128148

129149
// CHECK-LABEL: omp_wsloop
@@ -221,6 +241,21 @@ func @omp_wsloop_pretty(%lb : index, %ub : index, %step : index,
221241
omp.yield
222242
}
223243

244+
// CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) inclusive step (%{{.*}})
245+
omp.wsloop (%iv) : index = (%lb) to (%ub) inclusive step (%step) {
246+
omp.yield
247+
}
248+
249+
// CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) nowait
250+
omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) nowait {
251+
omp.yield
252+
}
253+
254+
// CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) nowait order(concurrent)
255+
omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) order(concurrent) nowait {
256+
omp.yield
257+
}
258+
224259
return
225260
}
226261

0 commit comments

Comments
 (0)