Skip to content

Commit a3cba7e

Browse files
committed
Revert "[mlir][tosa] Change 'shape' of RESHAPE from attribute to input shape … (llvm#125789)"
This reverts commit 571a987.
1 parent 73f11ac commit a3cba7e

25 files changed

+246
-449
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1625,7 +1625,7 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape"> {
16251625

16261626
let arguments = (ins
16271627
Tosa_Tensor:$input1,
1628-
Tosa_Shape:$shape
1628+
DenseI64ArrayAttr:$new_shape
16291629
);
16301630

16311631
let results = (outs

mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,11 +230,8 @@ SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
230230
}
231231

232232
// Computes shape value using tosa const_shape op.
233-
Value getTosaConstShape(ImplicitLocOpBuilder &builder,
234-
llvm::ArrayRef<int64_t> shape);
235233
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
236234
llvm::ArrayRef<int64_t> shape);
237-
238235
SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);
239236

240237
bool getConstShapeValue(Operation *op,

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1954,10 +1954,9 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
19541954
nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
19551955
});
19561956

1957-
auto shapeValue = getTosaConstShape(
1958-
rewriter, loc, mlir::tosa::convertFromMlirShape(resultTy.getShape()));
19591957
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
1960-
op, resultTy, genericOp.getResult(0), shapeValue);
1958+
op, resultTy, genericOp.getResult(0),
1959+
rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
19611960
return success();
19621961
}
19631962
};

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1717
#include "mlir/Dialect/Tensor/Utils/Utils.h"
1818
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
19-
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
2019
#include "mlir/IR/PatternMatch.h"
2120
#include "mlir/Transforms/DialectConversion.h"
2221

@@ -236,12 +235,7 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
236235
return rewriter.notifyMatchFailure(reshape.getLoc(),
237236
"expected input type to be tensor");
238237
}
239-
240-
llvm::SmallVector<int64_t> newShape;
241-
if (!tosa::getConstShapeValue(reshape.getShape().getDefiningOp(),
242-
newShape)) {
243-
return failure();
244-
}
238+
auto newShape = reshape.getNewShape();
245239

246240
// Infer all intermediate types
247241
auto inputType = inferReshapeInputType(input, newShape);

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
180180

181181
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
182182
op, op.getType(), op.getInput1(),
183-
getTosaConstShape(rewriter, op.getLoc(), newShape));
183+
rewriter.getDenseI64ArrayAttr(newShape));
184184
return success();
185185
}
186186
};
@@ -948,12 +948,8 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
948948
if (!getInput1().hasOneUse())
949949
return {};
950950

951-
llvm::SmallVector<int64_t> shapeVec;
952-
if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeVec))
953-
return {};
954-
955951
return operand.reshape(
956-
llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
952+
llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
957953
}
958954

959955
return {};

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,16 +1335,8 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
13351335
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
13361336
ShapeAdaptor inputShape(adaptor.getInput1().getType());
13371337
Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
1338-
llvm::SmallVector<int64_t> newShapeValue;
1339-
if (!tosa::getConstShapeValue(adaptor.getShape().getDefiningOp(),
1340-
newShapeValue)) {
1341-
auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
1342-
SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
1343-
inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
1344-
return success();
1345-
} else {
1346-
newShapeValue = convertToMlirShape(newShapeValue);
1347-
}
1338+
llvm::SmallVector<int64_t> newShapeValue =
1339+
convertToMlirShape(adaptor.getNewShape());
13481340

13491341
// We cannot infer from the total number of elements so we must take the
13501342
// shape attribute as exact.
@@ -1380,19 +1372,13 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
13801372
TensorType inputType = getInput1().getType();
13811373
RankedTensorType outputType = getType();
13821374

1383-
SmallVector<int64_t> shapeValues;
1384-
if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeValues)) {
1385-
// skip following checks if shape is not constant
1386-
return mlir::success();
1387-
}
1388-
1389-
if ((int64_t)shapeValues.size() != outputType.getRank())
1375+
if ((int64_t)getNewShape().size() != outputType.getRank())
13901376
return emitOpError() << "new shape does not match result rank";
13911377

13921378
for (auto [newShapeDim, outputShapeDim] :
1393-
zip(shapeValues, outputType.getShape())) {
1394-
if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
1395-
outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
1379+
zip(getNewShape(), outputType.getShape())) {
1380+
if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
1381+
newShapeDim != outputShapeDim)
13961382
return emitOpError() << "new shape is inconsistent with result shape";
13971383

13981384
if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
@@ -1411,18 +1397,18 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
14111397
}
14121398

14131399
int64_t newShapeElementsNum = std::accumulate(
1414-
shapeValues.begin(), shapeValues.end(), 1LL,
1400+
getNewShape().begin(), getNewShape().end(), 1LL,
14151401
[](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
14161402
bool isStaticNewShape =
1417-
llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
1403+
llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; });
14181404
if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
14191405
(!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
14201406
return emitOpError() << "cannot reshape " << inputElementsNum
14211407
<< " elements into " << newShapeElementsNum;
14221408
}
14231409
}
14241410

1425-
int missingDims = llvm::count(shapeValues, -1);
1411+
int missingDims = llvm::count(getNewShape(), -1);
14261412
if (missingDims > 1)
14271413
return emitOpError() << "expected at most one target dimension to be -1";
14281414

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ using namespace mlir::tosa;
2020

2121
namespace {
2222

23+
SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape) {
24+
return to_vector(llvm::map_range(shape, [](int64_t dim) {
25+
return ShapedType::isDynamic(dim) ? -1 : dim;
26+
}));
27+
}
28+
2329
struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
2430
explicit Conv2DIsFullyConnected(MLIRContext *context)
2531
: OpRewritePattern(context) {}
@@ -92,27 +98,25 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
9298
llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]};
9399
auto revisedInputShapeType =
94100
RankedTensorType::get(revisedInputShape, inputType.getElementType());
95-
auto revisedInputShapeValue = getTosaConstShape(
96-
rewriter, op.getLoc(), convertFromMlirShape(revisedInputShape));
97-
auto reshapedInput =
98-
rewriter
99-
.create<tosa::ReshapeOp>(op.getLoc(), revisedInputShapeType, input,
100-
revisedInputShapeValue)
101-
.getResult();
101+
auto reshapedInput = rewriter
102+
.create<tosa::ReshapeOp>(
103+
op.getLoc(), revisedInputShapeType, input,
104+
rewriter.getDenseI64ArrayAttr(
105+
convertFromMlirShape(revisedInputShape)))
106+
.getResult();
102107

103108
// Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
104109
llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
105110
weightShape[3]};
106111
auto revisedWeightShapeType = RankedTensorType::get(
107112
revisedWeightShape,
108113
dyn_cast<RankedTensorType>(weight.getType()).getElementType());
109-
auto revisedWeightShapeValue = getTosaConstShape(
110-
rewriter, op.getLoc(), convertFromMlirShape(revisedWeightShape));
111-
auto reshapedWeight =
112-
rewriter
113-
.create<tosa::ReshapeOp>(op.getLoc(), revisedWeightShapeType,
114-
weight, revisedWeightShapeValue)
115-
.getResult();
114+
auto reshapedWeight = rewriter
115+
.create<tosa::ReshapeOp>(
116+
op.getLoc(), revisedWeightShapeType, weight,
117+
rewriter.getDenseI64ArrayAttr(
118+
convertFromMlirShape(revisedWeightShape)))
119+
.getResult();
116120

117121
// Perform a fully connected network over the reshaped input and weight.
118122
llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]};
@@ -145,10 +149,9 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
145149
// Reshape output to [N, IH, IW, OC].
146150
llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
147151
inputShape[2], weightShape[0]};
148-
auto outputShapeValue = getTosaConstShape(
149-
rewriter, op.getLoc(), convertFromMlirShape(outputShape));
150152
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
151-
op, resultType, fullyConnectedValue, outputShapeValue);
153+
op, resultType, fullyConnectedValue,
154+
rewriter.getDenseI64ArrayAttr(convertFromMlirShape(outputShape)));
152155
return success();
153156
}
154157
};

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,10 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
5555
inputType = RankedTensorType::get(
5656
revisedInputShape,
5757
dyn_cast<RankedTensorType>(input.getType()).getElementType());
58-
auto revisedInputShapeValue =
59-
getTosaConstShape(rewriter, op.getLoc(), revisedInputShape);
6058
input = rewriter
61-
.create<tosa::ReshapeOp>(op.getLoc(), inputType, input,
62-
revisedInputShapeValue)
59+
.create<tosa::ReshapeOp>(
60+
op.getLoc(), inputType, input,
61+
rewriter.getDenseI64ArrayAttr(revisedInputShape))
6362
.getResult();
6463

6564
Type inputETy = inputType.getElementType();
@@ -154,10 +153,9 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
154153
auto outputShapeType = RankedTensorType::get(
155154
outputShape,
156155
dyn_cast<RankedTensorType>(input.getType()).getElementType());
157-
auto outputShapeValue =
158-
getTosaConstShape(rewriter, op->getLoc(), outputShape);
159156
Value outputValue = rewriter.create<tosa::ReshapeOp>(
160-
op.getLoc(), outputShapeType, mulValue, outputShapeValue);
157+
op.getLoc(), outputShapeType, mulValue,
158+
rewriter.getDenseI64ArrayAttr(outputShape));
161159

162160
Value bias = op.getBias();
163161
if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,9 @@ class TransposeConvStridedConverter
159159
outputChannels, weightHeight / stride[0],
160160
stride[0], weightWidth / stride[1],
161161
stride[1], inputChannels};
162-
163-
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
164162
weight = CreateOpAndInferShape<tosa::ReshapeOp>(
165-
builder, UnrankedTensorType::get(weightETy), weight,
166-
getTosaConstShape(rewriter, loc, weightReshapeDims0));
163+
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
164+
rewriter.getDenseI64ArrayAttr(weightReshapeDims0));
167165

168166
// Transpose the factored-out stride to the output channels.
169167
Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
@@ -175,13 +173,12 @@ class TransposeConvStridedConverter
175173
transposeWeightVal);
176174

177175
// Collapse the strides and output channels into a single dimension.
178-
llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
176+
llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
179177
outputChannels * stride[0] * stride[1], weightHeight / stride[0],
180178
weightWidth / stride[1], inputChannels};
181-
182179
weight = CreateOpAndInferShape<tosa::ReshapeOp>(
183180
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
184-
getTosaConstShape(rewriter, loc, weightReshapeDims1));
181+
rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
185182
ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());
186183

187184
weight = CreateOpAndInferShape<tosa::ReverseOp>(
@@ -260,13 +257,9 @@ class TransposeConvStridedConverter
260257
// Factor striding out of the convolution result.
261258
llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
262259
batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
263-
264-
auto convReshapeDims0Value =
265-
getTosaConstShape(rewriter, loc, convReshapeDims0);
266-
267260
conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
268261
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
269-
convReshapeDims0Value);
262+
rewriter.getDenseI64ArrayAttr(convReshapeDims0));
270263

271264
// Transpose the factored-out stride to the output channels.
272265
Value transposeConvVal = rewriter.create<tosa::ConstOp>(
@@ -280,13 +273,9 @@ class TransposeConvStridedConverter
280273
// Fuse striding behavior back into width / height.
281274
llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
282275
batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
283-
284-
auto convReshapeDims1Value =
285-
getTosaConstShape(rewriter, loc, convReshapeDims1);
286-
287276
conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
288277
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
289-
convReshapeDims1Value);
278+
rewriter.getDenseI64ArrayAttr(convReshapeDims1));
290279

291280
// Determine the amount to slice / pad from the result start.
292281
int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]);

mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -402,20 +402,13 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
402402
return std::nullopt;
403403

404404
// Do not insert a TransposeOp, instead we fold the reshape and its attribute.
405-
llvm::SmallVector<int64_t> newShape;
406-
if (!tosa::getConstShapeValue(reshapeOp.getShape().getDefiningOp(),
407-
newShape)) {
408-
// this mean shape is not constant
409-
return std::nullopt;
410-
}
411-
ImplicitLocOpBuilder builder(reshapeOp.getLoc(), rewriter);
412405
auto foldedReshape = rewriter.create<ReshapeOp>(
413406
reshapeOp.getLoc(),
414407
RankedTensorType::get(applyTOSAPermutation(shape, hoistedPerms),
415408
reshapeOutputType.getElementType()),
416409
reshapeOp.getInput1(),
417-
getTosaConstShape(builder, applyTOSAPermutation(llvm::ArrayRef(newShape),
418-
hoistedPerms)));
410+
rewriter.getDenseI64ArrayAttr(
411+
applyTOSAPermutation(reshapeOp.getNewShape(), hoistedPerms)));
419412
return foldedReshape->getResult(0);
420413
}
421414

0 commit comments

Comments
 (0)