Skip to content

Commit d70e7b1

Browse files
justin-ngo-armvivekkhandelwal1
authored andcommitted
[TOSA] Update tosa.slice's start and size to !tosa.shape type
* In TOSA 1.0, tosa.slice's `start` and `size` are !tosa.shape types. Update tosa.slice in Torch to TOSA in alignment with that. * Update LIT tests. Signed-off-by: Justin Ngo <[email protected]> Change-Id: Icf878ea4dc43ec1af3bd498b5ae96f514fe0f04a
1 parent 06899a8 commit d70e7b1

File tree

2 files changed

+109
-72
lines changed

2 files changed

+109
-72
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4010,8 +4010,8 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
40104010

40114011
rewriter.replaceOpWithNewOp<tosa::SliceOp>(
40124012
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
4013-
rewriter.getDenseI64ArrayAttr(startSlice),
4014-
rewriter.getDenseI64ArrayAttr(sizeSlice));
4013+
tosa::getTosaConstShape(rewriter, op->getLoc(), startSlice),
4014+
tosa::getTosaConstShape(rewriter, op->getLoc(), sizeSlice));
40154015

40164016
return success();
40174017
}
@@ -7143,8 +7143,8 @@ LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
71437143
startSlice[targetDim1] = std::abs(offset);
71447144
diagonalTensor = rewriter.create<tosa::SliceOp>(
71457145
op->getLoc(), transposedInputType, diagonalTensor,
7146-
rewriter.getDenseI64ArrayAttr(startSlice),
7147-
rewriter.getDenseI64ArrayAttr(sizeSlice));
7146+
tosa::getTosaConstShape(rewriter, op->getLoc(), startSlice),
7147+
tosa::getTosaConstShape(rewriter, op->getLoc(), sizeSlice));
71487148
}
71497149

71507150
// Apply Reduce Sum to get the result
@@ -7669,8 +7669,8 @@ Value reflectionPadAlongAxis(Value input, ArrayRef<int64_t> unpaddedShape,
76697669
auto leftPadType = RankedTensorType::get(leftPadShape, inputElemTy);
76707670

76717671
auto leftPadSlice = rewriter.create<tosa::SliceOp>(
7672-
loc, leftPadType, input, rewriter.getDenseI64ArrayAttr(leftStartSlice),
7673-
rewriter.getDenseI64ArrayAttr(leftSizeSlice));
7672+
loc, leftPadType, input, tosa::getTosaConstShape(rewriter, loc, leftStartSlice),
7673+
tosa::getTosaConstShape(rewriter, loc, leftSizeSlice));
76747674

76757675
auto leftPad = rewriter.create<tosa::ReverseOp>(
76767676
loc, leftPadType, leftPadSlice.getResult(), static_cast<int32_t>(axis));
@@ -7702,8 +7702,8 @@ Value reflectionPadAlongAxis(Value input, ArrayRef<int64_t> unpaddedShape,
77027702

77037703
auto rightPadSlice = rewriter.create<tosa::SliceOp>(
77047704
loc, rightPadType, input,
7705-
rewriter.getDenseI64ArrayAttr(rightStartSlice),
7706-
rewriter.getDenseI64ArrayAttr(rightSizeSlice));
7705+
tosa::getTosaConstShape(rewriter, loc, rightStartSlice),
7706+
tosa::getTosaConstShape(rewriter, loc, rightSizeSlice));
77077707

77087708
auto rightPad = rewriter.create<tosa::ReverseOp>(
77097709
loc, rightPadType, rightPadSlice.getResult(),
@@ -7949,8 +7949,8 @@ LogicalResult ConvertAtenOp<AtenReplicationPad2dOp>::matchAndRewrite(
79497949

79507950
auto leftPadSlice = rewriter.create<tosa::SliceOp>(
79517951
op->getLoc(), leftPadSliceType, self,
7952-
rewriter.getDenseI64ArrayAttr(leftStartSlice),
7953-
rewriter.getDenseI64ArrayAttr(leftSizeSlice));
7952+
tosa::getTosaConstShape(rewriter, op->getLoc(), leftStartSlice),
7953+
tosa::getTosaConstShape(rewriter, op->getLoc(), leftSizeSlice));
79547954

79557955
for (int64_t i = 0; i < paddingLeft; i++)
79567956
sideTensors.push_back(leftPadSlice.getResult());
@@ -7974,8 +7974,8 @@ LogicalResult ConvertAtenOp<AtenReplicationPad2dOp>::matchAndRewrite(
79747974

79757975
auto rightPadSlice = rewriter.create<tosa::SliceOp>(
79767976
op->getLoc(), rightPadSliceType, self,
7977-
rewriter.getDenseI64ArrayAttr(rightStartSlice),
7978-
rewriter.getDenseI64ArrayAttr(rightSizeSlice));
7977+
tosa::getTosaConstShape(rewriter, op->getLoc(), rightStartSlice),
7978+
tosa::getTosaConstShape(rewriter, op->getLoc(), rightSizeSlice));
79797979

79807980
for (int64_t i = 0; i < paddingRight; i++)
79817981
sideTensors.push_back(rightPadSlice.getResult());
@@ -8009,8 +8009,8 @@ LogicalResult ConvertAtenOp<AtenReplicationPad2dOp>::matchAndRewrite(
80098009

80108010
auto topPadSlice = rewriter.create<tosa::SliceOp>(
80118011
op->getLoc(), topPadSliceType, selfSidePadded,
8012-
rewriter.getDenseI64ArrayAttr(topStartSlice),
8013-
rewriter.getDenseI64ArrayAttr(topSizeSlice));
8012+
tosa::getTosaConstShape(rewriter, op->getLoc(), topStartSlice),
8013+
tosa::getTosaConstShape(rewriter, op->getLoc(), topSizeSlice));
80148014

80158015
for (int64_t i = 0; i < paddingTop; i++)
80168016
resultTensors.push_back(topPadSlice.getResult());
@@ -8037,8 +8037,8 @@ LogicalResult ConvertAtenOp<AtenReplicationPad2dOp>::matchAndRewrite(
80378037

80388038
auto bottomPadSlice = rewriter.create<tosa::SliceOp>(
80398039
op->getLoc(), bottomPadSliceType, selfSidePadded,
8040-
rewriter.getDenseI64ArrayAttr(bottomStartSlice),
8041-
rewriter.getDenseI64ArrayAttr(bottomSizeSlice));
8040+
tosa::getTosaConstShape(rewriter, op->getLoc(), bottomStartSlice),
8041+
tosa::getTosaConstShape(rewriter, op->getLoc(), bottomSizeSlice));
80428042

80438043
for (int64_t i = 0; i < paddingBottom; i++)
80448044
resultTensors.push_back(bottomPadSlice.getResult());

0 commit comments

Comments
 (0)