Skip to content

Commit fdc8285

Browse files
Meinersburmahesh-attarde
authored andcommitted
[mlir][omp] Add omp.tile operation (llvm#160292)
Add the `omp.tile` loop transformations for the OpenMP dialect. Used for lowering a standalone `!$omp tile` in Flang.
1 parent f123229 commit fdc8285

File tree

14 files changed

+870
-43
lines changed

14 files changed

+870
-43
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,35 @@ class OpenMP_NumTeamsClauseSkip<
995995

996996
def OpenMP_NumTeamsClause : OpenMP_NumTeamsClauseSkip<>;
997997

998+
//===----------------------------------------------------------------------===//
999+
// V5.1: [10.1.2] `sizes` clause
1000+
//===----------------------------------------------------------------------===//
1001+
1002+
class OpenMP_SizesClauseSkip<
1003+
bit traits = false, bit arguments = false, bit assemblyFormat = false,
1004+
bit description = false, bit extraClassDeclaration = false
1005+
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
1006+
extraClassDeclaration> {
1007+
let arguments = (ins
1008+
Variadic<IntLikeType>:$sizes
1009+
);
1010+
1011+
let optAssemblyFormat = [{
1012+
`sizes` `(` $sizes `:` type($sizes) `)`
1013+
}];
1014+
1015+
let description = [{
1016+
The `sizes` clauses defines the size of a grid over a multi-dimensional
1017+
logical iteration space. This grid is used for loop transformations such as
1018+
`tile` and `strip`. The size per dimension can be a variable, but only
1019+
values that are not at least 2 make sense. It is not specified what happens
1020+
when smaller values are used, but should still result in a loop nest that
1021+
executes each logical iteration once.
1022+
}];
1023+
}
1024+
1025+
def OpenMP_SizesClause : OpenMP_SizesClauseSkip<>;
1026+
9981027
//===----------------------------------------------------------------------===//
9991028
// V5.2: [10.1.2] `num_threads` clause
10001029
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/OpenMP/OpenMPOpBase.td

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,44 @@ def OpenMP_MapBoundsType : OpenMP_Type<"MapBounds", "map_bounds_ty"> {
3838
let summary = "Type for representing omp map clause bounds information";
3939
}
4040

41+
//===---------------------------------------------------------------------===//
42+
// OpenMP Canonical Loop Info Type
43+
//===---------------------------------------------------------------------===//
44+
45+
def CanonicalLoopInfoType : OpenMP_Type<"CanonicalLoopInfo", "cli"> {
46+
let summary = "Type for representing a reference to a canonical loop";
47+
let description = [{
48+
A variable of type CanonicalLoopInfo refers to an OpenMP-compatible
49+
canonical loop in the same function. Values of this type are not
50+
available at runtime and therefore cannot be used by the program itself,
51+
i.e. an opaque type. It is similar to the transform dialect's
52+
`!transform.interface` type, but instead of implementing an interface
53+
for each transformation, the OpenMP dialect itself defines possible
54+
operations on this type.
55+
56+
A value of type CanonicalLoopInfoType (in the following: CLI) value can be
57+
58+
1. created by omp.new_cli.
59+
2. passed to omp.canonical_loop to associate the loop to that CLI. A CLI
60+
can only be associated once.
61+
3. passed to an omp loop transformation operation that modifies the loop
62+
associated with the CLI. The CLI is the "applyee" and the operation is
63+
the consumer. A CLI can only be consumed once.
64+
4. passed to an omp loop transformation operation to associate the cli with
65+
a result of that transformation. The CLI is the "generatee" and the
66+
operation is the generator.
67+
68+
A CLI cannot
69+
70+
1. be returned from a function.
71+
2. be passed to operations that are not specifically designed to take a
72+
CanonicalLoopInfoType, including AnyType.
73+
74+
A CLI directly corresponds to an object of
75+
OpenMPIRBuilder's CanonicalLoopInfo struct when lowering to LLVM-IR.
76+
}];
77+
}
78+
4179
//===----------------------------------------------------------------------===//
4280
// Base classes for OpenMP dialect operations.
4381
//===----------------------------------------------------------------------===//
@@ -211,8 +249,35 @@ class OpenMP_Op<string mnemonic, list<Trait> traits = [],
211249
// Doesn't actually create a C++ base class (only defines default values for
212250
// tablegen classes that derive from this). Use LoopTransformationInterface
213251
// instead for common operations.
214-
class OpenMPTransform_Op<string mnemonic, list<Trait> traits = []> :
215-
OpenMP_Op<mnemonic, !listconcat([DeclareOpInterfaceMethods<LoopTransformationInterface>], traits) > {
252+
class OpenMPTransform_Op<string mnemonic,
253+
list<Trait> traits = [],
254+
list<OpenMP_Clause> clauses = []> :
255+
OpenMP_Op<mnemonic,
256+
traits = !listconcat([DeclareOpInterfaceMethods<LoopTransformationInterface>], traits),
257+
clauses = clauses> {
258+
}
259+
260+
// Base clause for loop transformations using the standard syntax.
261+
//
262+
// omp.opname ($generatees) <- ($applyees) clause(...) clause(...) ... <attr-dicr>
263+
// omp.opname ($applyees) clause(...) clause(...) ... <attr-dict>
264+
//
265+
// $generatees is optional and is assumed to be empty if omitted
266+
class OpenMPTransformBase_Op<string mnemonic,
267+
list<Trait> traits = [],
268+
list<OpenMP_Clause> clauses = []> :
269+
OpenMPTransform_Op<mnemonic,
270+
traits = !listconcat(traits, [AttrSizedOperandSegments]),
271+
clauses = clauses> {
272+
273+
let arguments = !con(
274+
(ins Variadic<CanonicalLoopInfoType>:$generatees,
275+
Variadic<CanonicalLoopInfoType>:$applyees
276+
), clausesArgs);
277+
278+
let assemblyFormat = [{ custom<LoopTransformClis>($generatees, $applyees) }]
279+
# clausesAssemblyFormat
280+
# [{ attr-dict }];
216281
}
217282

218283
#endif // OPENMP_OP_BASE

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 25 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -357,44 +357,6 @@ def SingleOp : OpenMP_Op<"single", traits = [
357357
let hasVerifier = 1;
358358
}
359359

360-
//===---------------------------------------------------------------------===//
361-
// OpenMP Canonical Loop Info Type
362-
//===---------------------------------------------------------------------===//
363-
364-
def CanonicalLoopInfoType : OpenMP_Type<"CanonicalLoopInfo", "cli"> {
365-
let summary = "Type for representing a reference to a canonical loop";
366-
let description = [{
367-
A variable of type CanonicalLoopInfo refers to an OpenMP-compatible
368-
canonical loop in the same function. Values of this type are not
369-
available at runtime and therefore cannot be used by the program itself,
370-
i.e. an opaque type. It is similar to the transform dialect's
371-
`!transform.interface` type, but instead of implementing an interface
372-
for each transformation, the OpenMP dialect itself defines possible
373-
operations on this type.
374-
375-
A value of type CanonicalLoopInfoType (in the following: CLI) value can be
376-
377-
1. created by omp.new_cli.
378-
2. passed to omp.canonical_loop to associate the loop to that CLI. A CLI
379-
can only be associated once.
380-
3. passed to an omp loop transformation operation that modifies the loop
381-
associated with the CLI. The CLI is the "applyee" and the operation is
382-
the consumer. A CLI can only be consumed once.
383-
4. passed to an omp loop transformation operation to associate the cli with
384-
a result of that transformation. The CLI is the "generatee" and the
385-
operation is the generator.
386-
387-
A CLI cannot
388-
389-
1. be returned from a function.
390-
2. be passed to operations that are not specifically designed to take a
391-
CanonicalLoopInfoType, including AnyType.
392-
393-
A CLI directly corresponds to an object of
394-
OpenMPIRBuilder's CanonicalLoopInfo struct when lowering to LLVM-IR.
395-
}];
396-
}
397-
398360
//===---------------------------------------------------------------------===//
399361
// OpenMP Canonical Loop Info Creation
400362
//===---------------------------------------------------------------------===//
@@ -563,6 +525,31 @@ def UnrollHeuristicOp : OpenMPTransform_Op<"unroll_heuristic", []> {
563525
let hasCustomAssemblyFormat = 1;
564526
}
565527

528+
//===----------------------------------------------------------------------===//
529+
// OpenMP tile operation
530+
//===----------------------------------------------------------------------===//
531+
532+
def TileOp : OpenMPTransformBase_Op<"tile",
533+
clauses = [OpenMP_SizesClause]> {
534+
let summary = "OpenMP tile operation";
535+
let description = [{
536+
Represents the OpenMP tile directive introduced in OpenMP 5.1.
537+
538+
The construct partitions the logical iteration space of the affected loops
539+
into equally-sized tiles, then creates two sets of nested loops. The outer
540+
loops, called the grid loops, iterate over all tiles. The inner loops,
541+
called the intratile loops, iterate over the logical iterations of a tile.
542+
The sizes clause determines the size of a tile.
543+
544+
Currently, the affected loops must be rectangular (the tripcount of the
545+
inner loop must not depend on any iv of an surrounding affected loop) and
546+
perfectly nested (except for the innermost affected loop, no operations
547+
other than the nested loop and the terminator in the loop body).
548+
}] # clausesDescription;
549+
550+
let hasVerifier = 1;
551+
}
552+
566553
//===----------------------------------------------------------------------===//
567554
// 2.8.3 Workshare Construct
568555
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "llvm/ADT/TypeSwitch.h"
3434
#include "llvm/ADT/bit.h"
3535
#include "llvm/Frontend/OpenMP/OMPConstants.h"
36+
#include "llvm/Support/InterleavedRange.h"
3637
#include <cstddef>
3738
#include <iterator>
3839
#include <optional>
@@ -3385,6 +3386,9 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
33853386
Value result = getResult();
33863387
auto [newCli, gen, cons] = decodeCli(result);
33873388

3389+
// Structured binding `gen` cannot be captured in lambdas before C++20
3390+
OpOperand *generator = gen;
3391+
33883392
// Derive the CLI variable name from its generator:
33893393
// * "canonloop" for omp.canonical_loop
33903394
// * custom name for loop transformation generatees
@@ -3403,6 +3407,24 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
34033407
.Case([&](UnrollHeuristicOp op) -> std::string {
34043408
llvm_unreachable("heuristic unrolling does not generate a loop");
34053409
})
3410+
.Case([&](TileOp op) -> std::string {
3411+
auto [generateesFirst, generateesCount] =
3412+
op.getGenerateesODSOperandIndexAndLength();
3413+
unsigned firstGrid = generateesFirst;
3414+
unsigned firstIntratile = generateesFirst + generateesCount / 2;
3415+
unsigned end = generateesFirst + generateesCount;
3416+
unsigned opnum = generator->getOperandNumber();
3417+
// In the OpenMP apply and looprange clauses, indices are 1-based
3418+
if (firstGrid <= opnum && opnum < firstIntratile) {
3419+
unsigned gridnum = opnum - firstGrid + 1;
3420+
return ("grid" + Twine(gridnum)).str();
3421+
}
3422+
if (firstIntratile <= opnum && opnum < end) {
3423+
unsigned intratilenum = opnum - firstIntratile + 1;
3424+
return ("intratile" + Twine(intratilenum)).str();
3425+
}
3426+
llvm_unreachable("Unexpected generatee argument");
3427+
})
34063428
.Default([&](Operation *op) {
34073429
assert(false && "TODO: Custom name for this operation");
34083430
return "transformed";
@@ -3631,6 +3653,138 @@ UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
36313653
return {0, 0};
36323654
}
36333655

3656+
//===----------------------------------------------------------------------===//
3657+
// TileOp
3658+
//===----------------------------------------------------------------------===//
3659+
3660+
static void printLoopTransformClis(OpAsmPrinter &p, TileOp op,
3661+
OperandRange generatees,
3662+
OperandRange applyees) {
3663+
if (!generatees.empty())
3664+
p << '(' << llvm::interleaved(generatees) << ')';
3665+
3666+
if (!applyees.empty())
3667+
p << " <- (" << llvm::interleaved(applyees) << ')';
3668+
}
3669+
3670+
static ParseResult parseLoopTransformClis(
3671+
OpAsmParser &parser,
3672+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &generateesOperands,
3673+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &applyeesOperands) {
3674+
if (parser.parseOptionalLess()) {
3675+
// Syntax 1: generatees present
3676+
3677+
if (parser.parseOperandList(generateesOperands,
3678+
mlir::OpAsmParser::Delimiter::Paren))
3679+
return failure();
3680+
3681+
if (parser.parseLess())
3682+
return failure();
3683+
} else {
3684+
// Syntax 2: generatees omitted
3685+
}
3686+
3687+
// Parse `<-` (`<` has already been parsed)
3688+
if (parser.parseMinus())
3689+
return failure();
3690+
3691+
if (parser.parseOperandList(applyeesOperands,
3692+
mlir::OpAsmParser::Delimiter::Paren))
3693+
return failure();
3694+
3695+
return success();
3696+
}
3697+
3698+
LogicalResult TileOp::verify() {
3699+
if (getApplyees().empty())
3700+
return emitOpError() << "must apply to at least one loop";
3701+
3702+
if (getSizes().size() != getApplyees().size())
3703+
return emitOpError() << "there must be one tile size for each applyee";
3704+
3705+
if (!getGeneratees().empty() &&
3706+
2 * getSizes().size() != getGeneratees().size())
3707+
return emitOpError()
3708+
<< "expecting two times the number of generatees than applyees";
3709+
3710+
DenseSet<Value> parentIVs;
3711+
3712+
Value parent = getApplyees().front();
3713+
for (auto &&applyee : llvm::drop_begin(getApplyees())) {
3714+
auto [parentCreate, parentGen, parentCons] = decodeCli(parent);
3715+
auto [create, gen, cons] = decodeCli(applyee);
3716+
3717+
if (!parentGen)
3718+
return emitOpError() << "applyee CLI has no generator";
3719+
3720+
auto parentLoop = dyn_cast_or_null<CanonicalLoopOp>(parentGen->getOwner());
3721+
if (!parentGen)
3722+
return emitOpError()
3723+
<< "currently only supports omp.canonical_loop as applyee";
3724+
3725+
parentIVs.insert(parentLoop.getInductionVar());
3726+
3727+
if (!gen)
3728+
return emitOpError() << "applyee CLI has no generator";
3729+
auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
3730+
if (!loop)
3731+
return emitOpError()
3732+
<< "currently only supports omp.canonical_loop as applyee";
3733+
3734+
// Canonical loop must be perfectly nested, i.e. the body of the parent must
3735+
// only contain the omp.canonical_loop of the nested loops, and
3736+
// omp.terminator
3737+
bool isPerfectlyNested = [&]() {
3738+
auto &parentBody = parentLoop.getRegion();
3739+
if (!parentBody.hasOneBlock())
3740+
return false;
3741+
auto &parentBlock = parentBody.getBlocks().front();
3742+
3743+
auto nestedLoopIt = parentBlock.begin();
3744+
if (nestedLoopIt == parentBlock.end() ||
3745+
(&*nestedLoopIt != loop.getOperation()))
3746+
return false;
3747+
3748+
auto termIt = std::next(nestedLoopIt);
3749+
if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
3750+
return false;
3751+
3752+
if (std::next(termIt) != parentBlock.end())
3753+
return false;
3754+
3755+
return true;
3756+
}();
3757+
if (!isPerfectlyNested)
3758+
return emitOpError() << "tiled loop nest must be perfectly nested";
3759+
3760+
if (parentIVs.contains(loop.getTripCount()))
3761+
return emitOpError() << "tiled loop nest must be rectangular";
3762+
3763+
parent = applyee;
3764+
}
3765+
3766+
// TODO: The tile sizes must be computed before the loop, but checking this
3767+
// requires dominance analysis. For instance:
3768+
//
3769+
// %canonloop = omp.new_cli
3770+
// omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) {
3771+
// // write to %x
3772+
// omp.terminator
3773+
// }
3774+
// %ts = llvm.load %x
3775+
// omp.tile <- (%canonloop) sizes(%ts : i32)
3776+
3777+
return success();
3778+
}
3779+
3780+
std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
3781+
return getODSOperandIndexAndLength(odsIndex_applyees);
3782+
}
3783+
3784+
std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
3785+
return getODSOperandIndexAndLength(odsIndex_generatees);
3786+
}
3787+
36343788
//===----------------------------------------------------------------------===//
36353789
// Critical construct (2.17.1)
36363790
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)