diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td index b75fc5e806afb..50b69b8f8d833 100644 --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -427,6 +427,28 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> { /*defaultImplementation=*/[{ return failure(); }] + >, + InterfaceMethod< + /*desc=*/[{ + Method to return the position of the partial result tile computed by + the tiled operation. This is same as + TilingInterface:::getResultTilePosition, but determines the result + tile position for partial reduction. + }], + /*retType=*/"::llvm::LogicalResult", + /*methodName=*/"getPartialResultTilePosition", + /*args=*/(ins + "::mlir::OpBuilder &":$b, + "unsigned":$resultNumber, + "::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets, + "::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes, + "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets, + "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes, + "::mlir::ArrayRef":$reductionDims), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return failure(); + }] > ]; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp index 5bf2f91c2c7bc..92cfba2549a3f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp @@ -105,13 +105,13 @@ static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) { static MeshOp getMesh(Operation *op, ArrayRef operandShardings, ArrayRef resultShardings, SymbolTableCollection &symbolTable) { - for (const MeshSharding& sharding : operandShardings) { + for (const MeshSharding &sharding : operandShardings) { if (sharding) { return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable); } } - for (const MeshSharding& sharding : resultShardings) { + for (const MeshSharding &sharding : resultShardings) { if (sharding) { return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable); } @@ -129,8 +129,9 @@ static MeshOp getMesh(Operation *op, ArrayRef operandShardings, // the original operand. // The other processes would use the reduction operation neutral tensor. static Value createDestinationPassingStyleInitOperand( - LinalgOp op, Value spmdizedOperand, ArrayRef reductionMeshAxes, - MeshOp meshOp, ImplicitLocOpBuilder &builder) { + LinalgOp op, int operandNumber, Value spmdizedOperand, + ArrayRef reductionMeshAxes, MeshOp meshOp, + ImplicitLocOpBuilder &builder) { Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex( meshOp.getSymName(), reductionMeshAxes, builder); Value zero = builder.create(0); @@ -152,14 +153,21 @@ static Value createDestinationPassingStyleInitOperand( builder.setInsertionPointToEnd(&ifOp.getElseRegion().front()); SmallVector shape = tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand); - PartialReductionOpInterface partialReductionIface = - llvm::cast(op.getOperation()); - assert(op->getNumResults() == 1 && "Multiple results not supported."); - FailureOr> reductionNeutralTensor = - partialReductionIface.generateInitialTensorForPartialReduction( - builder, builder.getLoc(), shape, {}); - assert(succeeded(reductionNeutralTensor)); - builder.create(reductionNeutralTensor.value()); + + SmallVector combinerOps; + matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps); + assert(combinerOps.size() == 1); + std::optional neutralEl = + arith::getNeutralElement(combinerOps[0]); + + Value init = builder.create(op.getLoc(), shape, + neutralEl.value().getType()); + Value constant = + builder.create(op.getLoc(), neutralEl.value()); + Value fill = builder.create(op.getLoc(), constant, init) + .getResult(0); + + builder.create(fill); } return ifOp.getResult(0); } @@ -178,7 +186,7 @@ static SmallVector createDestinationPassingStyleInitOperands( Value spmdizedInitOperand = spmdizationMap.lookup(op->getOperands()[operandIdx]); newOperands[operandIdx] = createDestinationPassingStyleInitOperand( - op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder); + op, 0, spmdizedInitOperand, reductionMeshAxes, meshOp, builder); return newOperands; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index f86715a94b268..b7764da26a7f4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -324,7 +324,27 @@ struct LinalgOpTilingInterface // External Model for implementing `PartialReductionInterface` for `LinalgOp`s. //===----------------------------------------------------------------------===// -/// External model implementation of PartialReductionInterface for LinalgOps. +/// Return an AffineMap for a partial result for the given result number, +/// assuming the partial tiling strategy is outer-reduction loop + +/// inner-parallel tile. The returned AffineMap can be used as the replacement +/// AffineMap for the inner-parallel tile linalg op for the given result number. +/// +/// The new AffineMap is the old AffineMap with reduction dimensions appended +/// at end. +static AffineMap getPartialResultAffineMap(LinalgOp linalgOp, + ArrayRef reductionDims, + unsigned resultNumber) { + AffineMap map = + linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber)); + for (int redPos : reductionDims) { + map = map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()), + map.getNumResults()); + } + return map; +} + +/// External model implementation of PartialReductionInterface for +/// LinalgOps. template struct LinalgOpPartialReductionInterface : public PartialReductionOpInterface::ExternalModel< @@ -338,11 +358,24 @@ struct LinalgOpPartialReductionInterface if (linalgOp.hasPureBufferSemantics()) return op->emitOpError("expected operation to have tensor semantics"); + // LinalgOp implements TilingInterface. + auto tilingInterfaceOp = cast(linalgOp.getOperation()); + SmallVector shape = + llvm::map_to_vector(tilingInterfaceOp.getIterationDomain(b), + [](Range x) { return x.size; }); + + SmallVector tiledShape; + for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) { + if (isZeroIndex(tileSize)) { + tiledShape.push_back(dimSize); + } else { + tiledShape.push_back(tileSize); + } + } + SmallVector inits; for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e; ++initIdx) { - // Insert the new parallel dimension based on the index of the reduction - // loops. This could be controlled by user for more flexibility. SmallVector combinerOps; if (!matchReduction(linalgOp.getRegionOutputArgs(), initIdx, combinerOps) || @@ -355,33 +388,19 @@ struct LinalgOpPartialReductionInterface return op->emitOpError( "Failed to get an identity value for the reduction operation."); - ArrayRef oldShape = - linalgOp.getShape(linalgOp.getDpsInitOperand(initIdx)); - - // Calculate the new shape, we insert the new dimensions based on the - // index of the reduction dimensions. - SmallVector newOutputShape; - SmallVector dynamicDims; - int64_t currReductionDims = 0; - DenseSet reductionDimsSet(reductionDims.begin(), - reductionDims.end()); - for (int64_t idx : - llvm::seq(0, oldShape.size() + reductionDims.size())) { - if (reductionDimsSet.contains(idx)) { - dispatchIndexOpFoldResults(sizes[idx], dynamicDims, newOutputShape); - currReductionDims++; - continue; - } - int64_t oldIdx = idx - currReductionDims; - int64_t dim = oldShape[oldIdx]; - newOutputShape.push_back(dim); - if (ShapedType::isDynamic(dim)) - dynamicDims.push_back(b.create( - loc, linalgOp.getDpsInitOperand(initIdx)->get(), oldIdx)); + // Append the new partial result dimensions. + AffineMap partialMap = + getPartialResultAffineMap(linalgOp, reductionDims, initIdx); + SmallVector partialResultShape; + for (AffineExpr dimExpr : partialMap.getResults()) { + auto dim = cast(dimExpr); + partialResultShape.push_back(tiledShape[dim.getPosition()]); } - Value emptyTensor = b.create( - loc, newOutputShape, - linalgOp.getRegionOutputArgs()[initIdx].getType(), dynamicDims); + + Type elType = + getElementTypeOrSelf(linalgOp->getResult(initIdx).getType()); + Value emptyTensor = + b.create(loc, partialResultShape, elType); Value constantOp = b.create(loc, *identity); auto identityTensor = b.create(loc, constantOp, emptyTensor); @@ -407,11 +426,7 @@ struct LinalgOpPartialReductionInterface // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace // this with a for range loop when we have it. AffineMap newMap = - linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx)); - for (int redPos : reductionDims) { - newMap = newMap.insertResult(b.getAffineDimExpr(redPos), - newMap.getNumResults()); - } + getPartialResultAffineMap(linalgOp, reductionDims, idx); newInitMaps.push_back(newMap); } @@ -476,29 +491,75 @@ struct LinalgOpPartialReductionInterface Location loc, ValueRange partialReduce, ArrayRef reductionDims) const { auto linalgOp = cast(op); - SmallVector reductionDimsInt64(reductionDims); - auto reduction = b.create( - loc, partialReduce, linalgOp.getDpsInits(), reductionDimsInt64, - [&linalgOp](OpBuilder &b, Location loc, ValueRange inputs) { - int64_t numInits = linalgOp.getNumDpsInits(); - SmallVector yieldedValues; - for (int idx : llvm::seq(0, numInits)) { + + // Permute the reduction dims as permuted by the partial result map. + + int64_t numInits = linalgOp.getNumDpsInits(); + SmallVector mergeOperations; + SmallVector replacements; + for (int idx : llvm::seq(numInits)) { + // linalg.reduce's iteration space is the tiled result's iteration space + // (and not the tiled operation's iteration space). To account for this, + // permute the reduction dimensions based on the partial result map of the + // tiled result. + AffineMap partialMap = + getPartialResultAffineMap(linalgOp, reductionDims, idx); + SmallVector partialReductionDims; + for (auto [resultNum, dimExpr] : + llvm::enumerate(partialMap.getResults())) { + unsigned dim = cast(dimExpr).getPosition(); + if (llvm::find(reductionDims, dim) != reductionDims.end()) { + partialReductionDims.push_back(resultNum); + } + } + + Value partialResult = partialReduce[idx]; + Value init = linalgOp.getDpsInits()[idx]; + + auto reduction = b.create( + loc, partialResult, init, partialReductionDims, + [&linalgOp, &idx](OpBuilder &b, Location loc, ValueRange inputs) { // Get the combiner op. SmallVector combinerOps; matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps); Operation *clonedReductionOp = b.clone(*combinerOps[0]); // Combine the input at idx and output at numInits + idx. - clonedReductionOp->setOperand(0, inputs[idx]); - clonedReductionOp->setOperand(1, inputs[numInits + idx]); - // Yield. - yieldedValues.push_back(clonedReductionOp->getResult(0)); - } - b.create(loc, yieldedValues); - }); - return MergeResult{ - {reduction.getOperation()}, - llvm::map_to_vector(reduction->getResults(), - [](OpResult r) -> Value { return r; })}; + clonedReductionOp->setOperand(0, inputs[0]); + clonedReductionOp->setOperand(1, inputs[1]); + b.create(loc, clonedReductionOp->getResult(0)); + }); + + mergeOperations.push_back(reduction); + replacements.push_back(reduction->getResult(0)); + } + + return MergeResult{mergeOperations, replacements}; + } + + LogicalResult getPartialResultTilePosition( + Operation *op, OpBuilder &b, unsigned resultNumber, + ArrayRef offsets, ArrayRef sizes, + SmallVector &resultOffsets, + SmallVector &resultSizes, + ArrayRef reductionDims) const { + auto linalgOp = cast(op); + + AffineMap partialMap = + getPartialResultAffineMap(linalgOp, reductionDims, resultNumber); + for (AffineExpr dimExpr : partialMap.getResults()) { + unsigned dim = cast(dimExpr).getPosition(); + resultSizes.push_back(sizes[dim]); + + if (llvm::find(reductionDims, dim) != reductionDims.end()) { + // Reduction dims are reduced, and are always outputed in the same + // place. So use offset 0 for them. + resultOffsets.push_back(b.getIndexAttr(0)); + } else { + resultOffsets.push_back(offsets[dim]); + } + } + + return success(); } }; diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 2277989bf8411..b548f8ce8b560 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -657,21 +657,29 @@ getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, resultOffset, resultSize); case scf::SCFTilingOptions::ReductionTilingStrategy:: PartialReductionOuterReduction: { - // TODO: This does not work for non identity accesses to the result tile. - // The proper fix is to add a getPartialResultTilePosition method to - // PartialReductionOpInterface. - resultOffset = - SmallVector(offsets.size(), rewriter.getIndexAttr(0)); - for (size_t i = 0; i < offsets.size(); i++) { - resultSize.push_back( - tensor::getMixedSize(rewriter, op.getLoc(), tiledResult, i)); + auto redOp = dyn_cast(op.getOperation()); + if (!redOp) { + return rewriter.notifyMatchFailure( + op, "PartialReductionOuterReduction tiling strategy is only supported" + "for operations implementing PartialReductionOpInterface"); } - return success(); + // Get reduction dimensions. + // TODO: PartialReductionOpInterface should really query TilingInterface + // itself and find reduction dimensions. + SmallVector reductionDims; + for (auto [idx, iteratorType] : + llvm::enumerate(op.getLoopIteratorTypes())) { + if (iteratorType == utils::IteratorType::reduction) + reductionDims.push_back(idx); + } + return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes, + resultOffset, resultSize, + reductionDims); + } default: return rewriter.notifyMatchFailure(op, "unhandled reduction tiling strategy"); } - } } static FailureOr diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir index cce4b4efa61c8..9d34c80822d0e 100644 --- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir @@ -32,8 +32,7 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor // CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor -// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor -// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor +// CHECK: %[[E:.*]] = tensor.empty(%[[D0]]) : tensor // CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor) -> tensor // CHECK: %[[L:.*]] = scf.for %[[K:.*]] = %[[C0]] to %[[D1]] step %[[C5]] iter_args(%[[ARG3:.*]] = %[[F]]) -> (tensor) { // CHECK: %[[PS:.*]] = affine.min #[[MAP0]](%[[K]])[%[[D1]]] @@ -81,13 +80,13 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)> // CHECK: func @reduction_tile_transpose -// CHECK: tensor.empty(%{{.*}}) : tensor<5x?xf32> -// CHECK: linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32> +// CHECK: tensor.empty(%{{.*}}) : tensor +// CHECK: linalg.fill {{.*}} : tensor) -> tensor // CHECK: scf.for -// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0] [%[[D0:.*]], %[[D1:.*]]] [1, 1] : tensor<5x?xf32> to tensor +// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0] [%[[D0:.*]], %[[D1:.*]]] [1, 1] : tensor to tensor // CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[L:.*]] : tensor) outs(%[[EXT]] : tensor) -// CHECK: %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor into tensor<5x?xf32> -// CHECK: scf.yield {{.*}} : tensor<5x?xf32> +// CHECK: %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor into tensor +// CHECK: scf.yield {{.*}} : tensor // CHECK: } // CHECK: linalg.reduce // CHECK: return @@ -129,8 +128,7 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor // CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor -// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor -// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor +// CHECK: %[[E:.*]] = tensor.empty(%[[D0]]) : tensor // CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor) -> tensor // CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor) { // CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]] @@ -183,9 +181,7 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor // CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor // CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor -// CHECK-DAG: %[[D3:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor -// CHECK-DAG: %[[D4:.*]] = tensor.dim %[[ARG2]], %[[C1]] : tensor -// CHECK: %[[E:.*]] = tensor.empty(%[[D3]], %[[D4]]) : tensor +// CHECK: %[[E:.*]] = tensor.empty(%[[D0]], %[[D2]]) : tensor // CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor) -> tensor // CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor) { // CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]] @@ -243,8 +239,7 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C15:.*]] = arith.constant 15 : index // CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor -// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor +// CHECK: %[[E:.*]] = tensor.empty(%[[D0]]) : tensor // CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor) -> tensor // CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor) { // CHECK: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor to tensor @@ -422,8 +417,8 @@ func.func @reduction_tile_multiple_results(%arg0: tensor, %out: tensor< module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %12, %2, %3, %loop = transform.structured.tile_reduction_using_for %0 - by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + %1, %12, %2, %3, %4, %loop = transform.structured.tile_reduction_using_for %0 + by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) transform.yield } } @@ -444,4 +439,44 @@ module attributes {transform.with_named_sequence} { // CHECK: scf.yield %[[INSERT1]], %[[INSERT1]] // CHECK: linalg.reduce // CHECK: arith.addf +// CHECK: linalg.reduce // CHECK: arith.maximumf + +// ----- + +func.func @reduction_tile_multi_dim_transpose(%arg0: tensor, %out: tensor) -> tensor { + %red = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d2, d0)>], + iterator_types = ["parallel", "reduction", "parallel"]} + ins(%arg0 : tensor) + outs(%out : tensor) { + ^bb0(%arg7: f32, %arg9: f32): + %42 = arith.addf %arg7, %arg9 : f32 + linalg.yield %42 : f32 + } -> tensor + return %red : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %2, %3, %loop = transform.structured.tile_reduction_using_for %0 + by tile_sizes = [0, 5, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2, d0, d1)> +// CHECK: func @reduction_tile_multi_dim_transpose +// CHECK: tensor.empty(%{{.*}}) : tensor +// CHECK: linalg.fill {{.*}} : tensor) -> tensor +// CHECK: scf.for +// CHECK: %[[K:.*]] = affine.min +// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0, 0] [%[[D2:.*]], %[[D0:.*]], %[[K]]] [1, 1, 1] : tensor to tensor +// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[L:.*]] : tensor) outs(%[[EXT]] : tensor) +// CHECK: %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0, 0] [%[[D2]], %[[D0]], %[[K]]] [1, 1, 1] : tensor into tensor +// CHECK: scf.yield {{.*}} : tensor +// CHECK: } +// CHECK: linalg.reduce +// CHECK: return