Skip to content

Commit 995fc53

Browse files
committed
Address review comments
1 parent 7820ef5 commit 995fc53

File tree

10 files changed

+120
-104
lines changed

10 files changed

+120
-104
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,25 @@ bool ClauseProcessor::processCancelDirectiveName(
271271
return true;
272272
}
273273

274-
bool ClauseProcessor::processCollapse(
274+
bool ClauseProcessor::processLoopNests(
275275
mlir::Location currentLocation, lower::pft::Evaluation &eval,
276276
mlir::omp::LoopRelatedClauseOps &result,
277277
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const {
278-
return collectLoopRelatedInfo(converter, currentLocation, eval, clauses,
279-
result, iv);
278+
int64_t numCollapse = collectLoopRelatedInfo(converter, currentLocation, eval,
279+
clauses, result, iv);
280+
return numCollapse > 1;
281+
}
282+
283+
bool ClauseProcessor::processCollapse(
284+
mlir::Location currentLocation, lower::pft::Evaluation &eval,
285+
mlir::omp::LoopNestOperands &result,
286+
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const {
287+
288+
int64_t numCollapse = collectLoopRelatedInfo(converter, currentLocation, eval,
289+
clauses, result, iv);
290+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
291+
result.collapseNumLoops = firOpBuilder.getI64IntegerAttr(numCollapse);
292+
return numCollapse > 1;
280293
}
281294

282295
bool ClauseProcessor::processDevice(lower::StatementContext &stmtCtx,
@@ -522,6 +535,19 @@ bool ClauseProcessor::processProcBind(
522535
return false;
523536
}
524537

538+
bool ClauseProcessor::processTileSizes(
539+
lower::pft::Evaluation &eval, mlir::omp::LoopNestOperands &result) const {
540+
bool found = false;
541+
llvm::SmallVector<int64_t> sizeValues;
542+
auto *ompCons{eval.getIf<parser::OpenMPConstruct>()};
543+
collectTileSizesFromOpenMPConstruct(ompCons, sizeValues, semaCtx);
544+
if (sizeValues.size() > 0) {
545+
found = true;
546+
result.tileSizes = sizeValues;
547+
}
548+
return found;
549+
}
550+
525551
bool ClauseProcessor::processSafelen(
526552
mlir::omp::SafelenClauseOps &result) const {
527553
if (auto *clause = findUniqueClause<omp::clause::Safelen>()) {

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,12 @@ class ClauseProcessor {
6363
mlir::omp::CancelDirectiveNameClauseOps &result) const;
6464
bool
6565
processCollapse(mlir::Location currentLocation, lower::pft::Evaluation &eval,
66-
mlir::omp::LoopRelatedClauseOps &result,
66+
mlir::omp::LoopNestOperands &result,
6767
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const;
68+
bool
69+
processLoopNests(mlir::Location currentLocation, lower::pft::Evaluation &eval,
70+
mlir::omp::LoopRelatedClauseOps &result,
71+
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const;
6872
bool processDevice(lower::StatementContext &stmtCtx,
6973
mlir::omp::DeviceClauseOps &result) const;
7074
bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
@@ -98,6 +102,8 @@ class ClauseProcessor {
98102
bool processPriority(lower::StatementContext &stmtCtx,
99103
mlir::omp::PriorityClauseOps &result) const;
100104
bool processProcBind(mlir::omp::ProcBindClauseOps &result) const;
105+
bool processTileSizes(lower::pft::Evaluation &eval,
106+
mlir::omp::LoopNestOperands &result) const;
101107
bool processSafelen(mlir::omp::SafelenClauseOps &result) const;
102108
bool processSchedule(lower::StatementContext &stmtCtx,
103109
mlir::omp::ScheduleClauseOps &result) const;

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
504504
[[fallthrough]];
505505
case OMPD_distribute:
506506
case OMPD_distribute_simd:
507-
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
507+
cp.processLoopNests(loc, eval, hostInfo->ops, hostInfo->iv);
508508
break;
509509

510510
case OMPD_teams:
@@ -523,7 +523,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
523523
[[fallthrough]];
524524
case OMPD_target_teams_distribute:
525525
case OMPD_target_teams_distribute_simd:
526-
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
526+
cp.processLoopNests(loc, eval, hostInfo->ops, hostInfo->iv);
527527
cp.processNumTeams(stmtCtx, hostInfo->ops);
528528
break;
529529

@@ -534,7 +534,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
534534
cp.processNumTeams(stmtCtx, hostInfo->ops);
535535
[[fallthrough]];
536536
case OMPD_loop:
537-
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
537+
cp.processLoopNests(loc, eval, hostInfo->ops, hostInfo->iv);
538538
break;
539539

540540
case OMPD_teams_workdistribute:
@@ -1573,20 +1573,7 @@ genLoopNestClauses(lower::AbstractConverter &converter,
15731573
cp.processCollapse(loc, eval, clauseOps, iv);
15741574

15751575
clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr();
1576-
1577-
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1578-
for (auto &clause : clauses)
1579-
if (clause.id == llvm::omp::Clause::OMPC_collapse) {
1580-
const auto &collapse = std::get<clause::Collapse>(clause.u);
1581-
int64_t collapseValue = evaluate::ToInt64(collapse.v).value();
1582-
clauseOps.numCollapse = firOpBuilder.getI64IntegerAttr(collapseValue);
1583-
}
1584-
1585-
llvm::SmallVector<int64_t> sizeValues;
1586-
auto *ompCons{eval.getIf<parser::OpenMPConstruct>()};
1587-
collectTileSizesFromOpenMPConstruct(ompCons, sizeValues, semaCtx);
1588-
if (sizeValues.size() > 0)
1589-
clauseOps.tileSizes = sizeValues;
1576+
cp.processTileSizes(eval, clauseOps);
15901577
}
15911578

15921579
static void genLoopClauses(

flang/lib/Lower/OpenMP/Utils.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -638,12 +638,12 @@ void collectTileSizesFromOpenMPConstruct(
638638
}
639639
}
640640

641-
bool collectLoopRelatedInfo(
641+
int64_t collectLoopRelatedInfo(
642642
lower::AbstractConverter &converter, mlir::Location currentLocation,
643643
lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
644644
mlir::omp::LoopRelatedClauseOps &result,
645645
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
646-
bool found = false;
646+
int64_t numCollapse = 1;
647647
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
648648

649649
// Collect the loops to collapse.
@@ -656,7 +656,7 @@ bool collectLoopRelatedInfo(
656656
if (auto *clause =
657657
ClauseFinder::findUniqueClause<omp::clause::Collapse>(clauses)) {
658658
collapseValue = evaluate::ToInt64(clause->v).value();
659-
found = true;
659+
numCollapse = collapseValue;
660660
}
661661

662662
// Collect sizes from tile directive if present
@@ -685,7 +685,6 @@ bool collectLoopRelatedInfo(
685685
if (const auto tclause{
686686
std::get_if<parser::OmpClause::Sizes>(&clause.u)}) {
687687
sizesLengthValue = tclause->v.size();
688-
found = true;
689688
}
690689
}
691690
}
@@ -728,7 +727,7 @@ bool collectLoopRelatedInfo(
728727

729728
convertLoopBounds(converter, currentLocation, result, loopVarTypeSize);
730729

731-
return found;
730+
return numCollapse;
732731
}
733732

734733
} // namespace omp

flang/lib/Lower/OpenMP/Utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ void genObjectList(const ObjectList &objects,
159159
void lastprivateModifierNotSupported(const omp::clause::Lastprivate &lastp,
160160
mlir::Location loc);
161161

162-
bool collectLoopRelatedInfo(
162+
int64_t collectLoopRelatedInfo(
163163
lower::AbstractConverter &converter, mlir::Location currentLocation,
164164
lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
165165
mlir::omp::LoopRelatedClauseOps &result,

mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,23 @@ class OpenMP_BindClauseSkip<
209209

210210
def OpenMP_BindClause : OpenMP_BindClauseSkip<>;
211211

212+
//===----------------------------------------------------------------------===//
213+
// V5.2: [4.4.3] `collapse` clause
214+
//===----------------------------------------------------------------------===//
215+
216+
class OpenMP_CollapseClauseSkip<
217+
bit traits = false, bit arguments = false, bit assemblyFormat = false,
218+
bit description = false, bit extraClassDeclaration = false
219+
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
220+
extraClassDeclaration> {
221+
let arguments = (ins
222+
ConfinedAttr<DefaultValuedOptionalAttr<I64Attr, "1">, [IntMinValue<1>]>
223+
:$collapse_num_loops
224+
);
225+
}
226+
227+
def OpenMP_CollapseClause : OpenMP_CollapseClauseSkip<>;
228+
212229
//===----------------------------------------------------------------------===//
213230
// V5.2: [5.7.2] `copyprivate` clause
214231
//===----------------------------------------------------------------------===//
@@ -317,38 +334,6 @@ class OpenMP_DeviceClauseSkip<
317334

318335
def OpenMP_DeviceClause : OpenMP_DeviceClauseSkip<>;
319336

320-
//===----------------------------------------------------------------------===//
321-
// V5.2: [XX.X] `collapse` clause
322-
//===----------------------------------------------------------------------===//
323-
324-
class OpenMP_CollapseClauseSkip<
325-
bit traits = false, bit arguments = false, bit assemblyFormat = false,
326-
bit description = false, bit extraClassDeclaration = false
327-
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
328-
extraClassDeclaration> {
329-
let arguments = (ins
330-
DefaultValuedOptionalAttr<I64Attr, "1">:$num_collapse
331-
);
332-
}
333-
334-
def OpenMP_CollapseClause : OpenMP_CollapseClauseSkip<>;
335-
336-
//===----------------------------------------------------------------------===//
337-
// V5.2: [xx.x] `sizes` clause
338-
//===----------------------------------------------------------------------===//
339-
340-
class OpenMP_TileSizesClauseSkip<
341-
bit traits = false, bit arguments = false, bit assemblyFormat = false,
342-
bit description = false, bit extraClassDeclaration = false
343-
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
344-
extraClassDeclaration> {
345-
let arguments = (ins
346-
OptionalAttr<DenseI64ArrayAttr>:$tile_sizes
347-
);
348-
}
349-
350-
def OpenMP_TileSizesClause : OpenMP_TileSizesClauseSkip<>;
351-
352337
//===----------------------------------------------------------------------===//
353338
// V5.2: [11.6.1] `dist_schedule` clause
354339
//===----------------------------------------------------------------------===//
@@ -1355,6 +1340,22 @@ class OpenMP_SimdlenClauseSkip<
13551340

13561341
def OpenMP_SimdlenClause : OpenMP_SimdlenClauseSkip<>;
13571342

1343+
//===----------------------------------------------------------------------===//
1344+
// V5.2: [9.1.1] `sizes` clause
1345+
//===----------------------------------------------------------------------===//
1346+
1347+
class OpenMP_TileSizesClauseSkip<
1348+
bit traits = false, bit arguments = false, bit assemblyFormat = false,
1349+
bit description = false, bit extraClassDeclaration = false
1350+
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
1351+
extraClassDeclaration> {
1352+
let arguments = (ins
1353+
OptionalAttr<DenseI64ArrayAttr>:$tile_sizes
1354+
);
1355+
}
1356+
1357+
def OpenMP_TileSizesClause : OpenMP_TileSizesClauseSkip<>;
1358+
13581359
//===----------------------------------------------------------------------===//
13591360
// V5.2: [5.5.9] `task_reduction` clause
13601361
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -614,15 +614,19 @@ def WorkshareLoopWrapperOp : OpenMP_Op<"workshare.loop_wrapper", traits = [
614614
def LoopNestOp : OpenMP_Op<"loop_nest", traits = [
615615
RecursiveMemoryEffects, SameVariadicOperandSize
616616
], clauses = [
617-
OpenMP_LoopRelatedClause,
618617
OpenMP_CollapseClause,
618+
OpenMP_LoopRelatedClause,
619619
OpenMP_TileSizesClause
620620
], singleRegion = true> {
621621
let summary = "rectangular loop nest";
622622
let description = [{
623-
This operation represents a collapsed rectangular loop nest. For each
624-
rectangular loop of the nest represented by an instance of this operation,
625-
lower and upper bounds, as well as a step variable, must be defined.
623+
This operation represents a rectangular loop nest which may be collapsed
624+
and/or tiled. For each rectangular loop of the nest represented by an
625+
instance of this operation, lower and upper bounds, as well as a step
626+
variable, must be defined. The collapse clause specifies how many loops
627+
that should be collapsed (1 if no collapse is done) after any tiling is
628+
performed. The tiling sizes is represented by the tile sizes clause.
629+
626630

627631
The lower and upper bounds specify a half-open range: the range includes the
628632
lower bound but does not include the upper bound. If the `loop_inclusive`
@@ -635,7 +639,7 @@ def LoopNestOp : OpenMP_Op<"loop_nest", traits = [
635639
`loop_steps` arguments.
636640

637641
```mlir
638-
omp.loop_nest (%i1, %i2) : i32 = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) {
642+
omp.loop_nest (%i1, %i2) : i32 = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) collapse(2) tiles(5,5) {
639643
%a = load %arrA[%i1, %i2] : memref<?x?xf32>
640644
%b = load %arrB[%i1, %i2] : memref<?x?xf32>
641645
%sum = arith.addf %a, %b : f32

mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -492,9 +492,10 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
492492

493493
// Create loop nest and populate region with contents of scf.parallel.
494494
auto loopOp = omp::LoopNestOp::create(
495-
rewriter, parallelOp.getLoc(), parallelOp.getLowerBound(),
496-
parallelOp.getUpperBound(), parallelOp.getStep(), false,
497-
parallelOp.getLowerBound().size(), nullptr);
495+
rewriter, parallelOp.getLoc(), parallelOp.getLowerBound().size(),
496+
parallelOp.getLowerBound(), parallelOp.getUpperBound(),
497+
parallelOp.getStep(), /*loop_inclusive=*/false,
498+
/*tile_sizes=*/nullptr);
498499

499500
rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(),
500501
loopOp.getRegion().begin());

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2980,7 +2980,7 @@ ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
29802980
return failure();
29812981
if (value > 1)
29822982
result.addAttribute(
2983-
"num_collapse",
2983+
"collapse_num_loops",
29842984
IntegerAttr::get(parser.getBuilder().getI64Type(), value));
29852985

29862986
// Parse tiles
@@ -3024,7 +3024,7 @@ void LoopNestOp::print(OpAsmPrinter &p) {
30243024
if (getLoopInclusive())
30253025
p << "inclusive ";
30263026
p << "step (" << getLoopSteps() << ") ";
3027-
if (int64_t numCollapse = getNumCollapse())
3027+
if (int64_t numCollapse = getCollapseNumLoops())
30283028
if (numCollapse > 1)
30293029
p << "collapse(" << numCollapse << ") ";
30303030

@@ -3037,9 +3037,9 @@ void LoopNestOp::print(OpAsmPrinter &p) {
30373037
void LoopNestOp::build(OpBuilder &builder, OperationState &state,
30383038
const LoopNestOperands &clauses) {
30393039
MLIRContext *ctx = builder.getContext();
3040-
LoopNestOp::build(builder, state, clauses.loopLowerBounds,
3041-
clauses.loopUpperBounds, clauses.loopSteps,
3042-
clauses.loopInclusive, clauses.numCollapse,
3040+
LoopNestOp::build(builder, state, clauses.collapseNumLoops,
3041+
clauses.loopLowerBounds, clauses.loopUpperBounds,
3042+
clauses.loopSteps, clauses.loopInclusive,
30433043
makeDenseI64ArrayAttr(ctx, clauses.tileSizes));
30443044
}
30453045

@@ -3058,7 +3058,7 @@ LogicalResult LoopNestOp::verify() {
30583058

30593059
uint64_t numIVs = getIVs().size();
30603060

3061-
if (const auto &numCollapse = getNumCollapse())
3061+
if (const auto &numCollapse = getCollapseNumLoops())
30623062
if (numCollapse > numIVs)
30633063
return emitOpError()
30643064
<< "collapse value is larger than the number of loops";

0 commit comments

Comments
 (0)