@@ -328,6 +328,17 @@ struct LinalgOpTilingInterface
328328// External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
329329// ===----------------------------------------------------------------------===//
330330
331+ // / In a given set vector, get the position of a particular element.
332+ std::optional<int > getPositionIn (const llvm::SetVector<unsigned > &reductionDims,
333+ unsigned value) {
334+ for (auto [index, reductionDim] : llvm::enumerate (reductionDims)) {
335+ if (reductionDim == value) {
336+ return index;
337+ }
338+ }
339+ return std::nullopt ;
340+ }
341+
331342// / Return an AffineMaps to use for the `outs` operands of the linalg op
332343// / generated for partial results. The new AffineMap is the AffineMap of the
333344// / untiled op with reduction dimensions appended at end in order in which they
@@ -348,28 +359,79 @@ getPartialResultAffineMaps(LinalgOp linalgOp,
348359 return partialReductionMaps;
349360}
350361
351- // / Return the slice of the `initValue` to use as input to the partial reduction
352- // / op generated.
353- static Operation *getInitSliceForOuterReduction (
354- OpBuilder &b, Location loc, Value initValue, ArrayRef<OpFoldResult> offsets,
362+ struct InitSliceInfo {
363+ SmallVector<int64_t > resultShape;
364+ SmallVector<OpFoldResult> offsets;
365+ SmallVector<OpFoldResult> sizes;
366+ SmallVector<OpFoldResult> strides;
367+ };
368+
369+ // / Return the result type, offsets, sizes and strides of the slice of the
370+ // / `initValue` to use as input to the partial reduction op generated with
371+ // / outer reduction strategy.
372+ static InitSliceInfo getInitSliceInfoForOuterReduction (
373+ MLIRContext *context, ArrayRef<OpFoldResult> offsets,
355374 ArrayRef<OpFoldResult> sizes, const SetVector<unsigned > &reductionDims,
356375 AffineMap partialReductionMap) {
357376 int64_t initRank = partialReductionMap.getNumResults ();
358377 SmallVector<OpFoldResult> initOffsets, initSizes;
359- SmallVector<OpFoldResult> initStrides (initRank, b.getIndexAttr (1 ));
378+ Attribute zero = IntegerAttr::get (IndexType::get (context), 0 );
379+ Attribute one = IntegerAttr::get (IndexType::get (context), 1 );
380+ SmallVector<OpFoldResult> initStrides (initRank, one);
360381 for (AffineExpr dimExpr : partialReductionMap.getResults ()) {
361382 unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition ();
362383 if (reductionDims.contains (dim)) {
363- initOffsets.push_back (b. getIndexAttr ( 0 ) );
384+ initOffsets.push_back (zero );
364385 } else {
365386 initOffsets.push_back (offsets[dim]);
366387 }
367388 initSizes.push_back (sizes[dim]);
368389 }
369- // TODO: Use SubsetExtractOpInterface here once available.
370- auto extractSlice = b.create <tensor::ExtractSliceOp>(
371- loc, initValue, initOffsets, initSizes, initStrides);
372- return extractSlice;
390+ SmallVector<int64_t > resultShape;
391+ std::tie (resultShape, std::ignore) = decomposeMixedValues (initSizes);
392+ return {resultShape, initOffsets, initSizes, initStrides};
393+ }
394+
395+ // / Return the result type, offsets, sizes and strides of the slice of the
396+ // / `initValue` to use as input to the partial reduction op generated with
397+ // / outer parallel strategy.
398+ static InitSliceInfo getInitSliceInfoForOuterParallel (
399+ MLIRContext *context, ValueRange ivs, ArrayRef<OpFoldResult> offsets,
400+ ArrayRef<OpFoldResult> sizes, const SetVector<unsigned > &reductionDims,
401+ AffineMap partialReductionMap) {
402+ int64_t initRank = partialReductionMap.getNumResults ();
403+ SmallVector<OpFoldResult> initOffsets, initSizes;
404+ Attribute one = IntegerAttr::get (IndexType::get (context), 1 );
405+ SmallVector<OpFoldResult> initStrides (initRank, one);
406+ SmallVector<OpFoldResult> resultShape;
407+ for (AffineExpr dimExpr : partialReductionMap.getResults ()) {
408+ unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition ();
409+ if (std::optional<int > dimPos = getPositionIn (reductionDims, dim)) {
410+ initOffsets.push_back (ivs[dimPos.value ()]);
411+ initSizes.push_back (one);
412+ } else {
413+ initOffsets.push_back (offsets[dim]);
414+ initSizes.push_back (sizes[dim]);
415+ resultShape.push_back (sizes[dim]);
416+ }
417+ }
418+ SmallVector<int64_t > staticShapes;
419+ std::tie (staticShapes, std::ignore) = decomposeMixedValues (resultShape);
420+ return {staticShapes, initOffsets, initSizes, initStrides};
421+ }
422+
423+ static InitSliceInfo getInitSliceInfo (
424+ MLIRContext *context, ReductionTilingStrategy strategy, ValueRange ivs,
425+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
426+ const SetVector<unsigned > &reductionDims, AffineMap partialReductionMap) {
427+ if (strategy == ReductionTilingStrategy::PartialReductionOuterReduction) {
428+ return getInitSliceInfoForOuterReduction (
429+ context, offsets, sizes, reductionDims, partialReductionMap);
430+ }
431+ assert (strategy == ReductionTilingStrategy::PartialReductionOuterParallel &&
432+ " unexpected ReductionTilingStrategy" );
433+ return getInitSliceInfoForOuterParallel (context, ivs, offsets, sizes,
434+ reductionDims, partialReductionMap);
373435}
374436
375437// / External model implementation of PartialReductionInterface for
@@ -439,18 +501,11 @@ struct LinalgOpPartialReductionInterface
439501 return inits;
440502 }
441503
442- FailureOr<TilingResult>
443- tileToPartialReduction (Operation *op, OpBuilder &b, Location loc,
444- ReductionTilingStrategy tilingStrategy,
445- ValueRange init, ArrayRef<OpFoldResult> offsets,
446- ArrayRef<OpFoldResult> sizes,
447- const SetVector<unsigned > &reductionDims) const {
448- if (tilingStrategy !=
449- ReductionTilingStrategy::PartialReductionOuterReduction) {
450- // TODO: Add support for `PartialReductionOuterParallel` strategy.
451- return op->emitOpError (" unsupported partial reduction tiling with "
452- " `PartialReductionOuterParallel` strategy" );
453- }
504+ FailureOr<TilingResult> tileToPartialReduction (
505+ Operation *op, OpBuilder &b, Location loc,
506+ ReductionTilingStrategy tilingStrategy, ValueRange init, ValueRange ivs,
507+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
508+ const SetVector<unsigned > &reductionDims) const {
454509 OpBuilder::InsertionGuard guard (b);
455510 auto linalgOp = cast<LinalgOp>(op);
456511
@@ -459,7 +514,16 @@ struct LinalgOpPartialReductionInterface
459514
460515 // Step 1. Extend init maps to have reduction dimension dims, since we
461516 // are converting them to parallel dimensions.
462- SmallVector<AffineMap> newInitMaps = partialReductionMaps;
517+ SmallVector<AffineMap> newInitMaps;
518+ if (tilingStrategy ==
519+ ReductionTilingStrategy::PartialReductionOuterReduction) {
520+ newInitMaps = llvm::to_vector (partialReductionMaps);
521+ } else {
522+ newInitMaps = llvm::map_to_vector (
523+ linalgOp.getDpsInitsMutable (), [&](OpOperand &opOperand) {
524+ return linalgOp.getMatchingIndexingMap (&opOperand);
525+ });
526+ }
463527
464528 // Step 2a: Extract a slice of the input operands.
465529 SmallVector<Value> tiledInputs = makeTiledShapes (
@@ -473,10 +537,17 @@ struct LinalgOpPartialReductionInterface
473537 SmallVector<Value, 1 > tiledInits;
474538 for (auto [partialReductionMap, valueToTile] :
475539 llvm::zip_equal (partialReductionMaps, init)) {
476- Operation *sliceOp =
477- getInitSliceForOuterReduction (b, loc, valueToTile, offsets, sizes,
478- reductionDims, partialReductionMap);
479- tiledInits.push_back (sliceOp->getResult (0 ));
540+ InitSliceInfo sliceInfo =
541+ getInitSliceInfo (b.getContext (), tilingStrategy, ivs, offsets, sizes,
542+ reductionDims, partialReductionMap);
543+ auto valueToTileType = cast<RankedTensorType>(valueToTile.getType ());
544+ RankedTensorType sliceResultType = RankedTensorType::get (
545+ sliceInfo.resultShape , valueToTileType.getElementType (),
546+ valueToTileType.getEncoding ());
547+ auto sliceOp = b.create <tensor::ExtractSliceOp>(
548+ loc, sliceResultType, valueToTile, sliceInfo.offsets , sliceInfo.sizes ,
549+ sliceInfo.strides );
550+ tiledInits.push_back (sliceOp.getResult ());
480551 generatedSlices.push_back (sliceOp);
481552 }
482553
@@ -491,19 +562,31 @@ struct LinalgOpPartialReductionInterface
491562 // Step 3. Change the reduction dim iterator types.
492563 SmallVector<utils::IteratorType> newIteratorTypes =
493564 linalgOp.getIteratorTypesArray ();
494- for (int dim : reductionDims)
495- newIteratorTypes[dim] = utils::IteratorType::parallel;
565+ if (tilingStrategy ==
566+ ReductionTilingStrategy::PartialReductionOuterReduction) {
567+ for (int dim : reductionDims)
568+ newIteratorTypes[dim] = utils::IteratorType::parallel;
569+ }
496570
497571 // Step 4. Create the new generic op.
572+ Operation *partialReductionOp;
498573 auto resultTypes = ValueRange (tiledInits).getTypes ();
499- auto genericOp = b.create <GenericOp>(loc, resultTypes, tiledInputs,
500- tiledInits, newMaps, newIteratorTypes);
501- IRMapping mapping;
502- op->getRegion (0 ).cloneInto (&genericOp.getRegion (),
503- genericOp.getRegion ().begin (), mapping);
574+ if (tilingStrategy ==
575+ ReductionTilingStrategy::PartialReductionOuterReduction) {
576+ auto genericOp = b.create <GenericOp>(
577+ loc, resultTypes, tiledInputs, tiledInits, newMaps, newIteratorTypes);
578+ IRMapping mapping;
579+ op->getRegion (0 ).cloneInto (&genericOp.getRegion (),
580+ genericOp.getRegion ().begin (), mapping);
581+ partialReductionOp = genericOp.getOperation ();
582+ } else {
583+ SmallVector<Value> operands = std::move (tiledInputs);
584+ llvm::append_range (operands, tiledInits);
585+ partialReductionOp = mlir::clone (b, op, resultTypes, operands);
586+ }
504587 return TilingResult{
505- {genericOp. getOperation () },
506- llvm::map_to_vector (genericOp ->getResults (),
588+ {partialReductionOp },
589+ llvm::map_to_vector (partialReductionOp ->getResults (),
507590 [](OpResult r) -> Value { return r; }),
508591 generatedSlices};
509592 }
@@ -557,27 +640,19 @@ struct LinalgOpPartialReductionInterface
557640 }
558641
559642 LogicalResult getPartialResultTilePosition (
560- Operation *op, OpBuilder &b, unsigned resultNumber,
561- ArrayRef<OpFoldResult> offsets , ArrayRef<OpFoldResult> sizes ,
562- const SetVector<unsigned > &reductionDims,
643+ Operation *op, OpBuilder &b, unsigned resultNumber, ValueRange ivs,
644+ ReductionTilingStrategy tilingStrategy , ArrayRef<OpFoldResult> offsets ,
645+ ArrayRef<OpFoldResult> sizes, const SetVector<unsigned > &reductionDims,
563646 SmallVector<OpFoldResult> &resultOffsets,
564647 SmallVector<OpFoldResult> &resultSizes) const {
565648 auto linalgOp = cast<LinalgOp>(op);
566649 SmallVector<AffineMap> partialReductionMaps =
567650 getPartialResultAffineMaps (linalgOp, reductionDims);
568-
569- for (AffineExpr dimExpr : partialReductionMaps[resultNumber].getResults ()) {
570- unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition ();
571- resultSizes.push_back (sizes[dim]);
572-
573- if (llvm::is_contained (reductionDims, dim)) {
574- // Reduction dims are reduced, and are always outputed in the same
575- // place. So use offset 0 for them.
576- resultOffsets.push_back (b.getIndexAttr (0 ));
577- } else {
578- resultOffsets.push_back (offsets[dim]);
579- }
580- }
651+ InitSliceInfo sliceInfo =
652+ getInitSliceInfo (b.getContext (), tilingStrategy, ivs, offsets, sizes,
653+ reductionDims, partialReductionMaps[resultNumber]);
654+ std::swap (resultOffsets, sliceInfo.offsets );
655+ std::swap (resultSizes, sliceInfo.sizes );
581656
582657 return success ();
583658 }
0 commit comments