Skip to content

Commit d59e6b9

Browse files
committed
Fix problem where the shape of the insert shape was calculated incorrectly
1 parent 479c8d6 commit d59e6b9

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,15 @@ struct LinalgOpTilingInterface
218218
}));
219219

220220
OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
221+
SmallVector<OpFoldResult> allShapeSizes =
222+
linalgOp.createFlatListOfOperandDims(b, linalgOp.getLoc());
223+
SmallVector<OpFoldResult> sizeBounds =
224+
mlir::affine::makeComposedFoldedMultiResultAffineApply(
225+
b, loc, linalgOp.getShapesToLoopsMap(), allShapeSizes);
221226
SliceParameters sliceParams = computeSliceParameters(
222227
b, loc, outOperand->get(), sizes,
223228
linalgOp.getMatchingIndexingMap(outOperand), offsets,
224-
/*ubs*/ {}, subShapeSizes, true);
229+
/*ubs*/ sizeBounds, subShapeSizes, true);
225230
resultOffsets = sliceParams.offsets;
226231
resultSizes = sliceParams.sizes;
227232
return success();

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1905,9 +1905,10 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
19051905
SmallVector<SmallVector<OpFoldResult>> resultSizes(
19061906
totalNumResultsOfConsumer);
19071907
for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
1908-
if (failed(tiledConsumerOp.getResultTilePosition(
1909-
rewriter, idx, iterDomainOffsets, iterDomainSizes,
1910-
resultOffsets[idx], resultSizes[idx]))) {
1908+
if (failed(cast<TilingInterface>(clonedConsumerOp)
1909+
.getResultTilePosition(rewriter, idx, iterDomainOffsets,
1910+
iterDomainSizes, resultOffsets[idx],
1911+
resultSizes[idx]))) {
19111912
return rewriter.notifyMatchFailure(
19121913
tiledConsumerOp,
19131914
"can't get result domain position from iter domain position");

0 commit comments

Comments
 (0)