33
33
#include " llvm/ADT/TypeSwitch.h"
34
34
#include " llvm/ADT/bit.h"
35
35
#include " llvm/Frontend/OpenMP/OMPConstants.h"
36
+ #include " llvm/Support/InterleavedRange.h"
36
37
#include < cstddef>
37
38
#include < iterator>
38
39
#include < optional>
@@ -3385,6 +3386,9 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3385
3386
Value result = getResult ();
3386
3387
auto [newCli, gen, cons] = decodeCli (result);
3387
3388
3389
+ // Structured binding `gen` cannot be captured in lambdas before C++20
3390
+ OpOperand *generator = gen;
3391
+
3388
3392
// Derive the CLI variable name from its generator:
3389
3393
// * "canonloop" for omp.canonical_loop
3390
3394
// * custom name for loop transformation generatees
@@ -3403,6 +3407,24 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3403
3407
.Case ([&](UnrollHeuristicOp op) -> std::string {
3404
3408
llvm_unreachable (" heuristic unrolling does not generate a loop" );
3405
3409
})
3410
+ .Case ([&](TileOp op) -> std::string {
3411
+ auto [generateesFirst, generateesCount] =
3412
+ op.getGenerateesODSOperandIndexAndLength ();
3413
+ unsigned firstGrid = generateesFirst;
3414
+ unsigned firstIntratile = generateesFirst + generateesCount / 2 ;
3415
+ unsigned end = generateesFirst + generateesCount;
3416
+ unsigned opnum = generator->getOperandNumber ();
3417
+ // In the OpenMP apply and looprange clauses, indices are 1-based
3418
+ if (firstGrid <= opnum && opnum < firstIntratile) {
3419
+ unsigned gridnum = opnum - firstGrid + 1 ;
3420
+ return (" grid" + Twine (gridnum)).str ();
3421
+ }
3422
+ if (firstIntratile <= opnum && opnum < end) {
3423
+ unsigned intratilenum = opnum - firstIntratile + 1 ;
3424
+ return (" intratile" + Twine (intratilenum)).str ();
3425
+ }
3426
+ llvm_unreachable (" Unexpected generatee argument" );
3427
+ })
3406
3428
.Default ([&](Operation *op) {
3407
3429
assert (false && " TODO: Custom name for this operation" );
3408
3430
return " transformed" ;
@@ -3631,6 +3653,138 @@ UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
3631
3653
return {0 , 0 };
3632
3654
}
3633
3655
3656
+ // ===----------------------------------------------------------------------===//
3657
+ // TileOp
3658
+ // ===----------------------------------------------------------------------===//
3659
+
3660
+ static void printLoopTransformClis (OpAsmPrinter &p, TileOp op,
3661
+ OperandRange generatees,
3662
+ OperandRange applyees) {
3663
+ if (!generatees.empty ())
3664
+ p << ' (' << llvm::interleaved (generatees) << ' )' ;
3665
+
3666
+ if (!applyees.empty ())
3667
+ p << " <- (" << llvm::interleaved (applyees) << ' )' ;
3668
+ }
3669
+
3670
+ static ParseResult parseLoopTransformClis (
3671
+ OpAsmParser &parser,
3672
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &generateesOperands,
3673
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &applyeesOperands) {
3674
+ if (parser.parseOptionalLess ()) {
3675
+ // Syntax 1: generatees present
3676
+
3677
+ if (parser.parseOperandList (generateesOperands,
3678
+ mlir::OpAsmParser::Delimiter::Paren))
3679
+ return failure ();
3680
+
3681
+ if (parser.parseLess ())
3682
+ return failure ();
3683
+ } else {
3684
+ // Syntax 2: generatees omitted
3685
+ }
3686
+
3687
+ // Parse `<-` (`<` has already been parsed)
3688
+ if (parser.parseMinus ())
3689
+ return failure ();
3690
+
3691
+ if (parser.parseOperandList (applyeesOperands,
3692
+ mlir::OpAsmParser::Delimiter::Paren))
3693
+ return failure ();
3694
+
3695
+ return success ();
3696
+ }
3697
+
3698
+ LogicalResult TileOp::verify () {
3699
+ if (getApplyees ().empty ())
3700
+ return emitOpError () << " must apply to at least one loop" ;
3701
+
3702
+ if (getSizes ().size () != getApplyees ().size ())
3703
+ return emitOpError () << " there must be one tile size for each applyee" ;
3704
+
3705
+ if (!getGeneratees ().empty () &&
3706
+ 2 * getSizes ().size () != getGeneratees ().size ())
3707
+ return emitOpError ()
3708
+ << " expecting two times the number of generatees than applyees" ;
3709
+
3710
+ DenseSet<Value> parentIVs;
3711
+
3712
+ Value parent = getApplyees ().front ();
3713
+ for (auto &&applyee : llvm::drop_begin (getApplyees ())) {
3714
+ auto [parentCreate, parentGen, parentCons] = decodeCli (parent);
3715
+ auto [create, gen, cons] = decodeCli (applyee);
3716
+
3717
+ if (!parentGen)
3718
+ return emitOpError () << " applyee CLI has no generator" ;
3719
+
3720
+ auto parentLoop = dyn_cast_or_null<CanonicalLoopOp>(parentGen->getOwner ());
3721
+ if (!parentGen)
3722
+ return emitOpError ()
3723
+ << " currently only supports omp.canonical_loop as applyee" ;
3724
+
3725
+ parentIVs.insert (parentLoop.getInductionVar ());
3726
+
3727
+ if (!gen)
3728
+ return emitOpError () << " applyee CLI has no generator" ;
3729
+ auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner ());
3730
+ if (!loop)
3731
+ return emitOpError ()
3732
+ << " currently only supports omp.canonical_loop as applyee" ;
3733
+
3734
+ // Canonical loop must be perfectly nested, i.e. the body of the parent must
3735
+ // only contain the omp.canonical_loop of the nested loops, and
3736
+ // omp.terminator
3737
+ bool isPerfectlyNested = [&]() {
3738
+ auto &parentBody = parentLoop.getRegion ();
3739
+ if (!parentBody.hasOneBlock ())
3740
+ return false ;
3741
+ auto &parentBlock = parentBody.getBlocks ().front ();
3742
+
3743
+ auto nestedLoopIt = parentBlock.begin ();
3744
+ if (nestedLoopIt == parentBlock.end () ||
3745
+ (&*nestedLoopIt != loop.getOperation ()))
3746
+ return false ;
3747
+
3748
+ auto termIt = std::next (nestedLoopIt);
3749
+ if (termIt == parentBlock.end () || !isa<TerminatorOp>(termIt))
3750
+ return false ;
3751
+
3752
+ if (std::next (termIt) != parentBlock.end ())
3753
+ return false ;
3754
+
3755
+ return true ;
3756
+ }();
3757
+ if (!isPerfectlyNested)
3758
+ return emitOpError () << " tiled loop nest must be perfectly nested" ;
3759
+
3760
+ if (parentIVs.contains (loop.getTripCount ()))
3761
+ return emitOpError () << " tiled loop nest must be rectangular" ;
3762
+
3763
+ parent = applyee;
3764
+ }
3765
+
3766
+ // TODO: The tile sizes must be computed before the loop, but checking this
3767
+ // requires dominance analysis. For instance:
3768
+ //
3769
+ // %canonloop = omp.new_cli
3770
+ // omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) {
3771
+ // // write to %x
3772
+ // omp.terminator
3773
+ // }
3774
+ // %ts = llvm.load %x
3775
+ // omp.tile <- (%canonloop) sizes(%ts : i32)
3776
+
3777
+ return success ();
3778
+ }
3779
+
3780
+ std::pair<unsigned , unsigned > TileOp ::getApplyeesODSOperandIndexAndLength () {
3781
+ return getODSOperandIndexAndLength (odsIndex_applyees);
3782
+ }
3783
+
3784
+ std::pair<unsigned , unsigned > TileOp::getGenerateesODSOperandIndexAndLength () {
3785
+ return getODSOperandIndexAndLength (odsIndex_generatees);
3786
+ }
3787
+
3634
3788
// ===----------------------------------------------------------------------===//
3635
3789
// Critical construct (2.17.1)
3636
3790
// ===----------------------------------------------------------------------===//
0 commit comments