Skip to content

Commit 3241384

Browse files
committed
Fix some testing
1 parent 55ae384 commit 3241384

File tree

3 files changed

+13
-396
lines changed

3 files changed

+13
-396
lines changed

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -798,9 +798,15 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
798798
// transformations such as padding and bufferization since the
799799
// extract/insert slice pairs make the accessed iteration argument
800800
// subdomains explicit.
801-
801+
SmallVector<int64_t> domainSizes;
802+
// FIXME: tileToPartialReduction adds the new init tensor to the output
803+
// but doesn't update the indexing type of the index map causing a crash.
804+
// isAllParallelLoops
805+
if (linalgOp.getNumParallelLoops() == linalgOp.getNumLoops()) {
806+
domainSizes = linalgOp.getStaticLoopRanges();
807+
}
802808
Type operandType = opOperand.get().getType();
803-
if (!isTiled(map, tileSizes, linalgOp.getStaticLoopRanges()) &&
809+
if (!isTiled(map, tileSizes, domainSizes) &&
804810
!(isa<RankedTensorType>(operandType) &&
805811
linalgOp.isDpsInit(&opOperand))) {
806812
allSliceParams.push_back(std::nullopt);
@@ -812,7 +818,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
812818

813819
allSliceParams.push_back(computeSliceParameters(
814820
builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
815-
omitPartialTileCheck, linalgOp.getStaticLoopRanges()));
821+
omitPartialTileCheck, domainSizes));
816822
}
817823

818824
return allSliceParams;

mlir/test/Dialect/Linalg/tile-tensors.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,10 @@ module attributes {transform.with_named_sequence} {
170170

171171
// -----
172172

173+
174+
// CHECK-LABEL: func @generic_op_tensors
175+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
176+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
173177
func.func @non_monotonic_affine_expr(%arg0 : tensor<7xf32>) -> tensor<7xf32> {
174178
%c0 = arith.constant 0 : index
175179
%0 = tensor.dim %arg0, %c0 : tensor<7xf32>

0 commit comments

Comments
 (0)