From 030c1b2366845b240a103202a4c3b677c0bba10d Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Sun, 8 Jun 2025 16:22:00 -0700 Subject: [PATCH] [mlir][PartialReductionTilingInterface] Add support for `ReductionTilingStrategy::PartialReductionOuterParallel` in `tileUsingSCF`. Following up from https://github.com/llvm/llvm-project/pull/143467, this PR adds support for `ReductionTilingStrategy::PartialReductionOuterParallel` to `tileUsingSCF`. The implementation of `PartialReductionTilingInterface` for `Linalg` ops has been updated to support this strategy as well. This makes the `tileUsingSCF` come on par with `linalg::tileReductionUsingForall` which will be deprecated subsequently. Changes summary - `PartialReductionTilingInterface` changes : - `tileToPartialReduction` method needed to get the induction variables of the generated tile loops. This was needed to keep the generated code similar to `linalg::tileReductionUsingForall`, specifically to create a simplified access for slicing the intermediate partial results tensor when tiled in `num_threads` mode. - `getPartialResultTilePosition` methods needs the induction varialbes for the generated tile loops for the same reason above, and also needs the `tilingStrategy` to be passed in to generate correct code. The tests in `transform-tile-reduction.mlir` testing the `linalg::tileReductionUsingForall` have been moved over to test `scf::tileUsingSCF` with `ReductionTilingStrategy::PartialReductionOuterParallel` strategy. Some of the test that were doing further cyclic distribution of the transformed code from tiling are removed. Those seem like two separate transformation that were merged into one. Ideally that would need to happen when resolving the `scf.forall` rather than during tiling. Signed-off-by: MaheshRavishankar --- .../Linalg/TransformOps/LinalgTransformOps.td | 6 +- .../mlir/Dialect/Utils/StaticValueUtils.h | 2 +- .../mlir/Interfaces/TilingInterface.td | 28 +- .../TransformOps/LinalgTransformOps.cpp | 37 ++- .../Linalg/Transforms/TilingInterfaceImpl.cpp | 196 ++++++++---- .../SCF/Transforms/TileUsingInterface.cpp | 229 ++++++++++---- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 14 +- mlir/lib/Dialect/Utils/StaticValueUtils.cpp | 2 +- .../Linalg/transform-tile-reduction.mlir | 286 ++++++++++-------- 9 files changed, 535 insertions(+), 265 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 38c8734c47381..9d6ce653e285c 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2019,6 +2019,7 @@ def TileReductionUsingForallOp : // TODO: support mixed static-dynamic (see TileUsingForallOp). let arguments = (ins TransformHandleTypeInterface:$target, + DefaultValuedAttr:$reduction_dims, DefaultValuedAttr:$num_threads, DefaultValuedAttr:$tile_sizes, OptionalAttr:$mapping); @@ -2036,10 +2037,11 @@ def TileReductionUsingForallOp : let assemblyFormat = [{ $target + (`reduction_dims` `=` $reduction_dims^)? `by` (`num_threads` `=` $num_threads^)? - (`,` `tile_sizes` `=` $tile_sizes^)? - (`,` `mapping` `=` $mapping^)? + (`tile_sizes` `=` $tile_sizes^)? + (`mapping` `=` $mapping^)? attr-dict `:` functional-type(operands, results) }]; diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h index b37fb55b67931..77c376fb9973a 100644 --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -156,7 +156,7 @@ SmallVector getMixedValues(ArrayRef staticValues, /// corresponding pair of arrays. This is the inverse function of /// `getMixedValues`. std::pair, SmallVector> -decomposeMixedValues(const SmallVectorImpl &mixedValues); +decomposeMixedValues(ArrayRef mixedValues); /// Helper to sort `values` according to matching `keys`. SmallVector diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td index 43a27e1cb6cdf..0de37338c95e4 100644 --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -367,15 +367,20 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface", [TilingInterface]> { let description = [{ Interface for allowing operations to expose information needed to - tile reductions using partial reduction followed by merge. This is - complementary to TilingInterface to tile reductions. + tile reductions using partial reduction followed by merge. This + extends the `TilingInterface` to allow splitting a reduction + dimension into a parallel dimension and reduction dimension. + The materialized inter-tile loop could either be the reduction dimension + (i.e. `ReductionTilingStrategy::PartialReductionOuterReduction`) or + the parallel dimension (i.e + `ReductionTilingStrategy::PartialReductionOuterReduction`). }]; let cppNamespace = "::mlir"; let methods = [ InterfaceMethod< /*desc=*/[{ Method to generate a tensor initalized with the identity value of the - operation reduction. The tensor shape is equal to operation result + reduction operator. The tensor shape is equal to operation result shape with new dimension for each non zero tile size. }], /*retType=*/"::mlir::FailureOr>", @@ -383,7 +388,7 @@ def PartialReductionOpInterface : /*args=*/(ins "::mlir::OpBuilder &":$b, "Location":$loc, - "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes, + "::mlir::ArrayRef<::mlir::OpFoldResult>":$tileSizes, "const ::mlir::SetVector &":$reductionDims), /*methodBody=*/"", /*defaultImplementation=*/[{ @@ -396,6 +401,11 @@ def PartialReductionOpInterface : reduction dimension are converted to parallel dimensions with a size less or equal to the tile size. This is meant to be used with `mergeReductions` method which will combine the partial reductions. + The method recieves the `offset` and `sizes` for all iteration space + dimensions, as well as the iteration number of the tiled reduction + dimensions (which is the induction variable of the inter-tile loop + for the reduction dimension divided by the step of the loop) in + `splitReductionIvs`. }], /*retType=*/"::mlir::FailureOr", /*methodName=*/"tileToPartialReduction", @@ -406,7 +416,8 @@ def PartialReductionOpInterface : "ValueRange":$init, "::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets, "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes, - "const ::llvm::SetVector &":$reductionDims), + "const ::llvm::SetVector &":$reductionDims, + "::mlir::ArrayRef<::mlir::OpFoldResult>":$splitReductionIvs), /*methodBody=*/"", /*defaultImplementation=*/[{ return failure(); @@ -436,15 +447,22 @@ def PartialReductionOpInterface : the tiled operation. This is same as TilingInterface:::getResultTilePosition, but determines the result tile position for partial reduction. + The method recieves the `offset` and `sizes` for all iteration space + dimensions, as well as the iteration number of the tiled reduction + dimensions (which is the induction variable of the inter-tile loop + for the reduction dimension divided by the tile size specified) in + `splitReductionIvs`. }], /*retType=*/"::llvm::LogicalResult", /*methodName=*/"getPartialResultTilePosition", /*args=*/(ins "::mlir::OpBuilder &":$b, "unsigned":$resultNumber, + "ReductionTilingStrategy":$tilingStrategy, "::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets, "::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes, "const ::mlir::SetVector &":$reductionDims, + "::mlir::ArrayRef<::mlir::OpFoldResult>":$splitReductionIvs, "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets, "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes), /*methodBody=*/"", diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index f2b7b34256847..2355edea2df6c 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3022,6 +3022,7 @@ void transform::TileReductionUsingForallOp::build( build(builder, result, /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy}, /*target=*/target, + /*reduction_dims=*/{}, /*num_threads=*/staticNumThreadsAttr, /*tile_sizes=*/staticTileSizesAttr, /*mapping=*/mapping); @@ -3036,23 +3037,45 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne( getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads())); SmallVector tileSizes = getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())); - FailureOr result = - linalg::tileReductionUsingForall( - rewriter, cast(target.getOperation()), - numThreads, tileSizes, getMapping()); + + scf::SCFTilingOptions options; + options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); + options.setReductionTilingStrategy( + ReductionTilingStrategy::PartialReductionOuterParallel); + if (!getNumThreads().empty()) { + options.setNumThreads(numThreads); + } else { + options.setTileSizes(tileSizes); + } + if (auto mapping = getMapping()) { + options.setMapping(mapping.value().getValue()); + } + SmallVector reductionDims = + extractFromIntegerArrayAttr(getReductionDims()); + if (reductionDims.empty()) { + for (auto [idx, iteratorType] : + llvm::enumerate(target.getIteratorTypesArray())) { + if (iteratorType == utils::IteratorType::reduction) + reductionDims.push_back(idx); + } + } + options.setReductionDims(reductionDims); + FailureOr result = scf::tileUsingSCF( + rewriter, cast(target.getOperation()), options); if (failed(result)) { auto diag = emitSilenceableError() << "could not tile reduction"; - diag.attachNote(target.getLoc()) << "target operation"; return diag; } + rewriter.replaceOp(target, result->replacements); + for (Value initValue : result->initialValues) results.push_back(initValue.getDefiningOp()); - for (auto parallelTiledOp : result->parallelTiledOps) + for (auto parallelTiledOp : result->tiledOps) results.push_back(parallelTiledOp); for (auto mergeOp : result->mergeOps) results.push_back(mergeOp); - results.push_back(result->loops); + results.push_back(result->loops.front()); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index f649bc49a8fbd..19d484a3bb701 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -328,6 +328,17 @@ struct LinalgOpTilingInterface // External Model for implementing `PartialReductionInterface` for `LinalgOp`s. //===----------------------------------------------------------------------===// +/// In a given set vector, get the position of a particular element. +std::optional getPositionIn(const llvm::SetVector &reductionDims, + unsigned value) { + for (auto [index, reductionDim] : llvm::enumerate(reductionDims)) { + if (reductionDim == value) { + return index; + } + } + return std::nullopt; +} + /// Return an AffineMaps to use for the `outs` operands of the linalg op /// generated for partial results. The new AffineMap is the AffineMap of the /// untiled op with reduction dimensions appended at end in order in which they @@ -348,28 +359,86 @@ getPartialResultAffineMaps(LinalgOp linalgOp, return partialReductionMaps; } -/// Return the slice of the `initValue` to use as input to the partial reduction -/// op generated. -static Operation *getInitSliceForOuterReduction( - OpBuilder &b, Location loc, Value initValue, ArrayRef offsets, +struct InitSliceInfo { + SmallVector resultShape; + SmallVector offsets; + SmallVector sizes; + SmallVector strides; +}; + +/// Return the result shape, offsets, sizes and strides of the slice of the +/// `initValue` to use as the destination of the partial reduction op generated +/// with outer reduction strategy. +static InitSliceInfo getInitSliceInfoForOuterReduction( + MLIRContext *context, ArrayRef offsets, ArrayRef sizes, const SetVector &reductionDims, - AffineMap partialReductionMap) { + ArrayRef splitReductionIvs, AffineMap partialReductionMap) { int64_t initRank = partialReductionMap.getNumResults(); SmallVector initOffsets, initSizes; - SmallVector initStrides(initRank, b.getIndexAttr(1)); + Attribute zero = IntegerAttr::get(IndexType::get(context), 0); + Attribute one = IntegerAttr::get(IndexType::get(context), 1); + SmallVector initStrides(initRank, one); for (AffineExpr dimExpr : partialReductionMap.getResults()) { unsigned dim = cast(dimExpr).getPosition(); if (reductionDims.contains(dim)) { - initOffsets.push_back(b.getIndexAttr(0)); + initOffsets.push_back(zero); } else { initOffsets.push_back(offsets[dim]); } initSizes.push_back(sizes[dim]); } - // TODO: Use SubsetExtractOpInterface here once available. - auto extractSlice = b.create( - loc, initValue, initOffsets, initSizes, initStrides); - return extractSlice; + SmallVector resultShape; + std::tie(resultShape, std::ignore) = decomposeMixedValues(initSizes); + return {resultShape, initOffsets, initSizes, initStrides}; +} + +/// Return the result shape, offsets, sizes and strides of the slice of the +/// `initValue` to use as destination of the partial reduction op generated with +/// outer parallel strategy. +static InitSliceInfo getInitSliceInfoForOuterParallel( + MLIRContext *context, ArrayRef offsets, + ArrayRef sizes, const SetVector &reductionDims, + ArrayRef splitReductionIvs, AffineMap partialReductionMap) { + int64_t initRank = partialReductionMap.getNumResults(); + SmallVector initOffsets, initSizes; + Attribute one = IntegerAttr::get(IndexType::get(context), 1); + SmallVector initStrides(initRank, one); + SmallVector resultShape; + for (AffineExpr dimExpr : partialReductionMap.getResults()) { + unsigned dim = cast(dimExpr).getPosition(); + if (std::optional dimPos = getPositionIn(reductionDims, dim)) { + initOffsets.push_back(splitReductionIvs[dimPos.value()]); + initSizes.push_back(one); + } else { + initOffsets.push_back(offsets[dim]); + initSizes.push_back(sizes[dim]); + resultShape.push_back(sizes[dim]); + } + } + SmallVector staticShapes; + std::tie(staticShapes, std::ignore) = decomposeMixedValues(resultShape); + return {staticShapes, initOffsets, initSizes, initStrides}; +} + +/// Return the result shape, offsets, sizes and strides of the slice of the +/// `initValue` to use as destination of the partial reduction op. +static InitSliceInfo getInitSliceInfo(MLIRContext *context, + ReductionTilingStrategy strategy, + ArrayRef offsets, + ArrayRef sizes, + const SetVector &reductionDims, + ArrayRef splitReductionIvs, + AffineMap partialReductionMap) { + if (strategy == ReductionTilingStrategy::PartialReductionOuterReduction) { + return getInitSliceInfoForOuterReduction(context, offsets, sizes, + reductionDims, splitReductionIvs, + partialReductionMap); + } + assert(strategy == ReductionTilingStrategy::PartialReductionOuterParallel && + "unexpected ReductionTilingStrategy"); + return getInitSliceInfoForOuterParallel(context, offsets, sizes, + reductionDims, splitReductionIvs, + partialReductionMap); } /// External model implementation of PartialReductionInterface for @@ -390,21 +459,6 @@ struct LinalgOpPartialReductionInterface SmallVector partialResultMaps = getPartialResultAffineMaps(linalgOp, reductionDims); - // LinalgOp implements TilingInterface. - auto tilingInterfaceOp = cast(linalgOp.getOperation()); - SmallVector shape = - llvm::map_to_vector(tilingInterfaceOp.getIterationDomain(b), - [](Range x) { return x.size; }); - - SmallVector tiledShape; - for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) { - if (isZeroInteger(tileSize)) { - tiledShape.push_back(dimSize); - } else { - tiledShape.push_back(tileSize); - } - } - SmallVector inits; for (auto [initIdx, result, partialMap] : llvm::enumerate(linalgOp->getResults(), partialResultMaps)) { @@ -424,7 +478,7 @@ struct LinalgOpPartialReductionInterface SmallVector partialResultShape; for (AffineExpr dimExpr : partialMap.getResults()) { auto dim = cast(dimExpr); - partialResultShape.push_back(tiledShape[dim.getPosition()]); + partialResultShape.push_back(sizes[dim.getPosition()]); } Type elType = getElementTypeOrSelf(result.getType()); @@ -444,13 +498,8 @@ struct LinalgOpPartialReductionInterface ReductionTilingStrategy tilingStrategy, ValueRange init, ArrayRef offsets, ArrayRef sizes, - const SetVector &reductionDims) const { - if (tilingStrategy != - ReductionTilingStrategy::PartialReductionOuterReduction) { - // TODO: Add support for `PartialReductionOuterParallel` strategy. - return op->emitOpError("unsupported partial reduction tiling with " - "`PartialReductionOuterParallel` strategy"); - } + const SetVector &reductionDims, + ArrayRef splitReductionIvs) const { OpBuilder::InsertionGuard guard(b); auto linalgOp = cast(op); @@ -459,7 +508,16 @@ struct LinalgOpPartialReductionInterface // Step 1. Extend init maps to have reduction dimension dims, since we // are converting them to parallel dimensions. - SmallVector newInitMaps = partialReductionMaps; + SmallVector newInitMaps; + if (tilingStrategy == + ReductionTilingStrategy::PartialReductionOuterReduction) { + newInitMaps = llvm::to_vector(partialReductionMaps); + } else { + newInitMaps = llvm::map_to_vector( + linalgOp.getDpsInitsMutable(), [&](OpOperand &opOperand) { + return linalgOp.getMatchingIndexingMap(&opOperand); + }); + } // Step 2a: Extract a slice of the input operands. SmallVector tiledInputs = makeTiledShapes( @@ -473,10 +531,17 @@ struct LinalgOpPartialReductionInterface SmallVector tiledInits; for (auto [partialReductionMap, valueToTile] : llvm::zip_equal(partialReductionMaps, init)) { - Operation *sliceOp = - getInitSliceForOuterReduction(b, loc, valueToTile, offsets, sizes, - reductionDims, partialReductionMap); - tiledInits.push_back(sliceOp->getResult(0)); + InitSliceInfo sliceInfo = getInitSliceInfo( + b.getContext(), tilingStrategy, offsets, sizes, reductionDims, + splitReductionIvs, partialReductionMap); + auto valueToTileType = cast(valueToTile.getType()); + RankedTensorType sliceResultType = RankedTensorType::get( + sliceInfo.resultShape, valueToTileType.getElementType(), + valueToTileType.getEncoding()); + auto sliceOp = b.create( + loc, sliceResultType, valueToTile, sliceInfo.offsets, sliceInfo.sizes, + sliceInfo.strides); + tiledInits.push_back(sliceOp.getResult()); generatedSlices.push_back(sliceOp); } @@ -491,19 +556,31 @@ struct LinalgOpPartialReductionInterface // Step 3. Change the reduction dim iterator types. SmallVector newIteratorTypes = linalgOp.getIteratorTypesArray(); - for (int dim : reductionDims) - newIteratorTypes[dim] = utils::IteratorType::parallel; + if (tilingStrategy == + ReductionTilingStrategy::PartialReductionOuterReduction) { + for (int dim : reductionDims) + newIteratorTypes[dim] = utils::IteratorType::parallel; + } // Step 4. Create the new generic op. + Operation *partialReductionOp; auto resultTypes = ValueRange(tiledInits).getTypes(); - auto genericOp = b.create(loc, resultTypes, tiledInputs, - tiledInits, newMaps, newIteratorTypes); - IRMapping mapping; - op->getRegion(0).cloneInto(&genericOp.getRegion(), - genericOp.getRegion().begin(), mapping); + if (tilingStrategy == + ReductionTilingStrategy::PartialReductionOuterReduction) { + auto genericOp = b.create( + loc, resultTypes, tiledInputs, tiledInits, newMaps, newIteratorTypes); + IRMapping mapping; + op->getRegion(0).cloneInto(&genericOp.getRegion(), + genericOp.getRegion().begin(), mapping); + partialReductionOp = genericOp.getOperation(); + } else { + SmallVector operands = std::move(tiledInputs); + llvm::append_range(operands, tiledInits); + partialReductionOp = mlir::clone(b, op, resultTypes, operands); + } return TilingResult{ - {genericOp.getOperation()}, - llvm::map_to_vector(genericOp->getResults(), + {partialReductionOp}, + llvm::map_to_vector(partialReductionOp->getResults(), [](OpResult r) -> Value { return r; }), generatedSlices}; } @@ -558,26 +635,19 @@ struct LinalgOpPartialReductionInterface LogicalResult getPartialResultTilePosition( Operation *op, OpBuilder &b, unsigned resultNumber, - ArrayRef offsets, ArrayRef sizes, - const SetVector &reductionDims, + ReductionTilingStrategy tilingStrategy, ArrayRef offsets, + ArrayRef sizes, const SetVector &reductionDims, + ArrayRef splitReductionIvs, SmallVector &resultOffsets, SmallVector &resultSizes) const { auto linalgOp = cast(op); SmallVector partialReductionMaps = getPartialResultAffineMaps(linalgOp, reductionDims); - - for (AffineExpr dimExpr : partialReductionMaps[resultNumber].getResults()) { - unsigned dim = cast(dimExpr).getPosition(); - resultSizes.push_back(sizes[dim]); - - if (llvm::is_contained(reductionDims, dim)) { - // Reduction dims are reduced, and are always outputed in the same - // place. So use offset 0 for them. - resultOffsets.push_back(b.getIndexAttr(0)); - } else { - resultOffsets.push_back(offsets[dim]); - } - } + InitSliceInfo sliceInfo = getInitSliceInfo( + b.getContext(), tilingStrategy, offsets, sizes, reductionDims, + splitReductionIvs, partialReductionMaps[resultNumber]); + std::swap(resultOffsets, sliceInfo.offsets); + std::swap(resultSizes, sliceInfo.sizes); return success(); } diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index e7c076024e67b..ddcae8481a5b4 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -166,12 +166,11 @@ static LogicalResult checkTileSizes(TilingInterface op, assert((numThreads.empty() || (numThreads.size() == iterators.size())) && "when specified, expected number of threads to use for each loop"); - bool isParallelTiling = false, isReductionTiling = false; + bool isParallelTiling = false; for (auto [index, iterator, tileSize] : llvm::enumerate(iterators, tileSizes)) { if (!isConstantIntValue(tileSize, 0)) { isParallelTiling |= iterator == utils::IteratorType::parallel; - isReductionTiling |= iterator == utils::IteratorType::reduction; } if (loopType == scf::SCFTilingOptions::LoopType::ForallOp && @@ -199,15 +198,29 @@ static LogicalResult checkTileSizes(TilingInterface op, } } - if (isParallelTiling && isReductionTiling && - reductionStrategy != ReductionTilingStrategy::FullReduction) { - return op->emitOpError( - "combined parallel and reduction tiling is not supported with partial " - "reduction tiling strategies"); + if (reductionStrategy != ReductionTilingStrategy::FullReduction) { + if (isParallelTiling) { + return op->emitOpError("tiling parallel dimensions is not supported with " + "partial reduction tiling strategies"); + } } return success(); } +/// Get the reduction dims that are tiled. This accounts for reduction dims +/// that are specified as tiled, but the tile size is 0. +static SetVector +getSanitizedReductionDims(ArrayRef tileSizes, + const scf::SCFTilingOptions &options) { + SetVector reductionDims; + for (auto dim : options.reductionDims) { + if (isConstantIntValue(tileSizes[dim], 0)) + continue; + reductionDims.insert(dim); + } + return reductionDims; +} + /// Check if `stride` evenly divides the trip count `size - offset`. static bool tileDividesIterationDomain(Range loopRange) { std::optional offsetAsInt = getConstantIntValue(loopRange.offset); @@ -264,10 +277,12 @@ static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, /// `offset`s and `size`s of the tile of the iteration space that the /// innermost loop body of the generated tiled loops corresponds to. static std::tuple, SmallVector> -getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, +getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, + ReductionTilingStrategy strategy, ValueRange ivs, ArrayRef iterationDomain, ArrayRef tileSizes, - ArrayRef numThreads) { + ArrayRef numThreads, + const llvm::SetVector &reductionDims) { SmallVector offsets, sizes; int materializedLoopNum = 0; @@ -279,8 +294,8 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, offsetExpr = d0 + d1 * s0; residualTileSizeExpr = s1 - (d0 + d1 * s0); - for (auto [nt, tileSize, loopRange] : - llvm::zip_equal(numThreads, tileSizes, iterationDomain)) { + for (auto [index, nt, tileSize, loopRange] : + llvm::enumerate(numThreads, tileSizes, iterationDomain)) { // Non-tiled cases, set the offset and size to the // `loopRange.offset/size`. @@ -564,9 +579,10 @@ static LogicalResult generateLoopNestUsingForallOp( /// - `loops` is an in-out parameter into which the generated loops are /// populated. static LogicalResult generateLoopNest( - RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, - ArrayRef loopRanges, ArrayRef tileSizes, - ArrayRef numThreads, ValueRange destinationTensors, + RewriterBase &rewriter, Location loc, + scf::SCFTilingOptions::LoopType loopType, ArrayRef loopRanges, + ArrayRef tileSizes, ArrayRef numThreads, + ValueRange destinationTensors, ArrayRef mappingVector, YieldTiledValuesFn tiledBodyFn, SmallVector &loops) { // If the tile sizes are all zero, no loops are generated. Just call the // callback function to handle untiled case. @@ -576,25 +592,26 @@ static LogicalResult generateLoopNest( return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors, tiledResults, resultOffsets, resultSizes); } - if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) { + if (loopType == scf::SCFTilingOptions::LoopType::ForOp) { return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes, destinationTensors, tiledBodyFn, loops); } - if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) { + if (loopType == scf::SCFTilingOptions::LoopType::ForallOp) { return generateLoopNestUsingForallOp( - rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector, + rewriter, loc, loopRanges, tileSizes, numThreads, mappingVector, destinationTensors, tiledBodyFn, loops); } return rewriter.notifyMatchFailure(loc, "unhandled loop type"); } -static FailureOr> -createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, - ArrayRef tileSizes, - const scf::SCFTilingOptions &options) { +static FailureOr> createInitialTensorsForTiling( + RewriterBase &rewriter, TilingInterface op, + ReductionTilingStrategy reductionStrategy, ArrayRef iterationDomain, + ArrayRef numThreads, ArrayRef tileSizes, + const SetVector &reductionDims) { SmallVector initTensors; Location loc = op->getLoc(); - if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) { + if (reductionStrategy == ReductionTilingStrategy::FullReduction) { if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, initTensors))) return failure(); return initTensors; @@ -602,20 +619,94 @@ createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, auto redOp = dyn_cast(op.getOperation()); if (!redOp) { - return rewriter.notifyMatchFailure( - op, "PartialReductionOuterReduction tiling strategy is only supported" - "for operations implementing PartialReductionOpInterface"); + return op->emitOpError( + "PartialReductionOuterReduction tiling strategy is only supported for " + "operations implementing PartialReductionOpInterface"); + } + SmallVector sizes(iterationDomain.size()); + AffineExpr s0, s1, s2; + bindSymbols(rewriter.getContext(), s0, s1, s2); + AffineExpr sizeExpr = ((s0 - s1).ceilDiv(s2)); + AffineExpr divExpr = s0.ceilDiv(s1); + for (auto [index, domain, tileSize] : + llvm::enumerate(iterationDomain, tileSizes)) { + if (!numThreads.empty()) { + // Untiled case. + if (isConstantIntValue(numThreads[index], 0)) { + sizes[index] = affine::makeComposedFoldedAffineApply( + rewriter, op.getLoc(), sizeExpr, + {domain.size, domain.offset, domain.stride}); + continue; + } + sizes[index] = numThreads[index]; + continue; + } + + // Non reduction dimensions/non-tiled dimensions. + if (!reductionDims.contains(index) || isConstantIntValue(tileSize, 0)) { + sizes[index] = affine::makeComposedFoldedAffineApply( + rewriter, op.getLoc(), sizeExpr, + {domain.size, domain.offset, domain.stride}); + continue; + } + + if (reductionStrategy == + ReductionTilingStrategy::PartialReductionOuterReduction) { + sizes[index] = tileSize; + continue; + } + + assert(reductionStrategy == + ReductionTilingStrategy::PartialReductionOuterParallel); + OpFoldResult normalizedRange = affine::makeComposedFoldedAffineApply( + rewriter, op.getLoc(), sizeExpr, + {domain.size, domain.offset, domain.stride}); + sizes[index] = affine::makeComposedFoldedAffineApply( + rewriter, op.getLoc(), divExpr, {normalizedRange, tileSize}); + } + return redOp.generateInitialTensorForPartialReduction(rewriter, loc, sizes, + reductionDims); +} + +/// For the case of `ReductionTilingStrategy::PartialReductionOuterParallel` +/// the `PartialReductionOpInterface` methods need the index of the parallel +/// split reduction being executed. +static SmallVector +getSplitReductionIvs(RewriterBase &rewriter, Location loc, + ReductionTilingStrategy reductionStrategy, ValueRange ivs, + ArrayRef numThreads, + ArrayRef tileSizes, + const SetVector &reductionDims) { + SmallVector splitReductionIvs; + splitReductionIvs.resize(reductionDims.size(), rewriter.getIndexAttr(0)); + AffineExpr s0, s1; + bindSymbols(rewriter.getContext(), s0, s1); + AffineExpr divExpr = s0.ceilDiv(s1); + int ivIndex = 0; + if (reductionStrategy == + ReductionTilingStrategy::PartialReductionOuterParallel) { + for (auto [index, reductionDim] : llvm::enumerate(reductionDims)) { + if (!numThreads.empty()) { + splitReductionIvs[index] = ivs[ivIndex++]; + continue; + } + splitReductionIvs[index] = affine::makeComposedFoldedAffineApply( + rewriter, loc, divExpr, + ArrayRef{ivs[ivIndex++], tileSizes[reductionDim]}); + } } - return redOp.generateInitialTensorForPartialReduction( - rewriter, loc, tileSizes, options.reductionDims); + return splitReductionIvs; } static FailureOr getTiledImplementation(RewriterBase &rewriter, TilingInterface op, + ReductionTilingStrategy reductionStrategy, ValueRange regionIterArg, ArrayRef offsets, - ArrayRef sizes, - const scf::SCFTilingOptions &options) { - if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) { + ArrayRef sizes, ValueRange ivs, + ArrayRef numThreads, + ArrayRef tileSizes, + const SetVector &reductionDims) { + if (reductionStrategy == ReductionTilingStrategy::FullReduction) { return op.getTiledImplementation(rewriter, offsets, sizes); } @@ -626,20 +717,25 @@ getTiledImplementation(RewriterBase &rewriter, TilingInterface op, "supported for operations " "implementing PartialReductionOpInterface"); } - return redOp.tileToPartialReduction(rewriter, op.getLoc(), - options.reductionStrategy, regionIterArg, - offsets, sizes, options.reductionDims); + + SmallVector splitReductionIvs = + getSplitReductionIvs(rewriter, op.getLoc(), reductionStrategy, ivs, + numThreads, tileSizes, reductionDims); + return redOp.tileToPartialReduction(rewriter, op.getLoc(), reductionStrategy, + regionIterArg, offsets, sizes, + reductionDims, splitReductionIvs); } -static LogicalResult -getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, - TilingInterface op, ArrayRef offsets, - ArrayRef sizes, - SmallVector &resultOffset, - SmallVector &resultSize, - const scf::SCFTilingOptions &options) { +static LogicalResult getResultTilePosition( + RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, + int64_t index, Value tiledResult, TilingInterface op, + ArrayRef offsets, ArrayRef sizes, + ValueRange ivs, ArrayRef numThreads, + ArrayRef tileSizes, const SetVector &reductionDims, + SmallVector &resultOffset, + SmallVector &resultSize) { - if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) { + if (reductionStrategy == ReductionTilingStrategy::FullReduction) { return op.getResultTilePosition(rewriter, index, offsets, sizes, resultOffset, resultSize); } @@ -649,16 +745,20 @@ getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, op, "PartialReductionOuterReduction tiling strategy is only supported" "for operations implementing PartialReductionOpInterface"); } - return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes, - options.reductionDims, resultOffset, - resultSize); + SmallVector splitReductionIvs = + getSplitReductionIvs(rewriter, op.getLoc(), reductionStrategy, ivs, + numThreads, tileSizes, reductionDims); + return redOp.getPartialResultTilePosition( + rewriter, index, reductionStrategy, offsets, sizes, reductionDims, + splitReductionIvs, resultOffset, resultSize); } static FailureOr mergeTilingResults(RewriterBase &rewriter, TilingInterface op, - ValueRange partialResults, - const scf::SCFTilingOptions &options) { - assert(options.reductionStrategy != ReductionTilingStrategy::FullReduction && + ReductionTilingStrategy reductionStrategy, + const SetVector &reductionDims, + ValueRange partialResults) { + assert(reductionStrategy != ReductionTilingStrategy::FullReduction && "expected merge to be called for only partial reduction cases"); auto redOp = dyn_cast(op.getOperation()); @@ -669,7 +769,7 @@ mergeTilingResults(RewriterBase &rewriter, TilingInterface op, "implementing PartialReductionOpInterface"); } return redOp.mergeReductions(rewriter, op.getLoc(), partialResults, - options.reductionDims); + reductionDims); } /// Append the specified additional `newInitOperands` operands to the @@ -911,6 +1011,10 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, return failure(); } + // Get the reduction dims + SetVector reductionDims = + getSanitizedReductionDims(tileSizes, options); + // 3. If there is an interchange specified, permute the iteration domain and // the tile sizes. SmallVector interchangeVector; @@ -938,7 +1042,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, // 4a. Compute the `offsets` and `sizes` to use for tiling. SmallVector offsets, sizes; std::tie(offsets, sizes) = getTileOffsetAndSizes( - rewriter, loc, ivs, iterationDomain, tileSizes, numThreads); + rewriter, loc, options.reductionStrategy, ivs, iterationDomain, + tileSizes, numThreads, reductionDims); // 4b. If interchange was provided, apply inverse of the interchange // to get back the offsets/sizes in the order to be specified. @@ -966,8 +1071,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, } // 5c. Tile the cloned operation. - tilingResult = getTiledImplementation(rewriter, clonedOp, regionIterArgs, - offsets, sizes, options); + tilingResult = getTiledImplementation( + rewriter, clonedOp, options.reductionStrategy, regionIterArgs, offsets, + sizes, ivs, numThreads, tileSizes, reductionDims); if (failed(tilingResult)) { rewriter.eraseOp(clonedOp); return op.emitOpError("faild to tile operation"); @@ -982,9 +1088,10 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, llvm::enumerate(tilingResult->tiledValues)) { tiledResults.push_back(tiledValue); SmallVector resultOffset, resultSize; - if (failed(getResultTilePosition(rewriter, index, tiledValue, op, offsets, - sizes, resultOffset, resultSize, - options))) { + if (failed(getResultTilePosition( + rewriter, options.reductionStrategy, index, tiledValue, op, + offsets, sizes, ivs, numThreads, tileSizes, reductionDims, + resultOffset, resultSize))) { for (auto op : tilingResult->tiledOps) { rewriter.eraseOp(op); } @@ -999,8 +1106,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, }; // 6. Find the destination tensors to use for the operation. - FailureOr> maybeInits = - createInitialTensorsForTiling(rewriter, op, tileSizes, options); + FailureOr> maybeInits = createInitialTensorsForTiling( + rewriter, op, options.reductionStrategy, iterationDomain, numThreads, + tileSizes, reductionDims); if (failed(maybeInits)) { return rewriter.notifyMatchFailure( op, "unable to create initial tensors for tiling"); @@ -1009,8 +1117,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, // 7. Generate the tiled loops nest using the callback defined above. SmallVector loops; - if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain, - tileSizes, numThreads, initTensors, + if (failed(generateLoopNest(rewriter, op.getLoc(), options.loopType, + iterationDomain, tileSizes, numThreads, + initTensors, options.mappingVector, innerYieldTiledValuesFn, loops))) return op.emitOpError("failed to generate tiling loops"); assert(succeeded(tilingResult) && @@ -1038,8 +1147,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, } // The results of the loop needs to be merged. - FailureOr mergeResult = - mergeTilingResults(rewriter, op, loopResults, options); + FailureOr mergeResult = mergeTilingResults( + rewriter, op, options.reductionStrategy, reductionDims, loopResults); if (failed(mergeResult)) { return rewriter.notifyMatchFailure( op, "Failed to merge partial results from tiling"); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 04242cad9ecb6..72144ec71c5d2 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2315,13 +2315,13 @@ RankedTensorType ExtractSliceOp::inferResultType( RankedTensorType ExtractSliceOp::inferResultType( RankedTensorType sourceTensorType, ArrayRef offsets, ArrayRef sizes, ArrayRef strides) { - SmallVector staticOffsets, staticSizes, staticStrides; - SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; - dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); - dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); - dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); - return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets, - staticSizes, staticStrides); + SmallVector staticSizes; + std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes); + assert(static_cast(staticSizes.size()) == + sourceTensorType.getRank() && + "unexpected staticSizes not equal to rank of source"); + return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(), + sourceTensorType.getEncoding()); } /// If the rank is reduced (i.e. the desiredResultRank is smaller than the diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index 8e3f796af54df..be01ff2fa3781 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -208,7 +208,7 @@ SmallVector getMixedValues(ArrayRef staticValues, /// Decompose a vector of mixed static or dynamic values into the corresponding /// pair of arrays. This is the inverse function of `getMixedValues`. std::pair, SmallVector> -decomposeMixedValues(const SmallVectorImpl &mixedValues) { +decomposeMixedValues(ArrayRef mixedValues) { SmallVector staticValues; SmallVector dynamicValues; for (const auto &it : mixedValues) { diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir index 009ab17786696..075d02ab75ad1 100644 --- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir @@ -112,7 +112,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0 - by num_threads = [0, 5], tile_sizes = [] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + by num_threads = [0, 5] tile_sizes = [] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) transform.yield } } @@ -134,10 +134,9 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]] // CHECK-DAG: %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]]) // CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor to tensor -// CHECK: %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]] -// CHECK: %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor to tensor -// CHECK: %[[TEMPEXT:.+]] = tensor.extract_slice %[[ET]][0] [%[[D0]]] [1] : tensor to tensor -// CHECK: %[[PARTIAL:.+]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "reduction"]} ins(%[[INCHUNK]] : tensor) outs(%[[TEMPEXT]] : tensor) { +// CHECK-DAG: %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]] +// CHECK-DAG: %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor to tensor +// CHECK: %[[PARTIAL:.+]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "reduction"]} ins(%[[INCHUNK]] : tensor) outs(%[[ET]] : tensor) { // CHECK: arith.mulf // CHECK: arith.addf // CHECK: linalg.yield @@ -166,7 +165,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0 - by num_threads = [0, 0, 5], tile_sizes = [] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + by num_threads = [0, 0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) transform.yield } } @@ -187,11 +186,10 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]] // CHECK-DAG: %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]]) // CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor to tensor -// CHECK: %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]] -// CHECK: %[[INCHUNKA:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor to tensor -// CHECK: %[[INCHUNKB:.+]] = tensor.extract_slice %[[ARG1]][%[[TINDEX]], 0] [%[[TS1]], %[[D2]]] [1, 1] : tensor to tensor -// CHECK: %[[TEMPEXT:.+]] = tensor.extract_slice %[[ET]][0, 0] [%[[D0]], %[[D2]]] [1, 1] : tensor to tensor -// CHECK: %[[PARTIAL:.+]] = linalg.matmul ins(%[[INCHUNKA]], %[[INCHUNKB]] : tensor, tensor) outs(%[[TEMPEXT]] : tensor) -> tensor +// CHECK-DAG: %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]] +// CHECK-DAG: %[[INCHUNKA:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor to tensor +// CHECK-DAG: %[[INCHUNKB:.+]] = tensor.extract_slice %[[ARG1]][%[[TINDEX]], 0] [%[[TS1]], %[[D2]]] [1, 1] : tensor to tensor +// CHECK: %[[PARTIAL:.+]] = linalg.matmul ins(%[[INCHUNKA]], %[[INCHUNKB]] : tensor, tensor) outs(%[[ET]] : tensor) -> tensor // CHECK: scf.forall.in_parallel { // CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor into tensor // CHECK: } @@ -204,113 +202,9 @@ module attributes {transform.with_named_sequence} { // ----- -func.func @reduction_tile_parallel_cyclic_dist( - %arg0: tensor, %out: tensor) -> tensor { - %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} - ins(%arg0 : tensor) - outs(%out : tensor) { - ^bb0(%arg7: f32, %arg9: f32): - %1 = arith.mulf %arg7, %arg7 : f32 - %2 = arith.addf %1, %arg9 : f32 - linalg.yield %2 : f32 - } -> tensor - return %red : tensor -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0 - by num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - transform.yield - } -} - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 3)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 3)> -// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0)> - -// CHECK: func @reduction_tile_parallel_cyclic_dist(%[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor -// CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C15:.*]] = arith.constant 15 : index -// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[E:.*]] = tensor.empty(%[[D0]]) : tensor -// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor) -> tensor -// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor) { -// CHECK: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor to tensor -// CHECK: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor -// CHECK: %[[LB:.+]] = affine.apply #[[MAP0]]()[%[[IV]]] -// CHECK: %[[CARRY:.+]] = scf.for %[[IV1:.+]] = %[[LB]] to %[[D1]] step %[[C15]] iter_args(%[[ACC:.+]] = %[[ET]]) -> (tensor) { -// CHECK: %[[TS0:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[D1]]] -// CHECK: %[[D3:.+]] = tensor.dim %[[ACC]], %[[C0]] : tensor -// CHECK: %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV1]]] [%[[D0]], %[[TS0]]] [1, 1] : tensor to tensor -// CHECK: %[[TEMPEXT:.+]] = tensor.extract_slice %[[ACC]][0] [%[[D3]]] [1] : tensor to tensor -// CHECK: %[[PARTIAL:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]]], iterator_types = ["parallel", "reduction"]} ins(%[[INCHUNK]] : tensor) outs(%[[TEMPEXT]] : tensor) { -// CHECK: arith.mulf -// CHECK: arith.addf -// CHECK: linalg.yield -// CHECK: } -> tensor -// CHECK: %[[INS:.+]] = tensor.insert_slice %[[PARTIAL]] into %[[ACC]][0] [%[[D3]]] [1] : tensor into tensor -// CHECK: scf.yield %[[INS]] : tensor -// CHECK: } -// CHECK: scf.forall.in_parallel { -// CHECK: tensor.parallel_insert_slice %[[CARRY]] into %[[ARG3]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor into tensor -// CHECK: } -// CHECK: } -// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor) outs(%[[ARG1]] : tensor) dimensions = [1] -// CHECK: arith.addf -// CHECK: linalg.yield -// CHECK: } -// CHECK: return %[[R]] : tensor - -// ----- - -func.func @reduction_tile_parallel_cyclic_dist( - %arg0: tensor, %out: tensor) -> tensor { - %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} - ins(%arg0 : tensor) - outs(%out : tensor) { - ^bb0(%arg7: f32, %arg9: f32): - %1 = arith.mulf %arg7, %arg7 : f32 - %2 = arith.addf %1, %arg9 : f32 - linalg.yield %2 : f32 - } -> tensor - return %red : tensor -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0 - by num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - - // CHECK: expecting fill - // CHECK-NEXT: linalg.fill - transform.print %1 {name = "expecting fill"} : !transform.any_op - // CHECK: expecting parallel reduction - // CHECK-NEXT: linalg.generic - // CHECK: iterator_types = ["parallel", "reduction"] - transform.print %2 {name = "expecting parallel reduction"} : !transform.any_op - // CHECK: expecting parallel reduction - // CHECK-NEXT: linalg.reduce - // CHECK: iterator_types = ["parallel", "reduction"] - transform.print %3 {name = "expecting parallel reduction"} : !transform.any_op - transform.yield - } -} - -// ----- - func.func @reduction_untiled_forall( %arg0: tensor, %out: tensor) -> tensor { - // expected-note @below {{target operation}} + // expected-error @below {{tiling parallel dimensions is not supported with partial reduction tiling strategies}} %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} @@ -329,9 +223,8 @@ module attributes {transform.with_named_sequence} { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op // expected-error @below {{could not tile reduction}} %1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0 - by num_threads = [5], tile_sizes = [3], mapping = [#gpu.thread] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - - transform.yield + by num_threads = [5] tile_sizes = [3] mapping = [#gpu.thread] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield } } @@ -643,3 +536,158 @@ module { // CHECK-SAME: outs(%[[INIT]] : // CHECK-SAME: dimensions = [1, 2] // CHECK: return %[[REDUCE]] + +// ----- + +func.func @reduction_tile_parallel_using_tile_sizes( + %arg0: tensor, %out: tensor) -> tensor { + %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor) + outs(%out : tensor) { + ^bb0(%arg7: f32, %arg9: f32): + %1 = arith.mulf %arg7, %arg7 : f32 + %2 = arith.addf %1, %arg9 : f32 + linalg.yield %2 : f32 + } -> tensor + return %red : tensor +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 ceildiv 5)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)> +// CHECK: func @reduction_tile_parallel_using_tile_sizes(%[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor +// CHECK-DAG: %[[PARALLEL_DIM:.+]] = affine.apply #[[MAP0]]()[%[[D1]]] +// CHECK: %[[E:.*]] = tensor.empty(%[[D0]], %[[PARALLEL_DIM]]) : tensor +// CHECK: %[[F:.*]] = linalg.fill +// CHECK-SAME: outs(%[[E]] : +// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) = (0) to (%[[D1]]) step (5) shared_outs(%[[ARG3:.+]] = %[[F]]) +// CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP1]](%[[IV]])[%[[D1]]] +// CHECK-DAG: %[[INIT_OFFSET:.+]] = affine.apply #[[MAP0]]()[%[[IV]]] +// CHECK-DAG: %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [%[[D0]], %[[TS0]]] [1, 1] +// CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3]][0, %[[INIT_OFFSET]]] [%[[D0]], 1] [1, 1] +// CHECK: %[[PARTIAL:.+]] = linalg.generic +// CHECK-SAME: ins(%[[INCHUNK]] : +// CHECK-SAME: outs(%[[ET]] : +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, %[[INIT_OFFSET]]] [%[[D0]], 1] [1, 1] +// CHECK: } +// CHECK: } +// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] +// CHECK-SAME: outs(%[[ARG1]] : +// CHECK: return %[[R]] : tensor +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0 + by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// Check that only one of the reduction dimension can be tiled (in this case inner). + +#map = affine_map<(d0, d1, d2) -> (d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0)> +module { + func.func @reduction_using_forall_tile_single_of_multiple_reduction_inner( + %arg0: tensor<86x128xf32>, %arg1: tensor<4096x86x128xf32>, %arg2: tensor<4096xf32>) -> tensor<4096xf32> { + %0 = linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "reduction"]} + ins(%arg0, %arg1 : tensor<86x128xf32>, tensor<4096x86x128xf32>) outs(%arg2 : tensor<4096xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %1 = arith.mulf %in, %in_0 : f32 + %2 = arith.addf %1, %out : f32 + linalg.yield %2 : f32 + } -> tensor<4096xf32> + return %0 : tensor<4096xf32> + } + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %fill_op, %split_linalg_op, %combining_linalg_op, %for_op = + transform.structured.tile_reduction_using_forall %0 reduction_dims = [2] by tile_sizes = [0, 0, 64] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } + } +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 ceildiv 64)> +// CHECK: func @reduction_using_forall_tile_single_of_multiple_reduction_inner(%[[ARG0:.+]]: tensor<86x128xf32>, %[[ARG1:.+]]: tensor<4096x86x128xf32>, %[[ARG2:.+]]: tensor<4096xf32>) +// CHECK: %[[E:.*]] = tensor.empty() : tensor<4096x2xf32> +// CHECK: %[[F:.*]] = linalg.fill +// CHECK-SAME: outs(%[[E]] : +// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) = (0) to (128) step (64) shared_outs(%[[ARG3:.+]] = %[[F]]) +// CHECK-DAG: %[[INIT_OFFSET:.+]] = affine.apply #[[MAP0]]()[%[[IV]]] +// CHECK-DAG: %[[ARG0_SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [86, 64] [1, 1] +// CHECK-DAG: %[[ARG1_SLICE:.+]] = tensor.extract_slice %[[ARG1]][0, 0, %[[IV]]] [4096, 86, 64] [1, 1, 1] +// CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3]][0, %[[INIT_OFFSET]]] [4096, 1] [1, 1] +// CHECK: %[[PARTIAL:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0_SLICE]], %[[ARG1_SLICE]] : +// CHECK-SAME: outs(%[[ET]] : +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, %[[INIT_OFFSET]]] [4096, 1] [1, 1] +// CHECK: } +// CHECK: } +// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] +// CHECK-SAME: outs(%[[ARG2]] : +// CHECK: return %[[R]] + +// ----- + +// Check that specifying both reduction dimensions, but setting tile size to 0 for one of them behaves consistent with specifying single reduction dimension. + +#map = affine_map<(d0, d1, d2) -> (d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0)> +module { + func.func @reduction_using_forall_tilesize_0_of_multiple_reduction_inner( + %arg0: tensor<86x128xf32>, %arg1: tensor<4096x86x128xf32>, %arg2: tensor<4096xf32>) -> tensor<4096xf32> { + %0 = linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "reduction"]} + ins(%arg0, %arg1 : tensor<86x128xf32>, tensor<4096x86x128xf32>) outs(%arg2 : tensor<4096xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %1 = arith.mulf %in, %in_0 : f32 + %2 = arith.addf %1, %out : f32 + linalg.yield %2 : f32 + } -> tensor<4096xf32> + return %0 : tensor<4096xf32> + } + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %fill_op, %split_linalg_op, %combining_linalg_op, %for_op = + transform.structured.tile_reduction_using_forall %0 reduction_dims = [1, 2] by tile_sizes = [0, 0, 64] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } + } +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 ceildiv 64)> +// CHECK: func @reduction_using_forall_tilesize_0_of_multiple_reduction_inner(%[[ARG0:.+]]: tensor<86x128xf32>, %[[ARG1:.+]]: tensor<4096x86x128xf32>, %[[ARG2:.+]]: tensor<4096xf32>) +// CHECK: %[[E:.*]] = tensor.empty() : tensor<4096x2xf32> +// CHECK: %[[F:.*]] = linalg.fill +// CHECK-SAME: outs(%[[E]] : +// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) = (0) to (128) step (64) shared_outs(%[[ARG3:.+]] = %[[F]]) +// CHECK-DAG: %[[INIT_OFFSET:.+]] = affine.apply #[[MAP0]]()[%[[IV]]] +// CHECK-DAG: %[[ARG0_SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [86, 64] [1, 1] +// CHECK-DAG: %[[ARG1_SLICE:.+]] = tensor.extract_slice %[[ARG1]][0, 0, %[[IV]]] [4096, 86, 64] [1, 1, 1] +// CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3]][0, %[[INIT_OFFSET]]] [4096, 1] [1, 1] +// CHECK: %[[PARTIAL:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0_SLICE]], %[[ARG1_SLICE]] : +// CHECK-SAME: outs(%[[ET]] : +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, %[[INIT_OFFSET]]] [4096, 1] [1, 1] +// CHECK: } +// CHECK: } +// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] +// CHECK-SAME: outs(%[[ARG2]] : +// CHECK: return %[[R]]