Skip to content

Commit e0c8eb6

Browse files
tatwaichongwonjeon
authored andcommitted
[mlir][tosa] Change 'shape' of RESHAPE from attribute to input shape … (llvm#125789)
The shape operand is changed to input shape type since V1.0 Change-Id: I508cc1d67e9b017048b3f29fecf202cb7d707110 Co-authored-by: Won Jeon <[email protected]>
1 parent a5fb2e9 commit e0c8eb6

25 files changed

+449
-246
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1625,7 +1625,7 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape"> {
16251625

16261626
let arguments = (ins
16271627
Tosa_Tensor:$input1,
1628-
DenseI64ArrayAttr:$new_shape
1628+
Tosa_Shape:$shape
16291629
);
16301630

16311631
let results = (outs

mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,11 @@ SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
230230
}
231231

232232
// Computes shape value using tosa const_shape op.
233+
Value getTosaConstShape(ImplicitLocOpBuilder &builder,
234+
llvm::ArrayRef<int64_t> shape);
233235
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
234236
llvm::ArrayRef<int64_t> shape);
237+
235238
SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);
236239

237240
bool getConstShapeValue(Operation *op,

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1954,9 +1954,10 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
19541954
nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
19551955
});
19561956

1957+
auto shapeValue = getTosaConstShape(
1958+
rewriter, loc, mlir::tosa::convertFromMlirShape(resultTy.getShape()));
19571959
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
1958-
op, resultTy, genericOp.getResult(0),
1959-
rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
1960+
op, resultTy, genericOp.getResult(0), shapeValue);
19601961
return success();
19611962
}
19621963
};

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1717
#include "mlir/Dialect/Tensor/Utils/Utils.h"
1818
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
19+
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
1920
#include "mlir/IR/PatternMatch.h"
2021
#include "mlir/Transforms/DialectConversion.h"
2122

@@ -235,7 +236,12 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
235236
return rewriter.notifyMatchFailure(reshape.getLoc(),
236237
"expected input type to be tensor");
237238
}
238-
auto newShape = reshape.getNewShape();
239+
240+
llvm::SmallVector<int64_t> newShape;
241+
if (!tosa::getConstShapeValue(reshape.getShape().getDefiningOp(),
242+
newShape)) {
243+
return failure();
244+
}
239245

240246
// Infer all intermediate types
241247
auto inputType = inferReshapeInputType(input, newShape);

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
180180

181181
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
182182
op, op.getType(), op.getInput1(),
183-
rewriter.getDenseI64ArrayAttr(newShape));
183+
getTosaConstShape(rewriter, op.getLoc(), newShape));
184184
return success();
185185
}
186186
};
@@ -948,8 +948,12 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
948948
if (!getInput1().hasOneUse())
949949
return {};
950950

951+
llvm::SmallVector<int64_t> shapeVec;
952+
if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeVec))
953+
return {};
954+
951955
return operand.reshape(
952-
llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
956+
llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
953957
}
954958

955959
return {};

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,8 +1335,16 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
13351335
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
13361336
ShapeAdaptor inputShape(adaptor.getInput1().getType());
13371337
Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
1338-
llvm::SmallVector<int64_t> newShapeValue =
1339-
convertToMlirShape(adaptor.getNewShape());
1338+
llvm::SmallVector<int64_t> newShapeValue;
1339+
if (!tosa::getConstShapeValue(adaptor.getShape().getDefiningOp(),
1340+
newShapeValue)) {
1341+
auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
1342+
SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
1343+
inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
1344+
return success();
1345+
} else {
1346+
newShapeValue = convertToMlirShape(newShapeValue);
1347+
}
13401348

13411349
// We cannot infer from the total number of elements so we must take the
13421350
// shape attribute as exact.
@@ -1372,13 +1380,19 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
13721380
TensorType inputType = getInput1().getType();
13731381
RankedTensorType outputType = getType();
13741382

1375-
if ((int64_t)getNewShape().size() != outputType.getRank())
1383+
SmallVector<int64_t> shapeValues;
1384+
if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeValues)) {
1385+
// skip following checks if shape is not constant
1386+
return mlir::success();
1387+
}
1388+
1389+
if ((int64_t)shapeValues.size() != outputType.getRank())
13761390
return emitOpError() << "new shape does not match result rank";
13771391

13781392
for (auto [newShapeDim, outputShapeDim] :
1379-
zip(getNewShape(), outputType.getShape())) {
1380-
if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
1381-
newShapeDim != outputShapeDim)
1393+
zip(shapeValues, outputType.getShape())) {
1394+
if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
1395+
outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
13821396
return emitOpError() << "new shape is inconsistent with result shape";
13831397

13841398
if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
@@ -1397,18 +1411,18 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
13971411
}
13981412

13991413
int64_t newShapeElementsNum = std::accumulate(
1400-
getNewShape().begin(), getNewShape().end(), 1LL,
1414+
shapeValues.begin(), shapeValues.end(), 1LL,
14011415
[](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
14021416
bool isStaticNewShape =
1403-
llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; });
1417+
llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
14041418
if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
14051419
(!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
14061420
return emitOpError() << "cannot reshape " << inputElementsNum
14071421
<< " elements into " << newShapeElementsNum;
14081422
}
14091423
}
14101424

1411-
int missingDims = llvm::count(getNewShape(), -1);
1425+
int missingDims = llvm::count(shapeValues, -1);
14121426
if (missingDims > 1)
14131427
return emitOpError() << "expected at most one target dimension to be -1";
14141428

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,6 @@ using namespace mlir::tosa;
2020

2121
namespace {
2222

23-
SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape) {
24-
return to_vector(llvm::map_range(shape, [](int64_t dim) {
25-
return ShapedType::isDynamic(dim) ? -1 : dim;
26-
}));
27-
}
28-
2923
struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
3024
explicit Conv2DIsFullyConnected(MLIRContext *context)
3125
: OpRewritePattern(context) {}
@@ -98,25 +92,27 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
9892
llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]};
9993
auto revisedInputShapeType =
10094
RankedTensorType::get(revisedInputShape, inputType.getElementType());
101-
auto reshapedInput = rewriter
102-
.create<tosa::ReshapeOp>(
103-
op.getLoc(), revisedInputShapeType, input,
104-
rewriter.getDenseI64ArrayAttr(
105-
convertFromMlirShape(revisedInputShape)))
106-
.getResult();
95+
auto revisedInputShapeValue = getTosaConstShape(
96+
rewriter, op.getLoc(), convertFromMlirShape(revisedInputShape));
97+
auto reshapedInput =
98+
rewriter
99+
.create<tosa::ReshapeOp>(op.getLoc(), revisedInputShapeType, input,
100+
revisedInputShapeValue)
101+
.getResult();
107102

108103
// Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
109104
llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
110105
weightShape[3]};
111106
auto revisedWeightShapeType = RankedTensorType::get(
112107
revisedWeightShape,
113108
dyn_cast<RankedTensorType>(weight.getType()).getElementType());
114-
auto reshapedWeight = rewriter
115-
.create<tosa::ReshapeOp>(
116-
op.getLoc(), revisedWeightShapeType, weight,
117-
rewriter.getDenseI64ArrayAttr(
118-
convertFromMlirShape(revisedWeightShape)))
119-
.getResult();
109+
auto revisedWeightShapeValue = getTosaConstShape(
110+
rewriter, op.getLoc(), convertFromMlirShape(revisedWeightShape));
111+
auto reshapedWeight =
112+
rewriter
113+
.create<tosa::ReshapeOp>(op.getLoc(), revisedWeightShapeType,
114+
weight, revisedWeightShapeValue)
115+
.getResult();
120116

121117
// Perform a fully connected network over the reshaped input and weight.
122118
llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]};
@@ -149,9 +145,10 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
149145
// Reshape output to [N, IH, IW, OC].
150146
llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
151147
inputShape[2], weightShape[0]};
148+
auto outputShapeValue = getTosaConstShape(
149+
rewriter, op.getLoc(), convertFromMlirShape(outputShape));
152150
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
153-
op, resultType, fullyConnectedValue,
154-
rewriter.getDenseI64ArrayAttr(convertFromMlirShape(outputShape)));
151+
op, resultType, fullyConnectedValue, outputShapeValue);
155152
return success();
156153
}
157154
};

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,11 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
5555
inputType = RankedTensorType::get(
5656
revisedInputShape,
5757
dyn_cast<RankedTensorType>(input.getType()).getElementType());
58+
auto revisedInputShapeValue =
59+
getTosaConstShape(rewriter, op.getLoc(), revisedInputShape);
5860
input = rewriter
59-
.create<tosa::ReshapeOp>(
60-
op.getLoc(), inputType, input,
61-
rewriter.getDenseI64ArrayAttr(revisedInputShape))
61+
.create<tosa::ReshapeOp>(op.getLoc(), inputType, input,
62+
revisedInputShapeValue)
6263
.getResult();
6364

6465
Type inputETy = inputType.getElementType();
@@ -153,9 +154,10 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
153154
auto outputShapeType = RankedTensorType::get(
154155
outputShape,
155156
dyn_cast<RankedTensorType>(input.getType()).getElementType());
157+
auto outputShapeValue =
158+
getTosaConstShape(rewriter, op->getLoc(), outputShape);
156159
Value outputValue = rewriter.create<tosa::ReshapeOp>(
157-
op.getLoc(), outputShapeType, mulValue,
158-
rewriter.getDenseI64ArrayAttr(outputShape));
160+
op.getLoc(), outputShapeType, mulValue, outputShapeValue);
159161

160162
Value bias = op.getBias();
161163
if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,11 @@ class TransposeConvStridedConverter
159159
outputChannels, weightHeight / stride[0],
160160
stride[0], weightWidth / stride[1],
161161
stride[1], inputChannels};
162+
163+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
162164
weight = CreateOpAndInferShape<tosa::ReshapeOp>(
163-
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
164-
rewriter.getDenseI64ArrayAttr(weightReshapeDims0));
165+
builder, UnrankedTensorType::get(weightETy), weight,
166+
getTosaConstShape(rewriter, loc, weightReshapeDims0));
165167

166168
// Transpose the factored-out stride to the output channels.
167169
Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
@@ -173,12 +175,13 @@ class TransposeConvStridedConverter
173175
transposeWeightVal);
174176

175177
// Collapse the strides and output channels into a single dimension.
176-
llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
178+
llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
177179
outputChannels * stride[0] * stride[1], weightHeight / stride[0],
178180
weightWidth / stride[1], inputChannels};
181+
179182
weight = CreateOpAndInferShape<tosa::ReshapeOp>(
180183
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
181-
rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
184+
getTosaConstShape(rewriter, loc, weightReshapeDims1));
182185
ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());
183186

184187
weight = CreateOpAndInferShape<tosa::ReverseOp>(
@@ -257,9 +260,13 @@ class TransposeConvStridedConverter
257260
// Factor striding out of the convolution result.
258261
llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
259262
batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
263+
264+
auto convReshapeDims0Value =
265+
getTosaConstShape(rewriter, loc, convReshapeDims0);
266+
260267
conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
261268
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
262-
rewriter.getDenseI64ArrayAttr(convReshapeDims0));
269+
convReshapeDims0Value);
263270

264271
// Transpose the factored-out stride to the output channels.
265272
Value transposeConvVal = rewriter.create<tosa::ConstOp>(
@@ -273,9 +280,13 @@ class TransposeConvStridedConverter
273280
// Fuse striding behavior back into width / height.
274281
llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
275282
batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
283+
284+
auto convReshapeDims1Value =
285+
getTosaConstShape(rewriter, loc, convReshapeDims1);
286+
276287
conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
277288
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
278-
rewriter.getDenseI64ArrayAttr(convReshapeDims1));
289+
convReshapeDims1Value);
279290

280291
// Determine the amount to slice / pad from the result start.
281292
int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]);

mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,13 +402,20 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
402402
return std::nullopt;
403403

404404
// Do not insert a TransposeOp, instead we fold the reshape and its attribute.
405+
llvm::SmallVector<int64_t> newShape;
406+
if (!tosa::getConstShapeValue(reshapeOp.getShape().getDefiningOp(),
407+
newShape)) {
408+
// this mean shape is not constant
409+
return std::nullopt;
410+
}
411+
ImplicitLocOpBuilder builder(reshapeOp.getLoc(), rewriter);
405412
auto foldedReshape = rewriter.create<ReshapeOp>(
406413
reshapeOp.getLoc(),
407414
RankedTensorType::get(applyTOSAPermutation(shape, hoistedPerms),
408415
reshapeOutputType.getElementType()),
409416
reshapeOp.getInput1(),
410-
rewriter.getDenseI64ArrayAttr(
411-
applyTOSAPermutation(reshapeOp.getNewShape(), hoistedPerms)));
417+
getTosaConstShape(builder, applyTOSAPermutation(llvm::ArrayRef(newShape),
418+
hoistedPerms)));
412419
return foldedReshape->getResult(0);
413420
}
414421

0 commit comments

Comments
 (0)