-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][tosa] Change 'shape' of RESHAPE from attribute to input shape … #125789
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir-linalg Author: TatWai Chong (tatwaichong) ChangesThe shape operand is changed to input shape type since V1.0 Change-Id: I508cc1d67e9b017048b3f29fecf202cb7d707110 Patch is 115.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125789.diff 25 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 8ede271cc56a8a..869ab913a715ad 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1621,7 +1621,7 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape"> {
let arguments = (ins
Tosa_Tensor:$input1,
- DenseI64ArrayAttr:$new_shape
+ Tosa_Shape:$shape
);
let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index 78a8828855437e..88c21629286525 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -230,8 +230,11 @@ SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
}
// Computes shape value using tosa const_shape op.
+Value getTosaConstShape(ImplicitLocOpBuilder &builder,
+ llvm::ArrayRef<int64_t> shape);
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
llvm::ArrayRef<int64_t> shape);
+
SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);
bool getConstShapeValue(Operation *op,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b0eb2d6cbc30b6..edb04010d53fd9 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1952,9 +1952,10 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
});
+ auto shapeValue = getTosaConstShape(
+ rewriter, loc, mlir::tosa::convertFromMlirShape(resultTy.getShape()));
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
- op, resultTy, genericOp.getResult(0),
- rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
+ op, resultTy, genericOp.getResult(0), shapeValue);
return success();
}
};
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index c4b787d5c865b0..fdb8b1e1471a73 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -235,7 +236,12 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
return rewriter.notifyMatchFailure(reshape.getLoc(),
"expected input type to be tensor");
}
- auto newShape = reshape.getNewShape();
+
+ llvm::SmallVector<int64_t> newShape;
+ if (!tosa::getConstShapeValue(reshape.getShape().getDefiningOp(),
+ newShape)) {
+ return failure();
+ }
// Infer all intermediate types
auto inputType = inferReshapeInputType(input, newShape);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9d36947b4352bb..229719f5ef84d4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -180,7 +180,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, op.getType(), op.getInput1(),
- rewriter.getDenseI64ArrayAttr(newShape));
+ getTosaConstShape(rewriter, op.getLoc(), newShape));
return success();
}
};
@@ -948,8 +948,12 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
if (!getInput1().hasOneUse())
return {};
+ llvm::SmallVector<int64_t> shapeVec;
+ if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeVec))
+ return {};
+
return operand.reshape(
- llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
+ llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
}
return {};
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e8b28906135edf..f88c6df8e2b458 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1309,8 +1309,16 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape(adaptor.getInput1().getType());
Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
- llvm::SmallVector<int64_t> newShapeValue =
- convertToMlirShape(adaptor.getNewShape());
+ llvm::SmallVector<int64_t> newShapeValue;
+ if (!tosa::getConstShapeValue(adaptor.getShape().getDefiningOp(),
+ newShapeValue)) {
+ auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
+ SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
+ inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
+ return success();
+ } else {
+ newShapeValue = convertToMlirShape(newShapeValue);
+ }
// We cannot infer from the total number of elements so we must take the
// shape attribute as exact.
@@ -1346,13 +1354,19 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
TensorType inputType = getInput1().getType();
RankedTensorType outputType = getType();
- if ((int64_t)getNewShape().size() != outputType.getRank())
+ SmallVector<int64_t> shapeValues;
+ if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeValues)) {
+ // skip following checks if shape is not constant
+ return mlir::success();
+ }
+
+ if ((int64_t)shapeValues.size() != outputType.getRank())
return emitOpError() << "new shape does not match result rank";
for (auto [newShapeDim, outputShapeDim] :
- zip(getNewShape(), outputType.getShape())) {
- if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
- newShapeDim != outputShapeDim)
+ zip(shapeValues, outputType.getShape())) {
+ if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
+ outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
return emitOpError() << "new shape is inconsistent with result shape";
if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
@@ -1371,10 +1385,10 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
}
int64_t newShapeElementsNum = std::accumulate(
- getNewShape().begin(), getNewShape().end(), 1LL,
+ shapeValues.begin(), shapeValues.end(), 1LL,
[](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
bool isStaticNewShape =
- llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; });
+ llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
(!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
return emitOpError() << "cannot reshape " << inputElementsNum
@@ -1382,7 +1396,7 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
}
}
- int missingDims = llvm::count(getNewShape(), -1);
+ int missingDims = llvm::count(shapeValues, -1);
if (missingDims > 1)
return emitOpError() << "expected at most one target dimension to be -1";
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
index 7d3deae3330afe..04e8ad31cf2e2e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -20,12 +20,6 @@ using namespace mlir::tosa;
namespace {
-SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape) {
- return to_vector(llvm::map_range(shape, [](int64_t dim) {
- return ShapedType::isDynamic(dim) ? -1 : dim;
- }));
-}
-
struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
explicit Conv2DIsFullyConnected(MLIRContext *context)
: OpRewritePattern(context) {}
@@ -98,12 +92,13 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]};
auto revisedInputShapeType =
RankedTensorType::get(revisedInputShape, inputType.getElementType());
- auto reshapedInput = rewriter
- .create<tosa::ReshapeOp>(
- op.getLoc(), revisedInputShapeType, input,
- rewriter.getDenseI64ArrayAttr(
- convertFromMlirShape(revisedInputShape)))
- .getResult();
+ auto revisedInputShapeValue = getTosaConstShape(
+ rewriter, op.getLoc(), convertFromMlirShape(revisedInputShape));
+ auto reshapedInput =
+ rewriter
+ .create<tosa::ReshapeOp>(op.getLoc(), revisedInputShapeType, input,
+ revisedInputShapeValue)
+ .getResult();
// Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
@@ -111,12 +106,13 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
auto revisedWeightShapeType = RankedTensorType::get(
revisedWeightShape,
dyn_cast<RankedTensorType>(weight.getType()).getElementType());
- auto reshapedWeight = rewriter
- .create<tosa::ReshapeOp>(
- op.getLoc(), revisedWeightShapeType, weight,
- rewriter.getDenseI64ArrayAttr(
- convertFromMlirShape(revisedWeightShape)))
- .getResult();
+ auto revisedWeightShapeValue = getTosaConstShape(
+ rewriter, op.getLoc(), convertFromMlirShape(revisedWeightShape));
+ auto reshapedWeight =
+ rewriter
+ .create<tosa::ReshapeOp>(op.getLoc(), revisedWeightShapeType,
+ weight, revisedWeightShapeValue)
+ .getResult();
// Perform a fully connected network over the reshaped input and weight.
llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]};
@@ -149,9 +145,10 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
// Reshape output to [N, IH, IW, OC].
llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
inputShape[2], weightShape[0]};
+ auto outputShapeValue = getTosaConstShape(
+ rewriter, op.getLoc(), convertFromMlirShape(outputShape));
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
- op, resultType, fullyConnectedValue,
- rewriter.getDenseI64ArrayAttr(convertFromMlirShape(outputShape)));
+ op, resultType, fullyConnectedValue, outputShapeValue);
return success();
}
};
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index ee857f1998a54d..b26397d0e3ed7a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -55,10 +55,11 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
inputType = RankedTensorType::get(
revisedInputShape,
dyn_cast<RankedTensorType>(input.getType()).getElementType());
+ auto revisedInputShapeValue =
+ getTosaConstShape(rewriter, op.getLoc(), revisedInputShape);
input = rewriter
- .create<tosa::ReshapeOp>(
- op.getLoc(), inputType, input,
- rewriter.getDenseI64ArrayAttr(revisedInputShape))
+ .create<tosa::ReshapeOp>(op.getLoc(), inputType, input,
+ revisedInputShapeValue)
.getResult();
Type inputETy = inputType.getElementType();
@@ -153,9 +154,10 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
auto outputShapeType = RankedTensorType::get(
outputShape,
dyn_cast<RankedTensorType>(input.getType()).getElementType());
+ auto outputShapeValue =
+ getTosaConstShape(rewriter, op->getLoc(), outputShape);
Value outputValue = rewriter.create<tosa::ReshapeOp>(
- op.getLoc(), outputShapeType, mulValue,
- rewriter.getDenseI64ArrayAttr(outputShape));
+ op.getLoc(), outputShapeType, mulValue, outputShapeValue);
Value bias = op.getBias();
if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index ae224671e304f2..69a66c98307e94 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -160,9 +160,11 @@ class TransposeConvStridedConverter
outputChannels, weightHeight / stride[0],
stride[0], weightWidth / stride[1],
stride[1], inputChannels};
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
weight = CreateOpAndInferShape<tosa::ReshapeOp>(
- rewriter, loc, UnrankedTensorType::get(weightETy), weight,
- rewriter.getDenseI64ArrayAttr(weightReshapeDims0));
+ builder, UnrankedTensorType::get(weightETy), weight,
+ getTosaConstShape(rewriter, loc, weightReshapeDims0));
// Transpose the factored-out stride to the output channels.
Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
@@ -174,12 +176,13 @@ class TransposeConvStridedConverter
transposeWeightVal);
// Collapse the strides and output channels into a single dimension.
- llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
+ llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
outputChannels * stride[0] * stride[1], weightHeight / stride[0],
weightWidth / stride[1], inputChannels};
+
weight = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
- rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
+ getTosaConstShape(rewriter, loc, weightReshapeDims1));
ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());
weight = CreateOpAndInferShape<tosa::ReverseOp>(
@@ -258,9 +261,13 @@ class TransposeConvStridedConverter
// Factor striding out of the convolution result.
llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
+
+ auto convReshapeDims0Value =
+ getTosaConstShape(rewriter, loc, convReshapeDims0);
+
conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
- rewriter.getDenseI64ArrayAttr(convReshapeDims0));
+ convReshapeDims0Value);
// Transpose the factored-out stride to the output channels.
Value transposeConvVal = rewriter.create<tosa::ConstOp>(
@@ -274,9 +281,13 @@ class TransposeConvStridedConverter
// Fuse striding behavior back into width / height.
llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
+
+ auto convReshapeDims1Value =
+ getTosaConstShape(rewriter, loc, convReshapeDims1);
+
conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
- rewriter.getDenseI64ArrayAttr(convReshapeDims1));
+ convReshapeDims1Value);
// Determine the amount to slice / pad from the result start.
int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index 520f283a3ba888..281f0529a5c081 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -402,13 +402,20 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
return std::nullopt;
// Do not insert a TransposeOp, instead we fold the reshape and its attribute.
+ llvm::SmallVector<int64_t> newShape;
+ if (!tosa::getConstShapeValue(reshapeOp.getShape().getDefiningOp(),
+ newShape)) {
+ // this mean shape is not constant
+ return std::nullopt;
+ }
+ ImplicitLocOpBuilder builder(reshapeOp.getLoc(), rewriter);
auto foldedReshape = rewriter.create<ReshapeOp>(
reshapeOp.getLoc(),
RankedTensorType::get(applyTOSAPermutation(shape, hoistedPerms),
reshapeOutputType.getElementType()),
reshapeOp.getInput1(),
- rewriter.getDenseI64ArrayAttr(
- applyTOSAPermutation(reshapeOp.getNewShape(), hoistedPerms)));
+ getTosaConstShape(builder, applyTOSAPermutation(llvm::ArrayRef(newShape),
+ hoistedPerms)));
return foldedReshape->getResult(0);
}
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index 62b0bc1857e395..8ab12d038849f4 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -145,10 +145,10 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
llvm::cast<RankedTensorType>(lowerTensorValue.getType());
auto reshapeOutputType = RankedTensorType::get(
ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
+ auto reshapeOutputShapeValue = getTosaConstShape(builder, reshapeOutputShape);
auto reshapeLower = builder.create<tosa::ReshapeOp>(
- reshapeOutputType, lowerTensorValue,
- builder.getDenseI64ArrayAttr(reshapeOutputShape));
+ reshapeOutputType, lowerTensorValue, reshapeOutputShapeValue);
if (input1Rank > input2Rank) {
input1 = higherTensorValue;
@@ -161,15 +161,20 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
return success();
}
-Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
+Value mlir::tosa::getTosaConstShape(ImplicitLocOpBuilder &builder,
llvm::ArrayRef<int64_t> shape) {
- auto attr = rewriter.getIndexTensorAttr(shape);
- auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size());
- mlir::Operation *mlir_op =
- rewriter.create<tosa::ConstShapeOp>(loc, type, attr);
+ auto attr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
+ auto type = mlir::tosa::shapeType::get(builder.getContext(), shape.size());
+ mlir::Operation *mlir_op = builder.create<tosa::ConstShapeOp>(type, attr);
return mlir_op->getResult(0);
}
+Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
+ llvm::ArrayRef<int64_t> shape) {
+ ImplicitLocOpBuilder builder(loc, rewriter);
+ return getTosaConstShape(builder, shape);
+}
+
SmallVector<int64_t> mlir::tosa::convertFromMlirShape(ArrayRef<int64_t> shape) {
return to_vector(llvm::map_range(shape, [](int64_t dim) {
return ShapedType::isDynamic(dim) ? -1 : dim;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 75b48f2b06d899..460e207d62de6a 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -24,7 +24,8 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
%reduce = tosa.reduce_max %arg0 {axis = 1 : i32} : (tensor<10x10xf32>) -> tensor<10x1xf32>
%1 = tosa.add %reduce, %arg1 : (tensor<10x1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%0 = tosa.add %1, %arg2 : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32>
- %2 = tosa.reshape %0 {new_shape = array<i64: 10, 10>} : (tensor<...
[truncated]
|
|
@llvm/pr-subscribers-mlir Author: TatWai Chong (tatwaichong) ChangesThe shape operand is changed to input shape type since V1.0 Change-Id: I508cc1d67e9b017048b3f29fecf202cb7d707110 Patch is 115.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125789.diff 25 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 8ede271cc56a8a..869ab913a715ad 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1621,7 +1621,7 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape"> {
let arguments = (ins
Tosa_Tensor:$input1,
- DenseI64ArrayAttr:$new_shape
+ Tosa_Shape:$shape
);
let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index 78a8828855437e..88c21629286525 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -230,8 +230,11 @@ SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
}
// Computes shape value using tosa const_shape op.
+Value getTosaConstShape(ImplicitLocOpBuilder &builder,
+ llvm::ArrayRef<int64_t> shape);
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
llvm::ArrayRef<int64_t> shape);
+
SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);
bool getConstShapeValue(Operation *op,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b0eb2d6cbc30b6..edb04010d53fd9 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1952,9 +1952,10 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
});
+ auto shapeValue = getTosaConstShape(
+ rewriter, loc, mlir::tosa::convertFromMlirShape(resultTy.getShape()));
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
- op, resultTy, genericOp.getResult(0),
- rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
+ op, resultTy, genericOp.getResult(0), shapeValue);
return success();
}
};
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index c4b787d5c865b0..fdb8b1e1471a73 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -235,7 +236,12 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
return rewriter.notifyMatchFailure(reshape.getLoc(),
"expected input type to be tensor");
}
- auto newShape = reshape.getNewShape();
+
+ llvm::SmallVector<int64_t> newShape;
+ if (!tosa::getConstShapeValue(reshape.getShape().getDefiningOp(),
+ newShape)) {
+ return failure();
+ }
// Infer all intermediate types
auto inputType = inferReshapeInputType(input, newShape);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9d36947b4352bb..229719f5ef84d4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -180,7 +180,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, op.getType(), op.getInput1(),
- rewriter.getDenseI64ArrayAttr(newShape));
+ getTosaConstShape(rewriter, op.getLoc(), newShape));
return success();
}
};
@@ -948,8 +948,12 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
if (!getInput1().hasOneUse())
return {};
+ llvm::SmallVector<int64_t> shapeVec;
+ if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeVec))
+ return {};
+
return operand.reshape(
- llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
+ llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
}
return {};
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e8b28906135edf..f88c6df8e2b458 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1309,8 +1309,16 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape(adaptor.getInput1().getType());
Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
- llvm::SmallVector<int64_t> newShapeValue =
- convertToMlirShape(adaptor.getNewShape());
+ llvm::SmallVector<int64_t> newShapeValue;
+ if (!tosa::getConstShapeValue(adaptor.getShape().getDefiningOp(),
+ newShapeValue)) {
+ auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
+ SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
+ inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
+ return success();
+ } else {
+ newShapeValue = convertToMlirShape(newShapeValue);
+ }
// We cannot infer from the total number of elements so we must take the
// shape attribute as exact.
@@ -1346,13 +1354,19 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
TensorType inputType = getInput1().getType();
RankedTensorType outputType = getType();
- if ((int64_t)getNewShape().size() != outputType.getRank())
+ SmallVector<int64_t> shapeValues;
+ if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeValues)) {
+ // skip following checks if shape is not constant
+ return mlir::success();
+ }
+
+ if ((int64_t)shapeValues.size() != outputType.getRank())
return emitOpError() << "new shape does not match result rank";
for (auto [newShapeDim, outputShapeDim] :
- zip(getNewShape(), outputType.getShape())) {
- if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
- newShapeDim != outputShapeDim)
+ zip(shapeValues, outputType.getShape())) {
+ if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
+ outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
return emitOpError() << "new shape is inconsistent with result shape";
if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
@@ -1371,10 +1385,10 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
}
int64_t newShapeElementsNum = std::accumulate(
- getNewShape().begin(), getNewShape().end(), 1LL,
+ shapeValues.begin(), shapeValues.end(), 1LL,
[](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
bool isStaticNewShape =
- llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; });
+ llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
(!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
return emitOpError() << "cannot reshape " << inputElementsNum
@@ -1382,7 +1396,7 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
}
}
- int missingDims = llvm::count(getNewShape(), -1);
+ int missingDims = llvm::count(shapeValues, -1);
if (missingDims > 1)
return emitOpError() << "expected at most one target dimension to be -1";
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
index 7d3deae3330afe..04e8ad31cf2e2e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -20,12 +20,6 @@ using namespace mlir::tosa;
namespace {
-SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape) {
- return to_vector(llvm::map_range(shape, [](int64_t dim) {
- return ShapedType::isDynamic(dim) ? -1 : dim;
- }));
-}
-
struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
explicit Conv2DIsFullyConnected(MLIRContext *context)
: OpRewritePattern(context) {}
@@ -98,12 +92,13 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]};
auto revisedInputShapeType =
RankedTensorType::get(revisedInputShape, inputType.getElementType());
- auto reshapedInput = rewriter
- .create<tosa::ReshapeOp>(
- op.getLoc(), revisedInputShapeType, input,
- rewriter.getDenseI64ArrayAttr(
- convertFromMlirShape(revisedInputShape)))
- .getResult();
+ auto revisedInputShapeValue = getTosaConstShape(
+ rewriter, op.getLoc(), convertFromMlirShape(revisedInputShape));
+ auto reshapedInput =
+ rewriter
+ .create<tosa::ReshapeOp>(op.getLoc(), revisedInputShapeType, input,
+ revisedInputShapeValue)
+ .getResult();
// Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
@@ -111,12 +106,13 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
auto revisedWeightShapeType = RankedTensorType::get(
revisedWeightShape,
dyn_cast<RankedTensorType>(weight.getType()).getElementType());
- auto reshapedWeight = rewriter
- .create<tosa::ReshapeOp>(
- op.getLoc(), revisedWeightShapeType, weight,
- rewriter.getDenseI64ArrayAttr(
- convertFromMlirShape(revisedWeightShape)))
- .getResult();
+ auto revisedWeightShapeValue = getTosaConstShape(
+ rewriter, op.getLoc(), convertFromMlirShape(revisedWeightShape));
+ auto reshapedWeight =
+ rewriter
+ .create<tosa::ReshapeOp>(op.getLoc(), revisedWeightShapeType,
+ weight, revisedWeightShapeValue)
+ .getResult();
// Perform a fully connected network over the reshaped input and weight.
llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]};
@@ -149,9 +145,10 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
// Reshape output to [N, IH, IW, OC].
llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
inputShape[2], weightShape[0]};
+ auto outputShapeValue = getTosaConstShape(
+ rewriter, op.getLoc(), convertFromMlirShape(outputShape));
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
- op, resultType, fullyConnectedValue,
- rewriter.getDenseI64ArrayAttr(convertFromMlirShape(outputShape)));
+ op, resultType, fullyConnectedValue, outputShapeValue);
return success();
}
};
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index ee857f1998a54d..b26397d0e3ed7a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -55,10 +55,11 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
inputType = RankedTensorType::get(
revisedInputShape,
dyn_cast<RankedTensorType>(input.getType()).getElementType());
+ auto revisedInputShapeValue =
+ getTosaConstShape(rewriter, op.getLoc(), revisedInputShape);
input = rewriter
- .create<tosa::ReshapeOp>(
- op.getLoc(), inputType, input,
- rewriter.getDenseI64ArrayAttr(revisedInputShape))
+ .create<tosa::ReshapeOp>(op.getLoc(), inputType, input,
+ revisedInputShapeValue)
.getResult();
Type inputETy = inputType.getElementType();
@@ -153,9 +154,10 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
auto outputShapeType = RankedTensorType::get(
outputShape,
dyn_cast<RankedTensorType>(input.getType()).getElementType());
+ auto outputShapeValue =
+ getTosaConstShape(rewriter, op->getLoc(), outputShape);
Value outputValue = rewriter.create<tosa::ReshapeOp>(
- op.getLoc(), outputShapeType, mulValue,
- rewriter.getDenseI64ArrayAttr(outputShape));
+ op.getLoc(), outputShapeType, mulValue, outputShapeValue);
Value bias = op.getBias();
if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index ae224671e304f2..69a66c98307e94 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -160,9 +160,11 @@ class TransposeConvStridedConverter
outputChannels, weightHeight / stride[0],
stride[0], weightWidth / stride[1],
stride[1], inputChannels};
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
weight = CreateOpAndInferShape<tosa::ReshapeOp>(
- rewriter, loc, UnrankedTensorType::get(weightETy), weight,
- rewriter.getDenseI64ArrayAttr(weightReshapeDims0));
+ builder, UnrankedTensorType::get(weightETy), weight,
+ getTosaConstShape(rewriter, loc, weightReshapeDims0));
// Transpose the factored-out stride to the output channels.
Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
@@ -174,12 +176,13 @@ class TransposeConvStridedConverter
transposeWeightVal);
// Collapse the strides and output channels into a single dimension.
- llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
+ llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
outputChannels * stride[0] * stride[1], weightHeight / stride[0],
weightWidth / stride[1], inputChannels};
+
weight = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
- rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
+ getTosaConstShape(rewriter, loc, weightReshapeDims1));
ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());
weight = CreateOpAndInferShape<tosa::ReverseOp>(
@@ -258,9 +261,13 @@ class TransposeConvStridedConverter
// Factor striding out of the convolution result.
llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
+
+ auto convReshapeDims0Value =
+ getTosaConstShape(rewriter, loc, convReshapeDims0);
+
conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
- rewriter.getDenseI64ArrayAttr(convReshapeDims0));
+ convReshapeDims0Value);
// Transpose the factored-out stride to the output channels.
Value transposeConvVal = rewriter.create<tosa::ConstOp>(
@@ -274,9 +281,13 @@ class TransposeConvStridedConverter
// Fuse striding behavior back into width / height.
llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
+
+ auto convReshapeDims1Value =
+ getTosaConstShape(rewriter, loc, convReshapeDims1);
+
conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
- rewriter.getDenseI64ArrayAttr(convReshapeDims1));
+ convReshapeDims1Value);
// Determine the amount to slice / pad from the result start.
int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index 520f283a3ba888..281f0529a5c081 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -402,13 +402,20 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
return std::nullopt;
// Do not insert a TransposeOp, instead we fold the reshape and its attribute.
+ llvm::SmallVector<int64_t> newShape;
+ if (!tosa::getConstShapeValue(reshapeOp.getShape().getDefiningOp(),
+ newShape)) {
+ // this mean shape is not constant
+ return std::nullopt;
+ }
+ ImplicitLocOpBuilder builder(reshapeOp.getLoc(), rewriter);
auto foldedReshape = rewriter.create<ReshapeOp>(
reshapeOp.getLoc(),
RankedTensorType::get(applyTOSAPermutation(shape, hoistedPerms),
reshapeOutputType.getElementType()),
reshapeOp.getInput1(),
- rewriter.getDenseI64ArrayAttr(
- applyTOSAPermutation(reshapeOp.getNewShape(), hoistedPerms)));
+ getTosaConstShape(builder, applyTOSAPermutation(llvm::ArrayRef(newShape),
+ hoistedPerms)));
return foldedReshape->getResult(0);
}
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index 62b0bc1857e395..8ab12d038849f4 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -145,10 +145,10 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
llvm::cast<RankedTensorType>(lowerTensorValue.getType());
auto reshapeOutputType = RankedTensorType::get(
ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
+ auto reshapeOutputShapeValue = getTosaConstShape(builder, reshapeOutputShape);
auto reshapeLower = builder.create<tosa::ReshapeOp>(
- reshapeOutputType, lowerTensorValue,
- builder.getDenseI64ArrayAttr(reshapeOutputShape));
+ reshapeOutputType, lowerTensorValue, reshapeOutputShapeValue);
if (input1Rank > input2Rank) {
input1 = higherTensorValue;
@@ -161,15 +161,20 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
return success();
}
-Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
+Value mlir::tosa::getTosaConstShape(ImplicitLocOpBuilder &builder,
llvm::ArrayRef<int64_t> shape) {
- auto attr = rewriter.getIndexTensorAttr(shape);
- auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size());
- mlir::Operation *mlir_op =
- rewriter.create<tosa::ConstShapeOp>(loc, type, attr);
+ auto attr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
+ auto type = mlir::tosa::shapeType::get(builder.getContext(), shape.size());
+ mlir::Operation *mlir_op = builder.create<tosa::ConstShapeOp>(type, attr);
return mlir_op->getResult(0);
}
+Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
+ llvm::ArrayRef<int64_t> shape) {
+ ImplicitLocOpBuilder builder(loc, rewriter);
+ return getTosaConstShape(builder, shape);
+}
+
SmallVector<int64_t> mlir::tosa::convertFromMlirShape(ArrayRef<int64_t> shape) {
return to_vector(llvm::map_range(shape, [](int64_t dim) {
return ShapedType::isDynamic(dim) ? -1 : dim;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 75b48f2b06d899..460e207d62de6a 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -24,7 +24,8 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
%reduce = tosa.reduce_max %arg0 {axis = 1 : i32} : (tensor<10x10xf32>) -> tensor<10x1xf32>
%1 = tosa.add %reduce, %arg1 : (tensor<10x1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%0 = tosa.add %1, %arg2 : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32>
- %2 = tosa.reshape %0 {new_shape = array<i64: 10, 10>} : (tensor<...
[truncated]
|
…type Co-authored-by: TatWai Chong <[email protected]> Change-Id: I508cc1d67e9b017048b3f29fecf202cb7d707110
7e89690 to
8aaa01d
Compare
GeorgeARM
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comments.
| llvm::SmallVector<int64_t, 6> convReshapeDims1 = { | ||
| batch, convHeight * stride[0], convWidth * stride[1], outputChannels}; | ||
|
|
||
| auto convReshapeDims1Value = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need the new variable? You can put it directly in the function below?
| llvm::SmallVector<int64_t, 6> convReshapeDims0 = { | ||
| batch, convHeight, convWidth, stride[0], stride[1], outputChannels}; | ||
|
|
||
| auto convReshapeDims0Value = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Directly to the function below?
| auto outputShapeType = RankedTensorType::get( | ||
| outputShape, | ||
| dyn_cast<RankedTensorType>(input.getType()).getElementType()); | ||
| auto outputShapeValue = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
directly in the function below?
| Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType()); | ||
| llvm::SmallVector<int64_t> newShapeValue = | ||
| convertToMlirShape(adaptor.getNewShape()); | ||
| llvm::SmallVector<int64_t> newShapeValue; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find this two-stage setup for the shape obscure. Would prefer an optional but probably worth a separate patch
…t shape … (llvm#125789)" This reverts commit 571a987.
We had previously cherry-picked llvm/llvm-project@73f11ac in #19939. Now we're integrating up to that commit, so it's no longer a cherry-pick. Reverting llvm/llvm-project#125789 because it breaks TorchToTosa, in torch-mlir. We will need to wait for this to be resolved in torch-mlir, then simultaneously bump torch-mlir and drop the revert. Chery-pick a Bazel fix: llvm/llvm-project@4df287a --------- Signed-off-by: Benoit Jacob <[email protected]>
…t shape … (llvm#125789)" This reverts commit 571a987.
Carrying the existing revert of llvm/llvm-project#125789 because it breaks TorchToTosa, in torch-mlir. We will need to wait for this to be resolved in torch-mlir, then simultaneously bump torch-mlir and drop the revert. Signed-off-by: Benoit Jacob <[email protected]>
…t shape … (llvm#125789)" This reverts commit 571a987.
Integrate at llvm/llvm-project@001ba42f Carrying the existing revert of llvm/llvm-project#125789 because it breaks TorchToTosa, in torch-mlir. We will need to wait for this to be resolved in torch-mlir, then simultaneously bump torch-mlir and drop the revert. Signed-off-by: Benoit Jacob <[email protected]>
llvm#125789) The shape operand is changed to input shape type since V1.0 Change-Id: I508cc1d67e9b017048b3f29fecf202cb7d707110 Co-authored-by: Won Jeon <[email protected]>
The shape operand is changed to input shape type since V1.0
Change-Id: I508cc1d67e9b017048b3f29fecf202cb7d707110