Skip to content

Commit 3a2e997

Browse files
committed
[AutoBump] Merge with fixes of 956c070 (Jan 29)
AMD: Move some tosa.slice folding patterns to canonicalization
2 parents 05e7f0e + 956c070 commit 3a2e997

File tree

17 files changed

+580
-276
lines changed

17 files changed

+580
-276
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1738,8 +1738,8 @@ def Tosa_SliceOp : Tosa_InferShapedTypeOp<"slice"> {
17381738

17391739
let arguments = (ins
17401740
Tosa_Tensor:$input1,
1741-
DenseI64ArrayAttr:$start,
1742-
DenseI64ArrayAttr:$size
1741+
Tosa_Shape:$start,
1742+
Tosa_Shape:$size
17431743
);
17441744

17451745
let results = (outs

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,12 +268,28 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
268268
ShapedType resultType = cast<ShapedType>(sliceOp.getType());
269269
if (llvm::isa<UnrankedTensorType>(resultType))
270270
return failure();
271+
272+
ElementsAttr startElems;
273+
ElementsAttr sizeElems;
274+
275+
if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
276+
return rewriter.notifyMatchFailure(
277+
sliceOp, "start of slice must be a static ranked shape");
278+
279+
if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
280+
return rewriter.notifyMatchFailure(
281+
sliceOp, "size of slice must be a static ranked shape");
282+
283+
llvm::SmallVector<int64_t> sliceStarts =
284+
llvm::to_vector(startElems.getValues<int64_t>());
285+
llvm::SmallVector<int64_t> sliceSizes =
286+
llvm::to_vector(sizeElems.getValues<int64_t>());
287+
271288
SmallVector<int64_t> strides, sizes;
272-
ArrayRef<int64_t> starts = sliceOp.getStart();
273289
strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1);
274290

275291
SmallVector<Value> dynSizes;
276-
for (const auto &i : llvm::enumerate(sliceOp.getSize())) {
292+
for (const auto &i : llvm::enumerate(sliceSizes)) {
277293
int64_t size = i.value();
278294
size_t index = i.index();
279295
sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
@@ -282,17 +298,27 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
282298

283299
auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
284300
auto offset = rewriter.create<arith::ConstantOp>(
285-
loc, rewriter.getIndexAttr(starts[index]));
301+
loc, rewriter.getIndexAttr(sliceStarts[index]));
286302
dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
287303
}
288304

289305
auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
290306
sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
291-
ValueRange({}), rewriter.getDenseI64ArrayAttr(starts),
307+
ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts),
292308
rewriter.getDenseI64ArrayAttr(sizes),
293309
rewriter.getDenseI64ArrayAttr(strides));
294310

295311
rewriter.replaceOp(sliceOp, newSliceOp.getResult());
312+
313+
// Remove const_shape ops when it no longer has use point.
314+
Operation *startConstShape = sliceOp.getStart().getDefiningOp();
315+
if (startConstShape->getResult(0).hasOneUse())
316+
rewriter.eraseOp(startConstShape);
317+
318+
Operation *sizeConstShape = sliceOp.getSize().getDefiningOp();
319+
if (sizeConstShape->getResult(0).hasOneUse())
320+
rewriter.eraseOp(sizeConstShape);
321+
296322
return success();
297323
}
298324
};

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 81 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -763,8 +763,19 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
763763
sliceOp, "slice input must be a static ranked tensor");
764764
int32_t axis = concatOp.getAxis();
765765

766-
llvm::SmallVector<int64_t> sliceStart(sliceOp.getStart());
767-
llvm::ArrayRef<int64_t> sliceSize = sliceOp.getSize();
766+
llvm::SmallVector<int64_t> sliceStart;
767+
if (!tosa::getConstShapeValue(sliceOp.getStart().getDefiningOp(),
768+
sliceStart)) {
769+
return rewriter.notifyMatchFailure(
770+
sliceOp, "slice start must be a constant shape");
771+
}
772+
773+
llvm::SmallVector<int64_t> sliceSize;
774+
if (!tosa::getConstShapeValue(sliceOp.getSize().getDefiningOp(),
775+
sliceSize)) {
776+
return rewriter.notifyMatchFailure(sliceOp,
777+
"slice size must be a constant shape");
778+
}
768779
llvm::SmallVector<Value> requiredConcatInputs;
769780
int64_t processedOriginalConcatInputSize = 0;
770781
int64_t droppedConcatInputSize = 0;
@@ -803,8 +814,8 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
803814
concatOp->getLoc(), requiredConcatInputs, axis);
804815
auto newSlice = rewriter.create<tosa::SliceOp>(
805816
sliceOp->getLoc(), sliceOp.getType(), newConcat,
806-
rewriter.getDenseI64ArrayAttr(sliceStart),
807-
rewriter.getDenseI64ArrayAttr(sliceSize));
817+
getTosaConstShape(rewriter, sliceOp.getStart().getLoc(), sliceStart),
818+
getTosaConstShape(rewriter, sliceOp.getSize().getLoc(), sliceSize));
808819
rewriter.replaceOp(sliceOp, newSlice);
809820
return success();
810821
}
@@ -839,8 +850,21 @@ struct TileSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
839850
SmallVector<int64_t> tileMultiplies;
840851
const LogicalResult tileHasConstantMultiplies =
841852
tileOp.getConstantMultiples(tileMultiplies);
853+
llvm::SmallVector<int64_t> sliceStartShape;
854+
if (!tosa::getConstShapeValue(sliceOp.getStart().getDefiningOp(),
855+
sliceStartShape)) {
856+
return rewriter.notifyMatchFailure(
857+
sliceOp, "slice start must be a constant shape");
858+
}
859+
860+
llvm::SmallVector<int64_t> sliceSizeShape;
861+
if (!tosa::getConstShapeValue(sliceOp.getSize().getDefiningOp(),
862+
sliceSizeShape)) {
863+
return rewriter.notifyMatchFailure(sliceOp,
864+
"slice size must be a constant shape");
865+
}
842866
for (auto [axis, sliceStart, sliceSize] :
843-
llvm::enumerate(sliceOp.getStart(), sliceOp.getSize())) {
867+
llvm::enumerate(sliceStartShape, sliceSizeShape)) {
844868
if (sliceSize <= 0) {
845869
return rewriter.notifyMatchFailure(
846870
sliceOp, "degenerate slice with zero sized dim");
@@ -878,16 +902,61 @@ struct TileSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
878902
tileOp->getOperand(0), constantShapeValue);
879903
auto newSlice = rewriter.create<tosa::SliceOp>(
880904
sliceOp->getLoc(), sliceOp.getType(), newTile,
881-
rewriter.getDenseI64ArrayAttr(newTileStarts), sliceOp.getSizeAttr());
905+
getTosaConstShape(rewriter, sliceOp.getStart().getLoc(), newTileStarts),
906+
sliceOp.getSize());
882907
rewriter.replaceOp(sliceOp, newSlice);
883908
return success();
884909
}
885910
};
886911

912+
// This pattern fuses consecutive slice operations into a single slice
913+
struct SliceSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
914+
using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
915+
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
916+
PatternRewriter &rewriter) const override {
917+
918+
auto precedingSliceOp = sliceOp.getInput1().getDefiningOp<SliceOp>();
919+
if (!precedingSliceOp)
920+
return failure();
921+
SmallVector<int64_t> precedingSliceStart;
922+
if (!tosa::getConstShapeValue(precedingSliceOp.getStart().getDefiningOp(),
923+
precedingSliceStart)) {
924+
return rewriter.notifyMatchFailure(
925+
sliceOp, "preceding slice start must be a constant shape");
926+
}
927+
SmallVector<int64_t> thisSliceStart;
928+
if (!tosa::getConstShapeValue(sliceOp.getStart().getDefiningOp(),
929+
thisSliceStart)) {
930+
return rewriter.notifyMatchFailure(
931+
sliceOp, "slice start must be a constant shape");
932+
}
933+
SmallVector<int64_t> newSliceStart;
934+
newSliceStart.reserve(precedingSliceStart.size());
935+
for (auto [startPreceding, startThis] :
936+
llvm::zip_equal(precedingSliceStart, thisSliceStart)) {
937+
newSliceStart.push_back(startPreceding + startThis);
938+
}
939+
Value newStartConst = getTosaConstShape(
940+
rewriter,
941+
rewriter.getFusedLoc({sliceOp.getStart().getLoc(),
942+
precedingSliceOp.getStart().getLoc()}),
943+
newSliceStart);
944+
rewriter.modifyOpInPlace(sliceOp, [&]() {
945+
sliceOp.getInput1Mutable().assign(precedingSliceOp.getInput1());
946+
sliceOp.getStartMutable().assign(newStartConst);
947+
sliceOp->setLoc(rewriter.getFusedLoc(
948+
{precedingSliceOp->getLoc(), sliceOp->getLoc()}));
949+
});
950+
951+
return success();
952+
}
953+
};
954+
887955
void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
888956
MLIRContext *context) {
889957
results.add<ConcatSliceOptimization>(context);
890958
results.add<TileSliceOptimization>(context);
959+
results.add<SliceSliceOptimization>(context);
891960
}
892961

893962
struct MinToClampOptimization : public OpRewritePattern<tosa::MinimumOp> {
@@ -1525,30 +1594,6 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
15251594
}
15261595

15271596
OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1528-
const auto tryFoldWithPrecedingSlice = [this](FoldAdaptor adaptor) {
1529-
auto precedingSliceOp = getInput1().getDefiningOp<SliceOp>();
1530-
if (!precedingSliceOp)
1531-
return failure();
1532-
const auto precedingSliceStart = precedingSliceOp.getStart();
1533-
const auto thisSliceStart = getStart();
1534-
SmallVector<int64_t> newSliceStart;
1535-
newSliceStart.reserve(precedingSliceStart.size());
1536-
for (auto [startPreceding, startThis] :
1537-
llvm::zip_equal(precedingSliceStart, thisSliceStart)) {
1538-
newSliceStart.push_back(startPreceding + startThis);
1539-
}
1540-
setOperand(precedingSliceOp->getOperand(0));
1541-
setStart(newSliceStart);
1542-
getOperation()->setLoc(
1543-
FusedLoc::get(getContext(), {precedingSliceOp->getLoc(), getLoc()}));
1544-
return success();
1545-
};
1546-
1547-
// First try folding the preceding slice, this also works if the shapes are
1548-
// dynamic
1549-
if (succeeded(tryFoldWithPrecedingSlice(adaptor)))
1550-
return getResult();
1551-
15521597
auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
15531598
auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
15541599

@@ -1573,7 +1618,12 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
15731618

15741619
if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
15751620
outputTy.getNumElements() == 1) {
1576-
llvm::SmallVector<uint64_t> indices(getStart());
1621+
DenseElementsAttr startElems;
1622+
if (!matchPattern(getStart(), m_Constant(&startElems)))
1623+
return {};
1624+
1625+
llvm::SmallVector<uint64_t> indices =
1626+
llvm::to_vector(startElems.getValues<uint64_t>());
15771627
auto value = operand.getValues<Attribute>()[indices];
15781628
return SplatElementsAttr::get(outputTy, value);
15791629
}

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -926,8 +926,18 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
926926
MLIRContext *context, ::std::optional<Location> location,
927927
SliceOp::Adaptor adaptor,
928928
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
929-
auto start = adaptor.getStart();
930-
auto size = adaptor.getSize();
929+
930+
Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
931+
SmallVector<int64_t> start;
932+
SmallVector<int64_t> size;
933+
934+
if (!tosa::getConstShapeValue(adaptor.getStart().getDefiningOp(), start) ||
935+
!tosa::getConstShapeValue(adaptor.getSize().getDefiningOp(), size)) {
936+
auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
937+
SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
938+
inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
939+
return success();
940+
}
931941

932942
// if size[i] is -1, all remaining elements in dimension i are included
933943
// in the slice, similar to TF.
@@ -965,8 +975,23 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
965975

966976
LogicalResult tosa::SliceOp::verify() {
967977
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
978+
if (!inputType)
979+
return success();
980+
981+
auto startShapeRank =
982+
llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
983+
if (inputType.getRank() != startShapeRank)
984+
return emitOpError(
985+
"length of start attribute is not equal rank of input shape");
986+
987+
auto sizeShapeRank =
988+
llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
989+
if (inputType.getRank() != sizeShapeRank)
990+
return emitOpError(
991+
"length of size attribute is not equal rank of input shape");
992+
968993
auto outputType = llvm::dyn_cast<RankedTensorType>(getType());
969-
if (!inputType || !outputType)
994+
if (!outputType)
970995
return success();
971996

972997
if (inputType.getRank() != outputType.getRank()) {
@@ -975,30 +1000,31 @@ LogicalResult tosa::SliceOp::verify() {
9751000
<< ") must match";
9761001
}
9771002

978-
if (static_cast<size_t>(inputType.getRank()) != getStart().size())
979-
return emitOpError(
980-
"length of start attribute is not equal rank of input shape");
981-
982-
if (static_cast<size_t>(inputType.getRank()) != getSize().size())
983-
return emitOpError(
984-
"length of size attribute is not equal rank of input shape");
1003+
SmallVector<int64_t> size;
1004+
if (!tosa::getConstShapeValue(getSize().getDefiningOp(), size)) {
1005+
return success();
1006+
}
9851007

9861008
for (int64_t dim = 0; dim < outputType.getRank(); ++dim) {
987-
if (getSize()[dim] != -1 && !outputType.isDynamicDim(dim) &&
988-
getSize()[dim] != outputType.getShape()[dim]) {
989-
return emitOpError() << "size attribute (" << getSize()[dim]
1009+
if (size[dim] != -1 && !outputType.isDynamicDim(dim) &&
1010+
size[dim] != outputType.getShape()[dim]) {
1011+
return emitOpError() << "size (" << size[dim]
9901012
<< ") does not match output type ("
9911013
<< outputType.getShape()[dim] << ") in dimension "
9921014
<< dim;
9931015
}
9941016
}
9951017

1018+
SmallVector<int64_t> start;
1019+
if (!tosa::getConstShapeValue(getStart().getDefiningOp(), start)) {
1020+
return success();
1021+
}
1022+
9961023
for (int i = 0; i < inputType.getRank(); ++i) {
997-
if (getSize()[i] != -1 && !inputType.isDynamicDim(i) &&
998-
getStart()[i] + getSize()[i] > inputType.getShape()[i]) {
999-
return emitOpError() << "start (" << getStart()[i] << ") plus size ("
1000-
<< getSize()[i]
1001-
<< ") goes out of bounds of input size ("
1024+
if (size[i] != -1 && !inputType.isDynamicDim(i) &&
1025+
start[i] + size[i] > inputType.getShape()[i]) {
1026+
return emitOpError() << "start (" << start[i] << ") plus size ("
1027+
<< size[i] << ") goes out of bounds of input size ("
10021028
<< inputType.getShape()[i] << ") in dimension " << i;
10031029
}
10041030
}
@@ -1265,7 +1291,7 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
12651291
return emitOpError() << "cannot reshape " << inputElementsNum
12661292
<< " elements into " << outputElementsNum;
12671293
}
1268-
1294+
12691295
if ((int64_t)getNewShape().size() != outputType.getRank()) {
12701296
return emitOpError()
12711297
<< "rank of newShape (" << getNewShape().size()

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,8 @@ class TransposeConvStridedConverter
302302

303303
auto slice = CreateOpAndInferShape<tosa::SliceOp>(
304304
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
305-
rewriter.getDenseI64ArrayAttr(sliceBegin),
306-
rewriter.getDenseI64ArrayAttr(sliceSize))
305+
getTosaConstShape(rewriter, loc, sliceBegin),
306+
getTosaConstShape(rewriter, loc, sliceSize))
307307
.getResult();
308308

309309
llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};

0 commit comments

Comments
 (0)