Skip to content

Commit e895262

Browse files
committed
Convert towards a proper fix
1 parent 3241384 commit e895262

File tree

3 files changed

+9
-17
lines changed

3 files changed

+9
-17
lines changed

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -798,15 +798,8 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
798798
// transformations such as padding and bufferization since the
799799
// extract/insert slice pairs make the accessed iteration argument
800800
// subdomains explicit.
801-
SmallVector<int64_t> domainSizes;
802-
// FIXME: tileToPartialReduction adds the new init tensor to the output
803-
// but doesn't update the indexing type of the index map causing a crash.
804-
// isAllParallelLoops
805-
if (linalgOp.getNumParallelLoops() == linalgOp.getNumLoops()) {
806-
domainSizes = linalgOp.getStaticLoopRanges();
807-
}
808801
Type operandType = opOperand.get().getType();
809-
if (!isTiled(map, tileSizes, domainSizes) &&
802+
if (!isTiled(map, tileSizes, linalgOp.getStaticLoopRanges()) &&
810803
!(isa<RankedTensorType>(operandType) &&
811804
linalgOp.isDpsInit(&opOperand))) {
812805
allSliceParams.push_back(std::nullopt);
@@ -818,7 +811,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
818811

819812
allSliceParams.push_back(computeSliceParameters(
820813
builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
821-
omitPartialTileCheck, domainSizes));
814+
omitPartialTileCheck, linalgOp.getStaticLoopRanges()));
822815
}
823816

824817
return allSliceParams;

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -984,8 +984,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
984984

985985
// 4a. Clone the operation.
986986
{
987-
auto clonedOp = cast<PartialReductionOpInterface>(
988-
cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
987+
auto clonedOp = cast<PartialReductionOpInterface>(rewriter.clone(*op));
989988

990989
// 4b. Tile the cloned operation.
991990
FailureOr<TilingResult> partialTilingResult =

mlir/test/Dialect/Linalg/tile-tensors.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,16 +170,16 @@ module attributes {transform.with_named_sequence} {
170170

171171
// -----
172172

173-
174-
// CHECK-LABEL: func @generic_op_tensors
175-
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
176-
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
173+
// CHECK-LABEL: func @non_monotonic_affine_expr
174+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<7xf32>
177175
func.func @non_monotonic_affine_expr(%arg0 : tensor<7xf32>) -> tensor<7xf32> {
178176
%c0 = arith.constant 0 : index
179177
%0 = tensor.dim %arg0, %c0 : tensor<7xf32>
180178
%empty = tensor.empty() : tensor<7xf32>
181-
// FIXME: Do proper testing
182-
// CHECK: tensor.extract_slice %arg2[0] [7] [1] : tensor<7xf32> to tensor<7xf32>
179+
180+
// CHECK: %[[OUT:.*]] = tensor.empty() : tensor<7xf32>
181+
// CHECK: scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[OUT]]) -> (tensor<7xf32>) {
182+
// CHECK: tensor.extract_slice %[[TC0]][0] [7] [1] : tensor<7xf32> to tensor<7xf32>
183183
%generic = linalg.generic
184184
{indexing_maps = [affine_map<(d0) -> (d0 mod 4)>,
185185
affine_map<(d0) -> (d0)>],

0 commit comments

Comments
 (0)