@@ -324,7 +324,27 @@ struct LinalgOpTilingInterface
324324// External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
325325// ===----------------------------------------------------------------------===//
326326
327- // / External model implementation of PartialReductionInterface for LinalgOps.
327+ // / Return an AffineMap for a partial result for the given result number,
328+ // / assuming the partial tiling strategy is outer-reduction loop +
329+ // / inner-parallel tile. The returned AffineMap can be used as the replacement
330+ // / AffineMap for the inner-parallel tile linalg op for the given result number.
331+ // /
332+ // / The new AffineMap is the old AffineMap with reduction dimensions appended
333+ // / at end.
334+ static AffineMap getPartialResultAffineMap (LinalgOp linalgOp,
335+ ArrayRef<int > reductionDims,
336+ unsigned resultNumber) {
337+ AffineMap map =
338+ linalgOp.getMatchingIndexingMap (linalgOp.getDpsInitOperand (resultNumber));
339+ for (int redPos : reductionDims) {
340+ map = map.insertResult (getAffineDimExpr (redPos, linalgOp.getContext ()),
341+ map.getNumResults ());
342+ }
343+ return map;
344+ }
345+
346+ // / External model implementation of PartialReductionInterface for
347+ // / LinalgOps.
328348template <typename LinalgOpTy>
329349struct LinalgOpPartialReductionInterface
330350 : public PartialReductionOpInterface::ExternalModel<
@@ -338,11 +358,24 @@ struct LinalgOpPartialReductionInterface
338358 if (linalgOp.hasPureBufferSemantics ())
339359 return op->emitOpError (" expected operation to have tensor semantics" );
340360
361+ // LinalgOp implements TilingInterface.
362+ auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation ());
363+ SmallVector<OpFoldResult> shape =
364+ llvm::map_to_vector (tilingInterfaceOp.getIterationDomain (b),
365+ [](Range x) { return x.size ; });
366+
367+ SmallVector<OpFoldResult> tiledShape;
368+ for (auto [tileSize, dimSize] : llvm::zip_equal (sizes, shape)) {
369+ if (isZeroIndex (tileSize)) {
370+ tiledShape.push_back (dimSize);
371+ } else {
372+ tiledShape.push_back (tileSize);
373+ }
374+ }
375+
341376 SmallVector<Value> inits;
342377 for (int initIdx = 0 , e = linalgOp.getNumDpsInits (); initIdx < e;
343378 ++initIdx) {
344- // Insert the new parallel dimension based on the index of the reduction
345- // loops. This could be controlled by user for more flexibility.
346379 SmallVector<Operation *, 4 > combinerOps;
347380 if (!matchReduction (linalgOp.getRegionOutputArgs (), initIdx,
348381 combinerOps) ||
@@ -355,33 +388,19 @@ struct LinalgOpPartialReductionInterface
355388 return op->emitOpError (
356389 " Failed to get an identity value for the reduction operation." );
357390
358- ArrayRef<int64_t > oldShape =
359- linalgOp.getShape (linalgOp.getDpsInitOperand (initIdx));
360-
361- // Calculate the new shape, we insert the new dimensions based on the
362- // index of the reduction dimensions.
363- SmallVector<int64_t > newOutputShape;
364- SmallVector<Value> dynamicDims;
365- int64_t currReductionDims = 0 ;
366- DenseSet<int > reductionDimsSet (reductionDims.begin (),
367- reductionDims.end ());
368- for (int64_t idx :
369- llvm::seq<int64_t >(0 , oldShape.size () + reductionDims.size ())) {
370- if (reductionDimsSet.contains (idx)) {
371- dispatchIndexOpFoldResults (sizes[idx], dynamicDims, newOutputShape);
372- currReductionDims++;
373- continue ;
374- }
375- int64_t oldIdx = idx - currReductionDims;
376- int64_t dim = oldShape[oldIdx];
377- newOutputShape.push_back (dim);
378- if (ShapedType::isDynamic (dim))
379- dynamicDims.push_back (b.create <tensor::DimOp>(
380- loc, linalgOp.getDpsInitOperand (initIdx)->get (), oldIdx));
391+ // Append the new partial result dimensions.
392+ AffineMap partialMap =
393+ getPartialResultAffineMap (linalgOp, reductionDims, initIdx);
394+ SmallVector<OpFoldResult> partialResultShape;
395+ for (AffineExpr dimExpr : partialMap.getResults ()) {
396+ auto dim = cast<AffineDimExpr>(dimExpr);
397+ partialResultShape.push_back (tiledShape[dim.getPosition ()]);
381398 }
382- Value emptyTensor = b.create <tensor::EmptyOp>(
383- loc, newOutputShape,
384- linalgOp.getRegionOutputArgs ()[initIdx].getType (), dynamicDims);
399+
400+ Type elType =
401+ getElementTypeOrSelf (linalgOp->getResult (initIdx).getType ());
402+ Value emptyTensor =
403+ b.create <tensor::EmptyOp>(loc, partialResultShape, elType);
385404 Value constantOp = b.create <arith::ConstantOp>(loc, *identity);
386405 auto identityTensor =
387406 b.create <linalg::FillOp>(loc, constantOp, emptyTensor);
@@ -407,11 +426,7 @@ struct LinalgOpPartialReductionInterface
407426 // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
408427 // this with a for range loop when we have it.
409428 AffineMap newMap =
410- linalgOp.getMatchingIndexingMap (linalgOp.getDpsInitOperand (idx));
411- for (int redPos : reductionDims) {
412- newMap = newMap.insertResult (b.getAffineDimExpr (redPos),
413- newMap.getNumResults ());
414- }
429+ getPartialResultAffineMap (linalgOp, reductionDims, idx);
415430 newInitMaps.push_back (newMap);
416431 }
417432
@@ -476,29 +491,75 @@ struct LinalgOpPartialReductionInterface
476491 Location loc, ValueRange partialReduce,
477492 ArrayRef<int > reductionDims) const {
478493 auto linalgOp = cast<LinalgOp>(op);
479- SmallVector<int64_t > reductionDimsInt64 (reductionDims);
480- auto reduction = b.create <linalg::ReduceOp>(
481- loc, partialReduce, linalgOp.getDpsInits (), reductionDimsInt64,
482- [&linalgOp](OpBuilder &b, Location loc, ValueRange inputs) {
483- int64_t numInits = linalgOp.getNumDpsInits ();
484- SmallVector<Value> yieldedValues;
485- for (int idx : llvm::seq<int >(0 , numInits)) {
494+
495+ // Permute the reduction dims as permuted by the partial result map.
496+
497+ int64_t numInits = linalgOp.getNumDpsInits ();
498+ SmallVector<Operation *> mergeOperations;
499+ SmallVector<Value> replacements;
500+ for (int idx : llvm::seq (numInits)) {
501+ // linalg.reduce's iteration space is the tiled result's iteration space
502+ // (and not the tiled operation's iteration space). To account for this,
503+ // permute the reduction dimensions based on the partial result map of the
504+ // tiled result.
505+ AffineMap partialMap =
506+ getPartialResultAffineMap (linalgOp, reductionDims, idx);
507+ SmallVector<int64_t > partialReductionDims;
508+ for (auto [resultNum, dimExpr] :
509+ llvm::enumerate (partialMap.getResults ())) {
510+ unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition ();
511+ if (llvm::find (reductionDims, dim) != reductionDims.end ()) {
512+ partialReductionDims.push_back (resultNum);
513+ }
514+ }
515+
516+ Value partialResult = partialReduce[idx];
517+ Value init = linalgOp.getDpsInits ()[idx];
518+
519+ auto reduction = b.create <linalg::ReduceOp>(
520+ loc, partialResult, init, partialReductionDims,
521+ [&linalgOp, &idx](OpBuilder &b, Location loc, ValueRange inputs) {
486522 // Get the combiner op.
487523 SmallVector<Operation *, 4 > combinerOps;
488524 matchReduction (linalgOp.getRegionOutputArgs (), idx, combinerOps);
489525 Operation *clonedReductionOp = b.clone (*combinerOps[0 ]);
490526 // Combine the input at idx and output at numInits + idx.
491- clonedReductionOp->setOperand (0 , inputs[idx]);
492- clonedReductionOp->setOperand (1 , inputs[numInits + idx]);
493- // Yield.
494- yieldedValues.push_back (clonedReductionOp->getResult (0 ));
495- }
496- b.create <linalg::YieldOp>(loc, yieldedValues);
497- });
498- return MergeResult{
499- {reduction.getOperation ()},
500- llvm::map_to_vector (reduction->getResults (),
501- [](OpResult r) -> Value { return r; })};
527+ clonedReductionOp->setOperand (0 , inputs[0 ]);
528+ clonedReductionOp->setOperand (1 , inputs[1 ]);
529+ b.create <linalg::YieldOp>(loc, clonedReductionOp->getResult (0 ));
530+ });
531+
532+ mergeOperations.push_back (reduction);
533+ replacements.push_back (reduction->getResult (0 ));
534+ }
535+
536+ return MergeResult{mergeOperations, replacements};
537+ }
538+
539+ LogicalResult getPartialResultTilePosition (
540+ Operation *op, OpBuilder &b, unsigned resultNumber,
541+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
542+ SmallVector<OpFoldResult> &resultOffsets,
543+ SmallVector<OpFoldResult> &resultSizes,
544+ ArrayRef<int > reductionDims) const {
545+ auto linalgOp = cast<LinalgOp>(op);
546+
547+ AffineMap partialMap =
548+ getPartialResultAffineMap (linalgOp, reductionDims, resultNumber);
549+ for (AffineExpr dimExpr : partialMap.getResults ()) {
550+ unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition ();
551+ resultSizes.push_back (sizes[dim]);
552+
553+ if (llvm::find (reductionDims, dim) != reductionDims.end ()) {
554+ // Reduction dims are reduced, and are always outputed in the same
555+ // place. So use offset 0 for them.
556+ resultOffsets.push_back (b.getIndexAttr (0 ));
557+ } else {
558+ resultOffsets.push_back (offsets[dim]);
559+ }
560+ }
561+
562+ return success ();
502563 }
503564};
504565
0 commit comments