@@ -324,7 +324,20 @@ struct LinalgOpTilingInterface
324324// External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
325325// ===----------------------------------------------------------------------===//
326326
327- // / External model implementation of PartialReductionInterface for LinalgOps.
327+ static AffineMap getPartialResultAffineMap (LinalgOp linalgOp,
328+ ArrayRef<int > reductionDims,
329+ unsigned resultNumber) {
330+ AffineMap map =
331+ linalgOp.getMatchingIndexingMap (linalgOp.getDpsInitOperand (resultNumber));
332+ for (int redPos : reductionDims) {
333+ map = map.insertResult (getAffineDimExpr (redPos, linalgOp.getContext ()),
334+ map.getNumResults ());
335+ }
336+ return map;
337+ }
338+
339+ // / External model implementation of PartialReductionInterface for
340+ // / LinalgOps.
328341template <typename LinalgOpTy>
329342struct LinalgOpPartialReductionInterface
330343 : public PartialReductionOpInterface::ExternalModel<
@@ -338,11 +351,24 @@ struct LinalgOpPartialReductionInterface
338351 if (linalgOp.hasPureBufferSemantics ())
339352 return op->emitOpError (" expected operation to have tensor semantics" );
340353
354+ // LinalgOp implements TilingInterface.
355+ auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation ());
356+ SmallVector<OpFoldResult> shape =
357+ llvm::map_to_vector (tilingInterfaceOp.getIterationDomain (b),
358+ [](Range x) { return x.size ; });
359+
360+ SmallVector<OpFoldResult> tiledShape;
361+ for (auto [tileSize, dimSize] : llvm::zip_equal (sizes, shape)) {
362+ if (isZeroIndex (tileSize)) {
363+ tiledShape.push_back (dimSize);
364+ } else {
365+ tiledShape.push_back (tileSize);
366+ }
367+ }
368+
341369 SmallVector<Value> inits;
342370 for (int initIdx = 0 , e = linalgOp.getNumDpsInits (); initIdx < e;
343371 ++initIdx) {
344- // Insert the new parallel dimension based on the index of the reduction
345- // loops. This could be controlled by user for more flexibility.
346372 SmallVector<Operation *, 4 > combinerOps;
347373 if (!matchReduction (linalgOp.getRegionOutputArgs (), initIdx,
348374 combinerOps) ||
@@ -355,33 +381,19 @@ struct LinalgOpPartialReductionInterface
355381 return op->emitOpError (
356382 " Failed to get an identity value for the reduction operation." );
357383
358- ArrayRef<int64_t > oldShape =
359- linalgOp.getShape (linalgOp.getDpsInitOperand (initIdx));
360-
361- // Calculate the new shape, we insert the new dimensions based on the
362- // index of the reduction dimensions.
363- SmallVector<int64_t > newOutputShape;
364- SmallVector<Value> dynamicDims;
365- int64_t currReductionDims = 0 ;
366- DenseSet<int > reductionDimsSet (reductionDims.begin (),
367- reductionDims.end ());
368- for (int64_t idx :
369- llvm::seq<int64_t >(0 , oldShape.size () + reductionDims.size ())) {
370- if (reductionDimsSet.contains (idx)) {
371- dispatchIndexOpFoldResults (sizes[idx], dynamicDims, newOutputShape);
372- currReductionDims++;
373- continue ;
374- }
375- int64_t oldIdx = idx - currReductionDims;
376- int64_t dim = oldShape[oldIdx];
377- newOutputShape.push_back (dim);
378- if (ShapedType::isDynamic (dim))
379- dynamicDims.push_back (b.create <tensor::DimOp>(
380- loc, linalgOp.getDpsInitOperand (initIdx)->get (), oldIdx));
384+ // Append the new partial result dimensions.
385+ AffineMap partialMap =
386+ getPartialResultAffineMap (linalgOp, reductionDims, initIdx);
387+ SmallVector<OpFoldResult> partialResultShape;
388+ for (AffineExpr dimExpr : partialMap.getResults ()) {
389+ auto dim = cast<AffineDimExpr>(dimExpr);
390+ partialResultShape.push_back (tiledShape[dim.getPosition ()]);
381391 }
382- Value emptyTensor = b.create <tensor::EmptyOp>(
383- loc, newOutputShape,
384- linalgOp.getRegionOutputArgs ()[initIdx].getType (), dynamicDims);
392+
393+ Type elType =
394+ getElementTypeOrSelf (linalgOp->getResult (initIdx).getType ());
395+ Value emptyTensor =
396+ b.create <tensor::EmptyOp>(loc, partialResultShape, elType);
385397 Value constantOp = b.create <arith::ConstantOp>(loc, *identity);
386398 auto identityTensor =
387399 b.create <linalg::FillOp>(loc, constantOp, emptyTensor);
@@ -407,11 +419,7 @@ struct LinalgOpPartialReductionInterface
407419 // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
408420 // this with a for range loop when we have it.
409421 AffineMap newMap =
410- linalgOp.getMatchingIndexingMap (linalgOp.getDpsInitOperand (idx));
411- for (int redPos : reductionDims) {
412- newMap = newMap.insertResult (b.getAffineDimExpr (redPos),
413- newMap.getNumResults ());
414- }
422+ getPartialResultAffineMap (linalgOp, reductionDims, idx);
415423 newInitMaps.push_back (newMap);
416424 }
417425
@@ -476,29 +484,74 @@ struct LinalgOpPartialReductionInterface
476484 Location loc, ValueRange partialReduce,
477485 ArrayRef<int > reductionDims) const {
478486 auto linalgOp = cast<LinalgOp>(op);
479- SmallVector<int64_t > reductionDimsInt64 (reductionDims);
480- auto reduction = b.create <linalg::ReduceOp>(
481- loc, partialReduce, linalgOp.getDpsInits (), reductionDimsInt64,
482- [&linalgOp](OpBuilder &b, Location loc, ValueRange inputs) {
483- int64_t numInits = linalgOp.getNumDpsInits ();
484- SmallVector<Value> yieldedValues;
485- for (int idx : llvm::seq<int >(0 , numInits)) {
487+
488+ // Permute the reduction dims as permuted by the partial result map.
489+
490+ int64_t numInits = linalgOp.getNumDpsInits ();
491+ SmallVector<Operation *> mergeOperations;
492+ SmallVector<Value> replacements;
493+ for (int idx : llvm::seq (numInits)) {
494+ // linalg.reduce's iteration space is the result's iteration space (and
495+ // not the operations iteration space). To account for this, permute the
496+ // reduction dimensions based on the partial result map.
497+ AffineMap partialMap =
498+ getPartialResultAffineMap (linalgOp, reductionDims, idx);
499+ SmallVector<int64_t > partialReductionDims;
500+ for (auto [resultNum, dimExpr] :
501+ llvm::enumerate (partialMap.getResults ())) {
502+ unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition ();
503+ if (llvm::find (reductionDims, dim) != reductionDims.end ()) {
504+ partialReductionDims.push_back (resultNum);
505+ }
506+ }
507+
508+ Value partialResult = partialReduce[idx];
509+ Value init = linalgOp.getDpsInits ()[idx];
510+
511+ auto reduction = b.create <linalg::ReduceOp>(
512+ loc, partialResult, init, partialReductionDims,
513+ [&linalgOp, &idx](OpBuilder &b, Location loc, ValueRange inputs) {
486514 // Get the combiner op.
487515 SmallVector<Operation *, 4 > combinerOps;
488516 matchReduction (linalgOp.getRegionOutputArgs (), idx, combinerOps);
489517 Operation *clonedReductionOp = b.clone (*combinerOps[0 ]);
490518 // Combine the input at idx and output at numInits + idx.
491- clonedReductionOp->setOperand (0 , inputs[idx]);
492- clonedReductionOp->setOperand (1 , inputs[numInits + idx]);
493- // Yield.
494- yieldedValues.push_back (clonedReductionOp->getResult (0 ));
495- }
496- b.create <linalg::YieldOp>(loc, yieldedValues);
497- });
498- return MergeResult{
499- {reduction.getOperation ()},
500- llvm::map_to_vector (reduction->getResults (),
501- [](OpResult r) -> Value { return r; })};
519+ clonedReductionOp->setOperand (0 , inputs[0 ]);
520+ clonedReductionOp->setOperand (1 , inputs[1 ]);
521+ b.create <linalg::YieldOp>(loc, clonedReductionOp->getResult (0 ));
522+ });
523+
524+ mergeOperations.push_back (reduction);
525+ replacements.push_back (reduction->getResult (0 ));
526+ }
527+
528+ return MergeResult{mergeOperations, replacements};
529+ }
530+
531+ LogicalResult getPartialResultTilePosition (
532+ Operation *op, OpBuilder &b, unsigned resultNumber,
533+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
534+ SmallVector<OpFoldResult> &resultOffsets,
535+ SmallVector<OpFoldResult> &resultSizes,
536+ ArrayRef<int > reductionDims) const {
537+ auto linalgOp = cast<LinalgOp>(op);
538+
539+ AffineMap partialMap =
540+ getPartialResultAffineMap (linalgOp, reductionDims, resultNumber);
541+ for (AffineExpr dimExpr : partialMap.getResults ()) {
542+ unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition ();
543+ resultSizes.push_back (sizes[dim]);
544+
545+ if (llvm::find (reductionDims, dim) != reductionDims.end ()) {
546+ // Reduction dims are reduced, and are always outputed in the same
547+ // place. So use offset 0 for them.
548+ resultOffsets.push_back (b.getIndexAttr (0 ));
549+ } else {
550+ resultOffsets.push_back (offsets[dim]);
551+ }
552+ }
553+
554+ return success ();
502555 }
503556};
504557
0 commit comments