@@ -291,47 +291,61 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
291291 return arith::DivUIOp::create (builder, loc, sum, divisor);
292292}
293293
294- // / Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
295- // / associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
296- // / 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
297- // / unrolled iteration using annotateFn.
298- static void generateUnrolledLoop (
299- Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor,
294+ void mlir::generateUnrolledLoop (
295+ Block *loopBodyBlock, Value iv, uint64_t unrollFactor,
300296 function_ref<Value(unsigned , Value, OpBuilder)> ivRemapFn,
301297 function_ref<void(unsigned , Operation *, OpBuilder)> annotateFn,
302- ValueRange iterArgs, ValueRange yieldedValues) {
298+ ValueRange iterArgs, ValueRange yieldedValues,
299+ IRMapping *clonedToSrcOpsMap) {
300+
301+ // Check if the op was cloned from another source op, and return it if found
302+ // (or the same op if not found)
303+ auto findOriginalSrcOp =
304+ [](Operation *op, const IRMapping &clonedToSrcOpsMap) -> Operation * {
305+ Operation *srcOp = op;
306+ // If the source op derives from another op: traverse the chain to find the
307+ // original source op
308+ while (srcOp && clonedToSrcOpsMap.contains (srcOp))
309+ srcOp = clonedToSrcOpsMap.lookup (srcOp);
310+ return srcOp;
311+ };
312+
303313 // Builder to insert unrolled bodies just before the terminator of the body of
304- // 'forOp' .
314+ // the loop .
305315 auto builder = OpBuilder::atBlockTerminator (loopBodyBlock);
306316
307- constexpr auto defaultAnnotateFn = [](unsigned , Operation *, OpBuilder) {};
317+ static const auto noopAnnotateFn = [](unsigned , Operation *, OpBuilder) {};
308318 if (!annotateFn)
309- annotateFn = defaultAnnotateFn ;
319+ annotateFn = noopAnnotateFn ;
310320
311321 // Keep a pointer to the last non-terminator operation in the original block
312322 // so that we know what to clone (since we are doing this in-place).
313323 Block::iterator srcBlockEnd = std::prev (loopBodyBlock->end (), 2 );
314324
315- // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies).
325+ // Unroll the contents of the loop body (append unrollFactor - 1 additional
326+ // copies).
316327 SmallVector<Value, 4 > lastYielded (yieldedValues);
317328
318329 for (unsigned i = 1 ; i < unrollFactor; i++) {
319- IRMapping operandMap;
320-
321330 // Prepare operand map.
331+ IRMapping operandMap;
322332 operandMap.map (iterArgs, lastYielded);
323333
324334 // If the induction variable is used, create a remapping to the value for
325335 // this unrolled instance.
326- if (!forOpIV .use_empty ()) {
327- Value ivUnroll = ivRemapFn (i, forOpIV , builder);
328- operandMap.map (forOpIV , ivUnroll);
336+ if (!iv .use_empty ()) {
337+ Value ivUnroll = ivRemapFn (i, iv , builder);
338+ operandMap.map (iv , ivUnroll);
329339 }
330340
331341 // Clone the original body of 'forOp'.
332342 for (auto it = loopBodyBlock->begin (); it != std::next (srcBlockEnd); it++) {
333- Operation *clonedOp = builder.clone (*it, operandMap);
343+ Operation *srcOp = &(*it);
344+ Operation *clonedOp = builder.clone (*srcOp, operandMap);
334345 annotateFn (i, clonedOp, builder);
346+ if (clonedToSrcOpsMap)
347+ clonedToSrcOpsMap->map (clonedOp,
348+ findOriginalSrcOp (srcOp, *clonedToSrcOpsMap));
335349 }
336350
337351 // Update yielded values.
@@ -1544,3 +1558,116 @@ bool mlir::isPerfectlyNestedForLoops(
15441558 }
15451559 return true ;
15461560}
1561+
1562+ std::optional<llvm::APSInt> mlir::scf::computeUbMinusLb (Value lb, Value ub,
1563+ bool isSigned) {
1564+ llvm::APSInt diff;
1565+ auto addOp = ub.getDefiningOp <arith::AddIOp>();
1566+ if (!addOp)
1567+ return std::nullopt ;
1568+ if ((isSigned && !addOp.hasNoSignedWrap ()) ||
1569+ (!isSigned && !addOp.hasNoUnsignedWrap ()))
1570+ return std::nullopt ;
1571+
1572+ if (addOp.getLhs () != lb ||
1573+ !matchPattern (addOp.getRhs (), m_ConstantInt (&diff)))
1574+ return std::nullopt ;
1575+ return diff;
1576+ }
1577+
1578+ llvm::SmallVector<int64_t >
1579+ mlir::getConstLoopTripCounts (mlir::LoopLikeOpInterface loopOp) {
1580+ std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds ();
1581+ std::optional<SmallVector<OpFoldResult>> upBnds = loopOp.getLoopUpperBounds ();
1582+ std::optional<SmallVector<OpFoldResult>> steps = loopOp.getLoopSteps ();
1583+ if (!loBnds || !upBnds || !steps)
1584+ return {};
1585+ llvm::SmallVector<int64_t > tripCounts;
1586+ for (auto [lb, ub, step] : llvm::zip (*loBnds, *upBnds, *steps)) {
1587+ std::optional<llvm::APInt> numIter = constantTripCount (
1588+ lb, ub, step, /* isSigned=*/ true , scf::computeUbMinusLb);
1589+ if (!numIter)
1590+ return {};
1591+ tripCounts.push_back (numIter->getSExtValue ());
1592+ }
1593+ return tripCounts;
1594+ }
1595+
1596+ FailureOr<scf::ParallelOp> mlir::parallelLoopUnrollByFactors (
1597+ scf::ParallelOp op, ArrayRef<uint64_t > unrollFactors,
1598+ RewriterBase &rewriter,
1599+ function_ref<void (unsigned , Operation *, OpBuilder)> annotateFn,
1600+ IRMapping *clonedToSrcOpsMap) {
1601+ const unsigned numLoops = op.getNumLoops ();
1602+ assert (llvm::none_of (unrollFactors, [](uint64_t f) { return f == 0 ; }) &&
1603+ " Expected positive unroll factors" );
1604+ assert ((!unrollFactors.empty () && (unrollFactors.size () <= numLoops)) &&
1605+ " Expected non-empty unroll factors of size <= to the number of loops" );
1606+
1607+ // Bail out if no valid unroll factors were provided
1608+ if (llvm::all_of (unrollFactors, [](uint64_t f) { return f == 1 ; }))
1609+ return rewriter.notifyMatchFailure (
1610+ op, " Unrolling not applied if all factors are 1" );
1611+
1612+ // Return if the loop body is empty.
1613+ if (llvm::hasSingleElement (op.getBody ()->getOperations ()))
1614+ return rewriter.notifyMatchFailure (op, " Cannot unroll an empty loop body" );
1615+
1616+ // If the provided unroll factors do not cover all the loop dims, they are
1617+ // applied to the inner loop dimensions.
1618+ const unsigned firstLoopDimIdx = numLoops - unrollFactors.size ();
1619+
1620+ // Make sure that the unroll factors divide the iteration space evenly
1621+ // TODO: Support unrolling loops with dynamic iteration spaces.
1622+ const llvm::SmallVector<int64_t > tripCounts = getConstLoopTripCounts (op);
1623+ if (tripCounts.empty ())
1624+ return rewriter.notifyMatchFailure (
1625+ op, " Failed to compute constant trip counts for the loop. Note that "
1626+ " dynamic loop sizes are not supported." );
1627+
1628+ for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) {
1629+ const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx];
1630+ if (tripCounts[dimIdx] % unrollFactor)
1631+ return rewriter.notifyMatchFailure (
1632+ op, " Unroll factors don't divide the iteration space evenly" );
1633+ }
1634+
1635+ std::optional<SmallVector<OpFoldResult>> maybeFoldSteps = op.getLoopSteps ();
1636+ if (!maybeFoldSteps)
1637+ return rewriter.notifyMatchFailure (op, " Failed to retrieve loop steps" );
1638+ llvm::SmallVector<size_t > steps{};
1639+ for (auto step : *maybeFoldSteps)
1640+ steps.push_back (static_cast <size_t >(*getConstantIntValue (step)));
1641+
1642+ for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) {
1643+ const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx];
1644+ if (unrollFactor == 1 )
1645+ continue ;
1646+ const size_t origStep = steps[dimIdx];
1647+ const int64_t newStep = origStep * unrollFactor;
1648+ IRMapping clonedToSrcOpsMap;
1649+
1650+ ValueRange iterArgs = ValueRange (op.getRegionIterArgs ());
1651+ auto yieldedValues = op.getBody ()->getTerminator ()->getOperands ();
1652+
1653+ generateUnrolledLoop (
1654+ op.getBody (), op.getInductionVars ()[dimIdx], unrollFactor,
1655+ [&](unsigned i, Value iv, OpBuilder b) {
1656+ // iv' = iv + step * i;
1657+ const AffineExpr expr = b.getAffineDimExpr (0 ) + (origStep * i);
1658+ const auto map =
1659+ b.getDimIdentityMap ().dropResult (0 ).insertResult (expr, 0 );
1660+ return affine::AffineApplyOp::create (b, iv.getLoc (), map,
1661+ ValueRange{iv});
1662+ },
1663+ /* annotateFn*/ annotateFn, iterArgs, yieldedValues, &clonedToSrcOpsMap);
1664+
1665+ // Update loop step
1666+ auto prevInsertPoint = rewriter.saveInsertionPoint ();
1667+ rewriter.setInsertionPoint (op);
1668+ op.getStepMutable ()[dimIdx].assign (
1669+ arith::ConstantIndexOp::create (rewriter, op.getLoc (), newStep));
1670+ rewriter.restoreInsertionPoint (prevInsertPoint);
1671+ }
1672+ return op;
1673+ }
0 commit comments