Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1625,7 +1625,7 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape"> {

let arguments = (ins
Tosa_Tensor:$input1,
DenseI64ArrayAttr:$new_shape
Tosa_Shape:$shape
);

let results = (outs
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,11 @@ SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
}

// Computes shape value using tosa const_shape op.
Value getTosaConstShape(ImplicitLocOpBuilder &builder,
llvm::ArrayRef<int64_t> shape);
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
llvm::ArrayRef<int64_t> shape);

SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);

bool getConstShapeValue(Operation *op,
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1954,9 +1954,10 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
});

auto shapeValue = getTosaConstShape(
rewriter, loc, mlir::tosa::convertFromMlirShape(resultTy.getShape()));
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, resultTy, genericOp.getResult(0),
rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
op, resultTy, genericOp.getResult(0), shapeValue);
return success();
}
};
Expand Down
8 changes: 7 additions & 1 deletion mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"

Expand Down Expand Up @@ -235,7 +236,12 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
return rewriter.notifyMatchFailure(reshape.getLoc(),
"expected input type to be tensor");
}
auto newShape = reshape.getNewShape();

llvm::SmallVector<int64_t> newShape;
if (!tosa::getConstShapeValue(reshape.getShape().getDefiningOp(),
newShape)) {
return failure();
}

// Infer all intermediate types
auto inputType = inferReshapeInputType(input, newShape);
Expand Down
8 changes: 6 additions & 2 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {

rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, op.getType(), op.getInput1(),
rewriter.getDenseI64ArrayAttr(newShape));
getTosaConstShape(rewriter, op.getLoc(), newShape));
return success();
}
};
Expand Down Expand Up @@ -948,8 +948,12 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
if (!getInput1().hasOneUse())
return {};

llvm::SmallVector<int64_t> shapeVec;
if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeVec))
return {};

return operand.reshape(
llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
}

return {};
Expand Down
32 changes: 23 additions & 9 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1335,8 +1335,16 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape(adaptor.getInput1().getType());
Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
llvm::SmallVector<int64_t> newShapeValue =
convertToMlirShape(adaptor.getNewShape());
llvm::SmallVector<int64_t> newShapeValue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this two-stage setup for the shape obscure. Would prefer an optional but probably worth a separate patch

if (!tosa::getConstShapeValue(adaptor.getShape().getDefiningOp(),
newShapeValue)) {
auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
return success();
} else {
newShapeValue = convertToMlirShape(newShapeValue);
}

// We cannot infer from the total number of elements so we must take the
// shape attribute as exact.
Expand Down Expand Up @@ -1372,13 +1380,19 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
TensorType inputType = getInput1().getType();
RankedTensorType outputType = getType();

if ((int64_t)getNewShape().size() != outputType.getRank())
SmallVector<int64_t> shapeValues;
if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeValues)) {
// skip following checks if shape is not constant
return mlir::success();
}

if ((int64_t)shapeValues.size() != outputType.getRank())
return emitOpError() << "new shape does not match result rank";

for (auto [newShapeDim, outputShapeDim] :
zip(getNewShape(), outputType.getShape())) {
if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
newShapeDim != outputShapeDim)
zip(shapeValues, outputType.getShape())) {
if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
return emitOpError() << "new shape is inconsistent with result shape";

if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
Expand All @@ -1397,18 +1411,18 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
}

int64_t newShapeElementsNum = std::accumulate(
getNewShape().begin(), getNewShape().end(), 1LL,
shapeValues.begin(), shapeValues.end(), 1LL,
[](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
bool isStaticNewShape =
llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; });
llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
(!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
return emitOpError() << "cannot reshape " << inputElementsNum
<< " elements into " << newShapeElementsNum;
}
}

int missingDims = llvm::count(getNewShape(), -1);
int missingDims = llvm::count(shapeValues, -1);
if (missingDims > 1)
return emitOpError() << "expected at most one target dimension to be -1";

Expand Down
37 changes: 17 additions & 20 deletions mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,6 @@ using namespace mlir::tosa;

namespace {

SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape) {
return to_vector(llvm::map_range(shape, [](int64_t dim) {
return ShapedType::isDynamic(dim) ? -1 : dim;
}));
}

struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
explicit Conv2DIsFullyConnected(MLIRContext *context)
: OpRewritePattern(context) {}
Expand Down Expand Up @@ -98,25 +92,27 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]};
auto revisedInputShapeType =
RankedTensorType::get(revisedInputShape, inputType.getElementType());
auto reshapedInput = rewriter
.create<tosa::ReshapeOp>(
op.getLoc(), revisedInputShapeType, input,
rewriter.getDenseI64ArrayAttr(
convertFromMlirShape(revisedInputShape)))
.getResult();
auto revisedInputShapeValue = getTosaConstShape(
rewriter, op.getLoc(), convertFromMlirShape(revisedInputShape));
auto reshapedInput =
rewriter
.create<tosa::ReshapeOp>(op.getLoc(), revisedInputShapeType, input,
revisedInputShapeValue)
.getResult();

// Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
weightShape[3]};
auto revisedWeightShapeType = RankedTensorType::get(
revisedWeightShape,
dyn_cast<RankedTensorType>(weight.getType()).getElementType());
auto reshapedWeight = rewriter
.create<tosa::ReshapeOp>(
op.getLoc(), revisedWeightShapeType, weight,
rewriter.getDenseI64ArrayAttr(
convertFromMlirShape(revisedWeightShape)))
.getResult();
auto revisedWeightShapeValue = getTosaConstShape(
rewriter, op.getLoc(), convertFromMlirShape(revisedWeightShape));
auto reshapedWeight =
rewriter
.create<tosa::ReshapeOp>(op.getLoc(), revisedWeightShapeType,
weight, revisedWeightShapeValue)
.getResult();

// Perform a fully connected network over the reshaped input and weight.
llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]};
Expand Down Expand Up @@ -149,9 +145,10 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
// Reshape output to [N, IH, IW, OC].
llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
inputShape[2], weightShape[0]};
auto outputShapeValue = getTosaConstShape(
rewriter, op.getLoc(), convertFromMlirShape(outputShape));
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, resultType, fullyConnectedValue,
rewriter.getDenseI64ArrayAttr(convertFromMlirShape(outputShape)));
op, resultType, fullyConnectedValue, outputShapeValue);
return success();
}
};
Expand Down
12 changes: 7 additions & 5 deletions mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,11 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
inputType = RankedTensorType::get(
revisedInputShape,
dyn_cast<RankedTensorType>(input.getType()).getElementType());
auto revisedInputShapeValue =
getTosaConstShape(rewriter, op.getLoc(), revisedInputShape);
input = rewriter
.create<tosa::ReshapeOp>(
op.getLoc(), inputType, input,
rewriter.getDenseI64ArrayAttr(revisedInputShape))
.create<tosa::ReshapeOp>(op.getLoc(), inputType, input,
revisedInputShapeValue)
.getResult();

Type inputETy = inputType.getElementType();
Expand Down Expand Up @@ -153,9 +154,10 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
auto outputShapeType = RankedTensorType::get(
outputShape,
dyn_cast<RankedTensorType>(input.getType()).getElementType());
auto outputShapeValue =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

directly in the function below?

getTosaConstShape(rewriter, op->getLoc(), outputShape);
Value outputValue = rewriter.create<tosa::ReshapeOp>(
op.getLoc(), outputShapeType, mulValue,
rewriter.getDenseI64ArrayAttr(outputShape));
op.getLoc(), outputShapeType, mulValue, outputShapeValue);

Value bias = op.getBias();
if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {
Expand Down
23 changes: 17 additions & 6 deletions mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,11 @@ class TransposeConvStridedConverter
outputChannels, weightHeight / stride[0],
stride[0], weightWidth / stride[1],
stride[1], inputChannels};

ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
weight = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
rewriter.getDenseI64ArrayAttr(weightReshapeDims0));
builder, UnrankedTensorType::get(weightETy), weight,
getTosaConstShape(rewriter, loc, weightReshapeDims0));

// Transpose the factored-out stride to the output channels.
Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
Expand All @@ -173,12 +175,13 @@ class TransposeConvStridedConverter
transposeWeightVal);

// Collapse the strides and output channels into a single dimension.
llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
outputChannels * stride[0] * stride[1], weightHeight / stride[0],
weightWidth / stride[1], inputChannels};

weight = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
getTosaConstShape(rewriter, loc, weightReshapeDims1));
ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());

weight = CreateOpAndInferShape<tosa::ReverseOp>(
Expand Down Expand Up @@ -257,9 +260,13 @@ class TransposeConvStridedConverter
// Factor striding out of the convolution result.
llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
batch, convHeight, convWidth, stride[0], stride[1], outputChannels};

auto convReshapeDims0Value =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Directly to the function below?

getTosaConstShape(rewriter, loc, convReshapeDims0);

conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
rewriter.getDenseI64ArrayAttr(convReshapeDims0));
convReshapeDims0Value);

// Transpose the factored-out stride to the output channels.
Value transposeConvVal = rewriter.create<tosa::ConstOp>(
Expand All @@ -273,9 +280,13 @@ class TransposeConvStridedConverter
// Fuse striding behavior back into width / height.
llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
batch, convHeight * stride[0], convWidth * stride[1], outputChannels};

auto convReshapeDims1Value =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need the new variable? You can put it directly in the function below?

getTosaConstShape(rewriter, loc, convReshapeDims1);

conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
rewriter.getDenseI64ArrayAttr(convReshapeDims1));
convReshapeDims1Value);

// Determine the amount to slice / pad from the result start.
int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]);
Expand Down
11 changes: 9 additions & 2 deletions mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,13 +402,20 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
return std::nullopt;

// Do not insert a TransposeOp, instead we fold the reshape and its attribute.
llvm::SmallVector<int64_t> newShape;
if (!tosa::getConstShapeValue(reshapeOp.getShape().getDefiningOp(),
newShape)) {
// this mean shape is not constant
return std::nullopt;
}
ImplicitLocOpBuilder builder(reshapeOp.getLoc(), rewriter);
auto foldedReshape = rewriter.create<ReshapeOp>(
reshapeOp.getLoc(),
RankedTensorType::get(applyTOSAPermutation(shape, hoistedPerms),
reshapeOutputType.getElementType()),
reshapeOp.getInput1(),
rewriter.getDenseI64ArrayAttr(
applyTOSAPermutation(reshapeOp.getNewShape(), hoistedPerms)));
getTosaConstShape(builder, applyTOSAPermutation(llvm::ArrayRef(newShape),
hoistedPerms)));
return foldedReshape->getResult(0);
}

Expand Down
19 changes: 12 additions & 7 deletions mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
llvm::cast<RankedTensorType>(lowerTensorValue.getType());
auto reshapeOutputType = RankedTensorType::get(
ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
auto reshapeOutputShapeValue = getTosaConstShape(builder, reshapeOutputShape);

auto reshapeLower = builder.create<tosa::ReshapeOp>(
reshapeOutputType, lowerTensorValue,
builder.getDenseI64ArrayAttr(reshapeOutputShape));
reshapeOutputType, lowerTensorValue, reshapeOutputShapeValue);

if (input1Rank > input2Rank) {
input1 = higherTensorValue;
Expand All @@ -161,15 +161,20 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
return success();
}

Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
Value mlir::tosa::getTosaConstShape(ImplicitLocOpBuilder &builder,
llvm::ArrayRef<int64_t> shape) {
auto attr = rewriter.getIndexTensorAttr(shape);
auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size());
mlir::Operation *mlir_op =
rewriter.create<tosa::ConstShapeOp>(loc, type, attr);
auto attr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
auto type = mlir::tosa::shapeType::get(builder.getContext(), shape.size());
mlir::Operation *mlir_op = builder.create<tosa::ConstShapeOp>(type, attr);
return mlir_op->getResult(0);
}

Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
llvm::ArrayRef<int64_t> shape) {
ImplicitLocOpBuilder builder(loc, rewriter);
return getTosaConstShape(builder, shape);
}

SmallVector<int64_t> mlir::tosa::convertFromMlirShape(ArrayRef<int64_t> shape) {
return to_vector(llvm::map_range(shape, [](int64_t dim) {
return ShapedType::isDynamic(dim) ? -1 : dim;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
%reduce = tosa.reduce_max %arg0 {axis = 1 : i32} : (tensor<10x10xf32>) -> tensor<10x1xf32>
%1 = tosa.add %reduce, %arg1 : (tensor<10x1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%0 = tosa.add %1, %arg2 : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32>
%2 = tosa.reshape %0 {new_shape = array<i64: 10, 10>} : (tensor<*xf32>) -> tensor<10x10xf32>
%s = tosa.const_shape {value = dense<[10, 10]> : tensor<2xindex>} : () -> !tosa.shape<2>
%2 = tosa.reshape %0, %s : (tensor<*xf32>, !tosa.shape<2>) -> tensor<10x10xf32>
return %2 : tensor<10x10xf32>
}

Expand Down
Loading