@@ -291,10 +291,28 @@ tileDispatchUsingSCFFopOp(RewriterBase &rewriter, TilingInterface op,
291291
292292 IREETilingResult tilingResult;
293293 tilingResult.tiledLoops .resize (numLoops, false );
294- for (auto [index, tileSize] : llvm::enumerate (tileSizes)) {
295- if (!isConstantIntValue (tileSize, 0 )) {
296- tilingResult.tiledLoops .set (index);
294+ AffineExpr s0, s1, s2, s3; // lb, ub, step, tileSize
295+ bindSymbols (rewriter.getContext (), s0, s1, s2, s3);
296+ AffineExpr numTilesExprs = (s1 - s0).ceilDiv (s2 * s3);
297+ for (auto [index, iteratorType, range, tileSize] :
298+ llvm::enumerate (op.getLoopIteratorTypes (), iterationDomain, tileSizes)) {
299+ // If distribution is specified, only parallel loops are tiled.
300+ if (options.distribution && iteratorType != utils::IteratorType::parallel) {
301+ continue ;
302+ }
303+ // If tile size is 0, it isnt tiled.
304+ if (isConstantIntValue (tileSize, 0 )) {
305+ continue ;
297306 }
307+ // If number of tiles is statically know to be 1, the loop isnt tiled.
308+ OpFoldResult numTiles = affine::makeComposedFoldedAffineApply (
309+ rewriter, loc, numTilesExprs,
310+ {range.offset , range.size , range.stride , tileSize});
311+ if (isConstantIntValue (numTiles, 1 )) {
312+ continue ;
313+ }
314+
315+ tilingResult.tiledLoops .set (index);
298316 }
299317
300318 if (!tilingResult.tiledLoops .any ()) {
@@ -328,40 +346,30 @@ tileDispatchUsingSCFFopOp(RewriterBase &rewriter, TilingInterface op,
328346 iterationDomain.size (), linalg::DistributionMethod::None);
329347 SmallVector<linalg::ProcInfo> procInfo;
330348 if (options.distribution ) {
331- SmallVector<utils::IteratorType> iteratorTypes =
332- op.getLoopIteratorTypes ();
333-
334- // The parallel loops that are tiled are partitionable loops.
335349 SmallVector<Range> parallelLoopRanges;
336- SmallVector<unsigned > partitionedLoopIds;
337-
338- AffineExpr s0, s1, s2, s3; // lb, ub, step, tileSize
339- bindSymbols (rewriter.getContext (), s0, s1, s2, s3);
340- AffineExpr numTilesExprs = (s1 - s0).ceilDiv (s2 * s3);
341- for (auto [index, iteratorType] : llvm::enumerate (iteratorTypes)) {
342- if (iteratorType != utils::IteratorType::parallel ||
343- isConstantIntValue (tileSizes[index], 0 )) {
344- continue ;
345- }
346-
347- OpFoldResult numTiles = affine::makeComposedFoldedAffineApply (
348- rewriter, loc, numTilesExprs,
349- {iterationDomain[index].offset , iterationDomain[index].size ,
350- iterationDomain[index].stride , tileSizes[index]});
351- if (isConstantIntValue (numTiles, 1 )) {
352- continue ;
350+ for (auto loopIdx : llvm::seq<unsigned >(0 , numLoops)) {
351+ if (tilingResult.tiledLoops .test (loopIdx)) {
352+ AffineExpr s0, s1;
353+ bindSymbols (rewriter.getContext (), s0, s1);
354+ OpFoldResult parallelLoopStep = affine::makeComposedFoldedAffineApply (
355+ rewriter, loc, s0 * s1,
356+ {iterationDomain[loopIdx].stride , tileSizes[loopIdx]});
357+ Range r = {iterationDomain[loopIdx].offset ,
358+ iterationDomain[loopIdx].size , parallelLoopStep};
359+ parallelLoopRanges.emplace_back (std::move (r));
353360 }
354-
355- parallelLoopRanges.push_back (iterationDomain[index]);
356- partitionedLoopIds.push_back (index);
357361 }
358362
359- // Query the callback to get the {procId, nprocs} to use.
360363 procInfo =
361364 options.distribution ->procInfo (rewriter, loc, parallelLoopRanges);
362365
363- for (auto [index, loopIdx] : llvm::enumerate (partitionedLoopIds)) {
364- distributionMethods[loopIdx] = procInfo[index].distributionMethod ;
366+ unsigned partitionedLoopIdx = 0 ;
367+ for (auto loopIdx : llvm::seq<unsigned >(0 , numLoops)) {
368+ if (!tilingResult.tiledLoops .test (loopIdx)) {
369+ continue ;
370+ }
371+ distributionMethods[loopIdx] =
372+ procInfo[partitionedLoopIdx++].distributionMethod ;
365373 }
366374 }
367375
@@ -443,7 +451,8 @@ static SmallVector<Operation *> getAllFusableProducers(TilingInterface op) {
443451 worklist.pop_front ();
444452 for (OpOperand &operand : currOp->getOpOperands ()) {
445453 Operation *definingOp = operand.get ().getDefiningOp ();
446- auto tilingInterfaceProducer = dyn_cast<TilingInterface>(definingOp);
454+ auto tilingInterfaceProducer =
455+ dyn_cast_or_null<TilingInterface>(definingOp);
447456 if (!tilingInterfaceProducer || isa<tensor::PadOp>(definingOp) ||
448457 producers.count (tilingInterfaceProducer)) {
449458 continue ;
0 commit comments