From 8aaa01d6d8668cae8f6b2ec715c831943c922123 Mon Sep 17 00:00:00 2001 From: Won Jeon Date: Wed, 6 Dec 2023 22:11:25 +0000 Subject: [PATCH] [mlir][tosa] Change 'shape' of RESHAPE from attribute to input shape type Co-authored-by: TatWai Chong Change-Id: I508cc1d67e9b017048b3f29fecf202cb7d707110 --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 2 +- .../mlir/Dialect/Tosa/Utils/ConversionUtils.h | 3 + .../Conversion/TosaToLinalg/TosaToLinalg.cpp | 5 +- .../Conversion/TosaToTensor/TosaToTensor.cpp | 8 +- .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 8 +- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 32 +++-- .../Tosa/Transforms/TosaDecomposeConv2D.cpp | 37 +++--- .../Transforms/TosaDecomposeDepthwise.cpp | 12 +- .../Transforms/TosaDecomposeTransposeConv.cpp | 23 +++- .../Tosa/Transforms/TosaReduceTransposes.cpp | 11 +- .../Dialect/Tosa/Utils/ConversionUtils.cpp | 19 +-- .../TosaToLinalg/tosa-to-linalg-invalid.mlir | 3 +- .../TosaToLinalg/tosa-to-linalg.mlir | 15 ++- .../TosaToTensor/tosa-to-tensor.mlir | 110 ++++++++++++------ mlir/test/Dialect/Tosa/canonicalize.mlir | 57 +++++---- mlir/test/Dialect/Tosa/constant-op-fold.mlir | 3 +- mlir/test/Dialect/Tosa/inlining.mlir | 3 +- mlir/test/Dialect/Tosa/invalid.mlir | 45 +++++-- mlir/test/Dialect/Tosa/level_check.mlir | 3 +- mlir/test/Dialect/Tosa/ops.mlir | 24 ++-- .../Dialect/Tosa/tosa-decompose-conv2d.mlir | 42 ++++--- .../Tosa/tosa-decompose-depthwise.mlir | 44 ++++--- .../Tosa/tosa-decompose-transpose-conv.mlir | 68 ++++++----- mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 41 ++++--- .../Dialect/Tosa/tosa-reduce-transposes.mlir | 77 ++++++------ 25 files changed, 449 insertions(+), 246 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 98bcbca3b02fa..840558a81493f 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1625,7 +1625,7 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape"> { let arguments = (ins Tosa_Tensor:$input1, - DenseI64ArrayAttr:$new_shape + Tosa_Shape:$shape ); let results = (outs diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h index 78a8828855437..88c2162928652 100644 --- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h @@ -230,8 +230,11 @@ SmallVector applyTOSAPermutation(ArrayRef input, } // Computes shape value using tosa const_shape op. +Value getTosaConstShape(ImplicitLocOpBuilder &builder, + llvm::ArrayRef shape); Value getTosaConstShape(PatternRewriter &rewriter, Location loc, llvm::ArrayRef shape); + SmallVector convertFromMlirShape(ArrayRef shape); bool getConstShapeValue(Operation *op, diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 67218cee518d5..e4f055ea2f5c4 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1954,9 +1954,10 @@ struct TileConverter : public OpConversionPattern { nestedBuilder.create(op.getLoc(), *args.begin()); }); + auto shapeValue = getTosaConstShape( + rewriter, loc, mlir::tosa::convertFromMlirShape(resultTy.getShape())); rewriter.replaceOpWithNewOp( - op, resultTy, genericOp.getResult(0), - rewriter.getDenseI64ArrayAttr(resultTy.getShape())); + op, resultTy, genericOp.getResult(0), shapeValue); return success(); } }; diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp index 2a9b4d111bdfa..7f029d56e2582 100644 --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" @@ -235,7 +236,12 @@ class ReshapeConverter : public OpConversionPattern { return rewriter.notifyMatchFailure(reshape.getLoc(), "expected input type to be tensor"); } - auto newShape = reshape.getNewShape(); + + llvm::SmallVector newShape; + if (!tosa::getConstShapeValue(reshape.getShape().getDefiningOp(), + newShape)) { + return failure(); + } // Infer all intermediate types auto inputType = inferReshapeInputType(input, newShape); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 8e22c879753a3..a9a65ac271b3c 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -180,7 +180,7 @@ struct TransposeIsReshape : public OpRewritePattern { rewriter.replaceOpWithNewOp( op, op.getType(), op.getInput1(), - rewriter.getDenseI64ArrayAttr(newShape)); + getTosaConstShape(rewriter, op.getLoc(), newShape)); return success(); } }; @@ -948,8 +948,12 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { if (!getInput1().hasOneUse()) return {}; + llvm::SmallVector shapeVec; + if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeVec)) + return {}; + return operand.reshape( - llvm::cast(operand.getType()).clone(getNewShape())); + llvm::cast(operand.getType()).clone(shapeVec)); } return {}; diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 031c279ff09e2..955021abdd67b 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -1335,8 +1335,16 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents( SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput1().getType()); Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType()); - llvm::SmallVector newShapeValue = - convertToMlirShape(adaptor.getNewShape()); + llvm::SmallVector newShapeValue; + if (!tosa::getConstShapeValue(adaptor.getShape().getDefiningOp(), + newShapeValue)) { + auto rank = cast(adaptor.getShape().getType()).getRank(); + SmallVector fallback(rank, ShapedType::kDynamic); + inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType)); + return success(); + } else { + newShapeValue = convertToMlirShape(newShapeValue); + } // We cannot infer from the total number of elements so we must take the // shape attribute as exact. @@ -1372,13 +1380,19 @@ llvm::LogicalResult tosa::ReshapeOp::verify() { TensorType inputType = getInput1().getType(); RankedTensorType outputType = getType(); - if ((int64_t)getNewShape().size() != outputType.getRank()) + SmallVector shapeValues; + if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeValues)) { + // skip following checks if shape is not constant + return mlir::success(); + } + + if ((int64_t)shapeValues.size() != outputType.getRank()) return emitOpError() << "new shape does not match result rank"; for (auto [newShapeDim, outputShapeDim] : - zip(getNewShape(), outputType.getShape())) { - if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic && - newShapeDim != outputShapeDim) + zip(shapeValues, outputType.getShape())) { + if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic && + outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim) return emitOpError() << "new shape is inconsistent with result shape"; if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1) @@ -1397,10 +1411,10 @@ llvm::LogicalResult tosa::ReshapeOp::verify() { } int64_t newShapeElementsNum = std::accumulate( - getNewShape().begin(), getNewShape().end(), 1LL, + shapeValues.begin(), shapeValues.end(), 1LL, [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; }); bool isStaticNewShape = - llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; }); + llvm::all_of(shapeValues, [](int64_t s) { return s > 0; }); if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) || (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) { return emitOpError() << "cannot reshape " << inputElementsNum @@ -1408,7 +1422,7 @@ llvm::LogicalResult tosa::ReshapeOp::verify() { } } - int missingDims = llvm::count(getNewShape(), -1); + int missingDims = llvm::count(shapeValues, -1); if (missingDims > 1) return emitOpError() << "expected at most one target dimension to be -1"; diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp index 4eba89b59bbd7..617a59bc87c9f 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp @@ -20,12 +20,6 @@ using namespace mlir::tosa; namespace { -SmallVector convertFromMlirShape(ArrayRef shape) { - return to_vector(llvm::map_range(shape, [](int64_t dim) { - return ShapedType::isDynamic(dim) ? -1 : dim; - })); -} - struct Conv2DIsFullyConnected : public OpRewritePattern { explicit Conv2DIsFullyConnected(MLIRContext *context) : OpRewritePattern(context) {} @@ -98,12 +92,13 @@ struct Conv2DIsFullyConnected : public OpRewritePattern { llvm::SmallVector revisedInputShape{combined, inputShape[3]}; auto revisedInputShapeType = RankedTensorType::get(revisedInputShape, inputType.getElementType()); - auto reshapedInput = rewriter - .create( - op.getLoc(), revisedInputShapeType, input, - rewriter.getDenseI64ArrayAttr( - convertFromMlirShape(revisedInputShape))) - .getResult(); + auto revisedInputShapeValue = getTosaConstShape( + rewriter, op.getLoc(), convertFromMlirShape(revisedInputShape)); + auto reshapedInput = + rewriter + .create(op.getLoc(), revisedInputShapeType, input, + revisedInputShapeValue) + .getResult(); // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC]. llvm::SmallVector revisedWeightShape{weightShape[0], @@ -111,12 +106,13 @@ struct Conv2DIsFullyConnected : public OpRewritePattern { auto revisedWeightShapeType = RankedTensorType::get( revisedWeightShape, dyn_cast(weight.getType()).getElementType()); - auto reshapedWeight = rewriter - .create( - op.getLoc(), revisedWeightShapeType, weight, - rewriter.getDenseI64ArrayAttr( - convertFromMlirShape(revisedWeightShape))) - .getResult(); + auto revisedWeightShapeValue = getTosaConstShape( + rewriter, op.getLoc(), convertFromMlirShape(revisedWeightShape)); + auto reshapedWeight = + rewriter + .create(op.getLoc(), revisedWeightShapeType, + weight, revisedWeightShapeValue) + .getResult(); // Perform a fully connected network over the reshaped input and weight. llvm::SmallVector fullyConnectedShape{combined, weightShape[0]}; @@ -149,9 +145,10 @@ struct Conv2DIsFullyConnected : public OpRewritePattern { // Reshape output to [N, IH, IW, OC]. llvm::SmallVector outputShape{inputShape[0], inputShape[1], inputShape[2], weightShape[0]}; + auto outputShapeValue = getTosaConstShape( + rewriter, op.getLoc(), convertFromMlirShape(outputShape)); rewriter.replaceOpWithNewOp( - op, resultType, fullyConnectedValue, - rewriter.getDenseI64ArrayAttr(convertFromMlirShape(outputShape))); + op, resultType, fullyConnectedValue, outputShapeValue); return success(); } }; diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp index ee857f1998a54..b26397d0e3ed7 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp @@ -55,10 +55,11 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern { inputType = RankedTensorType::get( revisedInputShape, dyn_cast(input.getType()).getElementType()); + auto revisedInputShapeValue = + getTosaConstShape(rewriter, op.getLoc(), revisedInputShape); input = rewriter - .create( - op.getLoc(), inputType, input, - rewriter.getDenseI64ArrayAttr(revisedInputShape)) + .create(op.getLoc(), inputType, input, + revisedInputShapeValue) .getResult(); Type inputETy = inputType.getElementType(); @@ -153,9 +154,10 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern { auto outputShapeType = RankedTensorType::get( outputShape, dyn_cast(input.getType()).getElementType()); + auto outputShapeValue = + getTosaConstShape(rewriter, op->getLoc(), outputShape); Value outputValue = rewriter.create( - op.getLoc(), outputShapeType, mulValue, - rewriter.getDenseI64ArrayAttr(outputShape)); + op.getLoc(), outputShapeType, mulValue, outputShapeValue); Value bias = op.getBias(); if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) { diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp index b5b3e9d76c47e..26baddcf1dd15 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -159,9 +159,11 @@ class TransposeConvStridedConverter outputChannels, weightHeight / stride[0], stride[0], weightWidth / stride[1], stride[1], inputChannels}; + + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); weight = CreateOpAndInferShape( - rewriter, loc, UnrankedTensorType::get(weightETy), weight, - rewriter.getDenseI64ArrayAttr(weightReshapeDims0)); + builder, UnrankedTensorType::get(weightETy), weight, + getTosaConstShape(rewriter, loc, weightReshapeDims0)); // Transpose the factored-out stride to the output channels. Value transposeWeightVal = rewriter.create( @@ -173,12 +175,13 @@ class TransposeConvStridedConverter transposeWeightVal); // Collapse the strides and output channels into a single dimension. - llvm::SmallVector weightReshapeDims1 = { + llvm::SmallVector weightReshapeDims1 = { outputChannels * stride[0] * stride[1], weightHeight / stride[0], weightWidth / stride[1], inputChannels}; + weight = CreateOpAndInferShape( rewriter, loc, UnrankedTensorType::get(weightETy), weight, - rewriter.getDenseI64ArrayAttr(weightReshapeDims1)); + getTosaConstShape(rewriter, loc, weightReshapeDims1)); ShapedType restridedWeightTy = cast(weight.getType()); weight = CreateOpAndInferShape( @@ -257,9 +260,13 @@ class TransposeConvStridedConverter // Factor striding out of the convolution result. llvm::SmallVector convReshapeDims0 = { batch, convHeight, convWidth, stride[0], stride[1], outputChannels}; + + auto convReshapeDims0Value = + getTosaConstShape(rewriter, loc, convReshapeDims0); + conv2d = CreateOpAndInferShape( rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, - rewriter.getDenseI64ArrayAttr(convReshapeDims0)); + convReshapeDims0Value); // Transpose the factored-out stride to the output channels. Value transposeConvVal = rewriter.create( @@ -273,9 +280,13 @@ class TransposeConvStridedConverter // Fuse striding behavior back into width / height. llvm::SmallVector convReshapeDims1 = { batch, convHeight * stride[0], convWidth * stride[1], outputChannels}; + + auto convReshapeDims1Value = + getTosaConstShape(rewriter, loc, convReshapeDims1); + conv2d = CreateOpAndInferShape( rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, - rewriter.getDenseI64ArrayAttr(convReshapeDims1)); + convReshapeDims1Value); // Determine the amount to slice / pad from the result start. int64_t resultSliceTop = std::max(0, -pad[0]); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp index 520f283a3ba88..281f0529a5c08 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp @@ -402,13 +402,20 @@ std::optional TosaReduceTransposes::buildMappedToValue( return std::nullopt; // Do not insert a TransposeOp, instead we fold the reshape and its attribute. + llvm::SmallVector newShape; + if (!tosa::getConstShapeValue(reshapeOp.getShape().getDefiningOp(), + newShape)) { + // this mean shape is not constant + return std::nullopt; + } + ImplicitLocOpBuilder builder(reshapeOp.getLoc(), rewriter); auto foldedReshape = rewriter.create( reshapeOp.getLoc(), RankedTensorType::get(applyTOSAPermutation(shape, hoistedPerms), reshapeOutputType.getElementType()), reshapeOp.getInput1(), - rewriter.getDenseI64ArrayAttr( - applyTOSAPermutation(reshapeOp.getNewShape(), hoistedPerms))); + getTosaConstShape(builder, applyTOSAPermutation(llvm::ArrayRef(newShape), + hoistedPerms))); return foldedReshape->getResult(0); } diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp index 62b0bc1857e39..8ab12d038849f 100644 --- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp @@ -145,10 +145,10 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder, llvm::cast(lowerTensorValue.getType()); auto reshapeOutputType = RankedTensorType::get( ArrayRef(reshapeOutputShape), reshapeInputType.getElementType()); + auto reshapeOutputShapeValue = getTosaConstShape(builder, reshapeOutputShape); auto reshapeLower = builder.create( - reshapeOutputType, lowerTensorValue, - builder.getDenseI64ArrayAttr(reshapeOutputShape)); + reshapeOutputType, lowerTensorValue, reshapeOutputShapeValue); if (input1Rank > input2Rank) { input1 = higherTensorValue; @@ -161,15 +161,20 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder, return success(); } -Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc, +Value mlir::tosa::getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef shape) { - auto attr = rewriter.getIndexTensorAttr(shape); - auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size()); - mlir::Operation *mlir_op = - rewriter.create(loc, type, attr); + auto attr = builder.getIndexTensorAttr(convertFromMlirShape(shape)); + auto type = mlir::tosa::shapeType::get(builder.getContext(), shape.size()); + mlir::Operation *mlir_op = builder.create(type, attr); return mlir_op->getResult(0); } +Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc, + llvm::ArrayRef shape) { + ImplicitLocOpBuilder builder(loc, rewriter); + return getTosaConstShape(builder, shape); +} + SmallVector mlir::tosa::convertFromMlirShape(ArrayRef shape) { return to_vector(llvm::map_range(shape, [](int64_t dim) { return ShapedType::isDynamic(dim) ? -1 : dim; diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir index 75b48f2b06d89..460e207d62de6 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir @@ -24,7 +24,8 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, % %reduce = tosa.reduce_max %arg0 {axis = 1 : i32} : (tensor<10x10xf32>) -> tensor<10x1xf32> %1 = tosa.add %reduce, %arg1 : (tensor<10x1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> %0 = tosa.add %1, %arg2 : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32> - %2 = tosa.reshape %0 {new_shape = array} : (tensor<*xf32>) -> tensor<10x10xf32> + %s = tosa.const_shape {value = dense<[10, 10]> : tensor<2xindex>} : () -> !tosa.shape<2> + %2 = tosa.reshape %0, %s : (tensor<*xf32>, !tosa.shape<2>) -> tensor<10x10xf32> return %2 : tensor<10x10xf32> } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 6e8501aaaf2af..3031434e6d4ba 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1387,7 +1387,8 @@ func.func @tile(%arg0 : tensor<2x3xi8>) -> () { // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<2x2x1x3xi8>) // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8 // CHECK: linalg.yield %[[ARG1]] : i8 - // CHECK: tosa.reshape [[GENERIC]] {new_shape = array} + // CHECK: [[CONST3:%.+]] = tosa.const_shape {value = dense<[4, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> + // CHECK: tosa.reshape [[GENERIC]], [[CONST3]] %cst21 = tosa.const_shape { value = dense<[2, 1]> : tensor<2xindex> } : () -> !tosa.shape<2> %0 = tosa.tile %arg0, %cst21: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<4x3xi8> @@ -1395,7 +1396,8 @@ func.func @tile(%arg0 : tensor<2x3xi8>) -> () { // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<1x2x2x3xi8>) // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8 // CHECK: linalg.yield %[[ARG1]] : i8 - // CHECK: tosa.reshape [[GENERIC]] {new_shape = array} + // CHECK: [[CONST8:%.+]] = tosa.const_shape {value = dense<[2, 6]> : tensor<2xindex>} : () -> !tosa.shape<2> + // tosa.reshape [[GENERIC]], [[CONST8]] %cst12 = tosa.const_shape { value = dense<[1, 2]> : tensor<2xindex> } : () -> !tosa.shape<2> %1 = tosa.tile %arg0, %cst12: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<2x6xi8> @@ -1403,8 +1405,9 @@ func.func @tile(%arg0 : tensor<2x3xi8>) -> () { // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<5x2x7x3xi8>) // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8 // CHECK: linalg.yield %[[ARG1]] : i8 - // CHECK: tosa.reshape [[GENERIC]] {new_shape = array} %cst57 = tosa.const_shape { value = dense<[5, 7]> : tensor<2xindex> } : () -> !tosa.shape<2> + // CHECK: [[CONST13:%.+]] = tosa.const_shape {value = dense<[10, 21]> : tensor<2xindex>} : () -> !tosa.shape<2> + // CHECK: tosa.reshape [[GENERIC]], [[CONST13]] %2 = tosa.tile %arg0, %cst57: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<10x21xi8> return @@ -1424,7 +1427,8 @@ func.func @tile_dyn_input(%arg0 : tensor) -> () { // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor) outs(%[[INIT]] : tensor<2x?x1x3xi8>) // CHECK: ^bb0(%[[ARG1:.+]]: i8, // CHECK: linalg.yield %[[ARG1]] : i8 - // CHECK: tosa.reshape %[[GENERIC]] {new_shape = array} + // CHECK: %[[CONST3:.+]] = tosa.const_shape {value = dense<[-1, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> + // CHECK: tosa.reshape %[[GENERIC]], %[[CONST3]] %cst21 = tosa.const_shape { value = dense<[2, 1]> : tensor<2xindex> } : () -> !tosa.shape<2> %0 = tosa.tile %arg0, %cst21: (tensor, !tosa.shape<2>) -> tensor @@ -1445,7 +1449,8 @@ func.func @tile_dyn_multiples(%arg0 : tensor<2x3xi8>) -> () { // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs(%[[INIT]] : tensor<2x2x?x3xi8>) // CHECK: ^bb0(%[[ARG1:.+]]: i8, // CHECK: linalg.yield %[[ARG1]] : i8 - // CHECK: tosa.reshape %[[GENERIC]] {new_shape = array} + // CHECK: %[[CONST2:.+]] = tosa.const_shape {value = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // CHECK: tosa.reshape %[[GENERIC]], %[[CONST2]] %cst = tosa.const_shape { value = dense<[2, -1]> : tensor<2xindex> } : () -> !tosa.shape<2> %0 = tosa.tile %arg0, %cst: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<2x?xi8> diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir index e83e898644bc0..c2eaba4c563d0 100644 --- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir +++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir @@ -6,7 +6,8 @@ // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor // CHECK: return %[[ARG_0]] : tensor func.func @test_reshape_0d_same_s2s_explicit(%arg0: tensor) -> tensor { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor + %s = tosa.const_shape { value = dense<> : tensor<0xindex> } : () -> !tosa.shape<0> + %0 = "tosa.reshape"(%arg0, %s) : (tensor, !tosa.shape<0>) -> tensor return %0 : tensor } @@ -18,7 +19,8 @@ func.func @test_reshape_0d_same_s2s_explicit(%arg0: tensor) -> tensor // CHECK: %[[VAL_1:.*]] = tensor.cast %[[VAL_0]] : tensor<1xf32> to tensor // CHECK: return %[[VAL_1]] : tensor func.func @test_reshape_0d_up_s2d_auto(%arg0: tensor) -> tensor { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor + %s = tosa.const_shape { value = dense<-1> : tensor<1xindex> } : () -> !tosa.shape<1> + %0 = "tosa.reshape"(%arg0, %s) : (tensor, !tosa.shape<1>) -> tensor return %0 : tensor } @@ -30,7 +32,8 @@ func.func @test_reshape_0d_up_s2d_auto(%arg0: tensor) -> tensor { // CHECK: %[[VAL_1:.*]] = tensor.cast %[[VAL_0]] : tensor<1xf32> to tensor // CHECK: return %[[VAL_1]] : tensor func.func @test_reshape_0d_up_s2d_explicit(%arg0: tensor) -> tensor { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor + %s = tosa.const_shape { value = dense<1> : tensor<1xindex> } : () -> !tosa.shape<1> + %0 = "tosa.reshape"(%arg0, %s) : (tensor, !tosa.shape<1>) -> tensor return %0 : tensor } @@ -41,7 +44,8 @@ func.func @test_reshape_0d_up_s2d_explicit(%arg0: tensor) -> tensor // CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] output_shape [1] : tensor into tensor<1xf32> // CHECK: return %[[VAL_0]] : tensor<1xf32> func.func @test_reshape_0d_up_s2s_auto(%arg0: tensor) -> tensor<1xf32> { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor<1xf32> + %s = tosa.const_shape { value = dense<-1> : tensor<1xindex> } : () -> !tosa.shape<1> + %0 = "tosa.reshape"(%arg0, %s) : (tensor, !tosa.shape<1>) -> tensor<1xf32> return %0 : tensor<1xf32> } @@ -52,7 +56,8 @@ func.func @test_reshape_0d_up_s2s_auto(%arg0: tensor) -> tensor<1xf32> { // CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] output_shape [1] : tensor into tensor<1xf32> // CHECK: return %[[VAL_0]] : tensor<1xf32> func.func @test_reshape_0d_up_s2s_explicit(%arg0: tensor) -> tensor<1xf32> { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor<1xf32> + %s = tosa.const_shape { value = dense<1> : tensor<1xindex> } : () -> !tosa.shape<1> + %0 = "tosa.reshape"(%arg0, %s) : (tensor, !tosa.shape<1>) -> tensor<1xf32> return %0 : tensor<1xf32> } @@ -64,7 +69,8 @@ func.func @test_reshape_0d_up_s2s_explicit(%arg0: tensor) -> tensor<1xf32> // CHECK: %[[VAL_1:.*]] = tensor.collapse_shape %[[VAL_0]] [] : tensor<1xf32> into tensor // CHECK: return %[[VAL_1]] : tensor func.func @test_reshape_1d_down_d2s_explicit(%arg0: tensor) -> tensor { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor + %s = tosa.const_shape { value = dense<> : tensor<0xindex> } : () -> !tosa.shape<0> + %0 = "tosa.reshape"(%arg0, %s) : (tensor, !tosa.shape<0>) -> tensor return %0 : tensor } @@ -75,7 +81,8 @@ func.func @test_reshape_1d_down_d2s_explicit(%arg0: tensor) -> tensor into tensor // CHECK: return %[[VAL_0]] : tensor func.func @test_reshape_1d_down_s2s_explicit(%arg0: tensor<1xf32>) -> tensor { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<1xf32>) -> tensor + %s = tosa.const_shape { value = dense<> : tensor<0xindex> } : () -> !tosa.shape<0> + %0 = "tosa.reshape"(%arg0, %s) : (tensor<1xf32>, !tosa.shape<0>) -> tensor return %0 : tensor } @@ -90,7 +97,8 @@ func.func @test_reshape_1d_down_s2s_explicit(%arg0: tensor<1xf32>) -> tensor into tensor<2x?xf32> // CHECK: return %[[EXPANDED]] : tensor<2x?xf32> func.func @test_reshape_1d_up_d2d_auto(%arg0: tensor) -> tensor<2x?xf32> { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor<2x?xf32> + %s = tosa.const_shape { value = dense<[2, -1]> : tensor<2xindex> } : () -> !tosa.shape<2> + %0 = "tosa.reshape"(%arg0, %s) : (tensor, !tosa.shape<2>) -> tensor<2x?xf32> return %0 : tensor<2x?xf32> } @@ -101,7 +109,8 @@ func.func @test_reshape_1d_up_d2d_auto(%arg0: tensor) -> tensor<2x?xf32> // CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] {{\[\[}}0, 1]] output_shape [2, 3] : tensor<6xf32> into tensor<2x3xf32> // CHECK: return %[[VAL_0]] : tensor<2x3xf32> func.func @test_reshape_1d_up_s2s_explicit(%arg0: tensor<6xf32>) -> tensor<2x3xf32> { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<6xf32>) -> tensor<2x3xf32> + %s = tosa.const_shape { value = dense<[2, 3]> : tensor<2xindex> } : () -> !tosa.shape<2> + %0 = "tosa.reshape"(%arg0, %s) : (tensor<6xf32>, !tosa.shape<2>) -> tensor<2x3xf32> return %0 : tensor<2x3xf32> } @@ -112,7 +121,8 @@ func.func @test_reshape_1d_up_s2s_explicit(%arg0: tensor<6xf32>) -> tensor<2x3xf // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<2x?xf32> into tensor // CHECK: return %[[VAL_0]] : tensor func.func @test_reshape_2d_down_d2d_auto(%arg0: tensor<2x?xf32>) -> tensor { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<2x?xf32>) -> tensor + %s = tosa.const_shape { value = dense<-1> : tensor<1xindex> } : () -> !tosa.shape<1> + %0 = "tosa.reshape"(%arg0, %s) : (tensor<2x?xf32>, !tosa.shape<1>) -> tensor return %0 : tensor } @@ -123,7 +133,8 @@ func.func @test_reshape_2d_down_d2d_auto(%arg0: tensor<2x?xf32>) -> tensor into tensor<6xf32> // CHECK: return %[[VAL_0]] : tensor<6xf32> func.func @test_reshape_2d_down_s2s_explicit(%arg0: tensor<2x3xf32>) -> tensor<6xf32> { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<2x3xf32>) -> tensor<6xf32> + %s = tosa.const_shape { value = dense<6> : tensor<1xindex> } : () -> !tosa.shape<1> + %0 = "tosa.reshape"(%arg0, %s) : (tensor<2x3xf32>, !tosa.shape<1>) -> tensor<6xf32> return %0 : tensor<6xf32> } @@ -139,7 +150,8 @@ func.func @test_reshape_2d_down_s2s_explicit(%arg0: tensor<2x3xf32>) -> tensor<6 // CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] output_shape [2, %[[DIV]]] : tensor into tensor<2x?xf32> // CHECK: return %[[EXPANDED]] : tensor<2x?xf32> func.func @test_reshape_2d_same_d2d_auto(%arg0: tensor) -> tensor<2x?xf32> { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor<2x?xf32> + %s = tosa.const_shape { value = dense<[2, -1]> : tensor<2xindex> } : () -> !tosa.shape<2> + %0 = "tosa.reshape"(%arg0, %s) : (tensor, !tosa.shape<2>) -> tensor<2x?xf32> return %0 : tensor<2x?xf32> } @@ -152,10 +164,12 @@ func.func @test_reshape_2d_same_d2d_auto(%arg0: tensor) -> tensor<2x?xf // CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<4x2xf32> to tensor // CHECK: return %[[VAL_2]] : tensor func.func @test_reshape_2d_same_s2d_auto(%arg0: tensor<2x4xf32>) -> tensor { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<2x4xf32>) -> tensor + %s = tosa.const_shape { value = dense<[-1, 2]> : tensor<2xindex> } : () -> !tosa.shape<2> + %0 = "tosa.reshape"(%arg0, %s) : (tensor<2x4xf32>, !tosa.shape<2>) -> tensor return %0 : tensor } + // ----- // CHECK-LABEL: test_reshape_2d_same_s2d_explicit @@ -165,7 +179,8 @@ func.func @test_reshape_2d_same_s2d_auto(%arg0: tensor<2x4xf32>) -> tensor to tensor // CHECK: return %[[VAL_2]] : tensor func.func @test_reshape_2d_same_s2d_explicit(%arg0: tensor<2x4xf32>) -> tensor { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<2x4xf32>) -> tensor + %s = tosa.const_shape { value = dense<[4, 2]> : tensor<2xindex> } : () -> !tosa.shape<2> + %0 = "tosa.reshape"(%arg0, %s) : (tensor<2x4xf32>, !tosa.shape<2>) -> tensor return %0 : tensor } @@ -177,7 +192,8 @@ func.func @test_reshape_2d_same_s2d_explicit(%arg0: tensor<2x4xf32>) -> tensor into tensor<2x3xf32> // CHECK: return %[[VAL_1]] : tensor<2x3xf32> func.func @test_reshape_2d_same_s2s_explicit(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<3x2xf32>) -> tensor<2x3xf32> + %s = tosa.const_shape { value = dense<[2, 3]> : tensor<2xindex> } : () -> !tosa.shape<2> + %0 = "tosa.reshape"(%arg0, %s) : (tensor<3x2xf32>, !tosa.shape<2>) -> tensor<2x3xf32> return %0 : tensor<2x3xf32> } @@ -194,7 +210,8 @@ func.func @test_reshape_2d_same_s2s_explicit(%arg0: tensor<3x2xf32>) -> tensor<2 // CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<0x3x?xf32> to tensor // CHECK: return %[[VAL_2]] : tensor func.func @test_reshape_3d_same_d2d_auto_empty(%arg0: tensor<3x2x?xf32>) -> tensor { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<3x2x?xf32>) -> tensor + %s = tosa.const_shape { value = dense<[0, 3, -1]> : tensor<3xindex> } : () -> !tosa.shape<3> + %0 = "tosa.reshape"(%arg0, %s) : (tensor<3x2x?xf32>, !tosa.shape<3>) -> tensor return %0 : tensor } @@ -211,7 +228,8 @@ func.func @test_reshape_3d_same_d2d_auto_empty(%arg0: tensor<3x2x?xf32>) -> tens // CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<2x?x4xf32> to tensor // CHECK: return %[[VAL_2]] : tensor func.func @test_reshape_3d_same_d2d_auto(%arg0: tensor<2x?x?xf32>) -> tensor { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<2x?x?xf32>) -> tensor + %s = tosa.const_shape { value = dense<[2, -1, 4]> : tensor<3xindex> } : () -> !tosa.shape<3> + %0 = "tosa.reshape"(%arg0, %s) : (tensor<2x?x?xf32>, !tosa.shape<3>) -> tensor return %0 : tensor } @@ -227,7 +245,8 @@ func.func @test_reshape_3d_same_d2d_auto(%arg0: tensor<2x?x?xf32>) -> tensor into tensor<2x3x?xf32> // CHECK: return %[[VAL_1]] : tensor<2x3x?xf32> func.func @test_reshape_3d_same_d2d_auto_identity(%arg0: tensor) -> tensor<2x3x?xf32> { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor<2x3x?xf32> + %s = tosa.const_shape { value = dense<[2, 3, -1]> : tensor<3xindex> } : () -> !tosa.shape<3> + %0 = "tosa.reshape"(%arg0, %s) : (tensor, !tosa.shape<3>) -> tensor<2x3x?xf32> return %0 : tensor<2x3x?xf32> } @@ -244,7 +263,8 @@ func.func @test_reshape_3d_same_d2d_auto_identity(%arg0: tensor) -> t // CHECK: %[[VAL_2:.*]] = tensor.cast %[[EXPANDED]] : tensor to tensor // CHECK: return %[[VAL_2]] : tensor func.func @test_reshape_3d_same_d2d_explicit_empty(%arg0: tensor<3x2x?xf32>) -> tensor { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<3x2x?xf32>) -> tensor + %s = tosa.const_shape { value = dense<[0, 3, 2]> : tensor<3xindex> } : () -> !tosa.shape<3> + %0 = "tosa.reshape"(%arg0, %s) : (tensor<3x2x?xf32>, !tosa.shape<3>) -> tensor return %0 : tensor } @@ -261,7 +281,8 @@ func.func @test_reshape_3d_same_d2d_explicit_empty(%arg0: tensor<3x2x?xf32>) -> // CHECK: %[[VAL_2:.*]] = tensor.cast %[[EXPANDED]] : tensor to tensor // CHECK: return %[[VAL_2]] : tensor func.func @test_reshape_3d_same_d2d_explicit(%arg0: tensor) -> tensor { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor + %s = tosa.const_shape { value = dense<[2, 3, 4]> : tensor<3xindex> } : () -> !tosa.shape<3> + %0 = "tosa.reshape"(%arg0, %s) : (tensor, !tosa.shape<3>) -> tensor return %0 : tensor } @@ -272,7 +293,8 @@ func.func @test_reshape_3d_same_d2d_explicit(%arg0: tensor) -> tensor // CHECK: %[[VAL_0:.*]] = tensor.cast %[[ARG_0]] : tensor to tensor<2x3x?xf32> // CHECK: return %[[VAL_0]] : tensor<2x3x?xf32> func.func @test_reshape_3d_same_d2d_explicit_identity(%arg0: tensor) -> tensor<2x3x?xf32> { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor<2x3x?xf32> + %s = tosa.const_shape { value = dense<[2, 3, 4]> : tensor<3xindex> } : () -> !tosa.shape<3> + %0 = "tosa.reshape"(%arg0, %s) : (tensor, !tosa.shape<3>) -> tensor<2x3x?xf32> return %0 : tensor<2x3x?xf32> } @@ -289,7 +311,8 @@ func.func @test_reshape_3d_same_d2d_explicit_identity(%arg0: tensor) // CHECK: %[[VAL_2:.*]] = tensor.cast %[[EXPANDED]] : tensor<2x?x4xf32> to tensor<2x3x4xf32> // CHECK: return %[[VAL_2]] : tensor<2x3x4xf32> func.func @test_reshape_3d_same_d2s_auto(%arg0: tensor) -> tensor<2x3x4xf32> { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor<2x3x4xf32> + %s = tosa.const_shape { value = dense<[2, -1, 4]> : tensor<3xindex> } : () -> !tosa.shape<3> + %0 = "tosa.reshape"(%arg0, %s) : (tensor, !tosa.shape<3>) -> tensor<2x3x4xf32> return %0 : tensor<2x3x4xf32> } @@ -306,7 +329,8 @@ func.func @test_reshape_3d_same_d2s_auto(%arg0: tensor) -> tensor<2x3 // CHECK: %[[VAL_2:.*]] = tensor.cast %[[EXPANDED]] : tensor to tensor<2x3x4xf32> // CHECK: return %[[VAL_2]] : tensor<2x3x4xf32> func.func @test_reshape_3d_same_d2s_explicit(%arg0: tensor) -> tensor<2x3x4xf32> { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor<2x3x4xf32> + %s = tosa.const_shape { value = dense<[2, 3, 4]> : tensor<3xindex> } : () -> !tosa.shape<3> + %0 = "tosa.reshape"(%arg0, %s) : (tensor, !tosa.shape<3>) -> tensor<2x3x4xf32> return %0 : tensor<2x3x4xf32> } @@ -316,7 +340,8 @@ func.func @test_reshape_3d_same_d2s_explicit(%arg0: tensor) -> tensor // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x3x4xf32> // CHECK: return %[[ARG_0]] : tensor<2x3x4xf32> func.func @test_reshape_3d_same_s2s_explicit_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32> + %s = tosa.const_shape { value = dense<[2, 3, 4]> : tensor<3xindex> } : () -> !tosa.shape<3> + %0 = "tosa.reshape"(%arg0, %s) : (tensor<2x3x4xf32>, !tosa.shape<3>) -> tensor<2x3x4xf32> return %0 : tensor<2x3x4xf32> } @@ -333,7 +358,8 @@ func.func @test_reshape_3d_same_s2s_explicit_identity(%arg0: tensor<2x3x4xf32>) // CHECK: %[[CAST:.*]] = tensor.cast %[[EXPANDED]] : tensor to tensor<1x3x2x1xf32> // CHECK: return %[[CAST]] : tensor<1x3x2x1xf32> func.func @test_reshape_3d_up_d2s_explicit(%input: tensor) -> tensor<1x3x2x1xf32> { - %0 = tosa.reshape %input {new_shape = array} : (tensor) -> tensor<1x3x2x1xf32> + %s = tosa.const_shape { value = dense<[1, 3, 2, 1]> : tensor<4xindex> } : () -> !tosa.shape<4> + %0 = tosa.reshape %input, %s : (tensor, !tosa.shape<4>) -> tensor<1x3x2x1xf32> return %0 : tensor<1x3x2x1xf32> } @@ -345,7 +371,8 @@ func.func @test_reshape_3d_up_d2s_explicit(%input: tensor) -> tensor< // CHECK: %[[VAL_1:.*]] = tensor.collapse_shape %[[VAL_0]] [] : tensor<1x1x1x1xf32> into tensor // CHECK: return %[[VAL_1]] : tensor func.func @test_reshape_4d_down_d2s_explicit(%arg0: tensor) -> tensor { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor + %s = tosa.const_shape { value = dense<> : tensor<0xindex> } : () -> !tosa.shape<0> + %0 = "tosa.reshape"(%arg0, %s) : (tensor, !tosa.shape<0>) -> tensor return %0 : tensor } @@ -361,7 +388,8 @@ func.func @test_reshape_4d_down_d2s_explicit(%arg0: tensor) -> tens // CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2]] output_shape [%[[VAL_0]], 2, 3] : tensor into tensor // CHECK: return %[[EXPANDED]] : tensor func.func @test_reshape_5d_down_d2d_auto(%arg0: tensor) -> tensor { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor + %s = tosa.const_shape { value = dense<[-1, 2, 3]> : tensor<3xindex> } : () -> !tosa.shape<3> + %0 = "tosa.reshape"(%arg0, %s) : (tensor, !tosa.shape<3>) -> tensor return %0 : tensor } @@ -377,7 +405,8 @@ func.func @test_reshape_5d_down_d2d_auto(%arg0: tensor) -> tensor // CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2]] output_shape [%[[VAL_0]], 5, 77] : tensor into tensor // CHECK: return %[[EXPANDED]] : tensor func.func @test_reshape_6d_down_d2d_auto(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<1x2x?x5x7x11xf32>) -> tensor + %s = tosa.const_shape { value = dense<[-1, 5, 77]> : tensor<3xindex> } : () -> !tosa.shape<3> + %0 = "tosa.reshape"(%arg0, %s) : (tensor<1x2x?x5x7x11xf32>, !tosa.shape<3>) -> tensor return %0 : tensor } @@ -388,7 +417,8 @@ func.func @test_reshape_6d_down_d2d_auto(%arg0: tensor<1x2x?x5x7x11xf32>) -> ten // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2], [3], [4, 5]] : tensor<1x2x3x5x7x11xf32> into tensor<6x5x77xf32> // CHECK: return %[[VAL_0]] : tensor<6x5x77xf32> func.func @test_reshape_6d_down_s2s_auto(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> + %s = tosa.const_shape { value = dense<[6, 5, -1]> : tensor<3xindex> } : () -> !tosa.shape<3> + %0 = "tosa.reshape"(%arg0, %s) : (tensor<1x2x3x5x7x11xf32>, !tosa.shape<3>) -> tensor<6x5x77xf32> return %0 : tensor<6x5x77xf32> } @@ -400,10 +430,13 @@ func.func @test_reshape_6d_down_s2s_auto(%arg0: tensor<1x2x3x5x7x11xf32>) -> ten // // See https://github.com/llvm/llvm-project/pull/91521 for a full description. +// ----- + // CHECK-LABEL: reshape_bug_fix // CHECK: tensor.expand_shape func.func @reshape_bug_fix(%arg0: tensor) -> tensor<1x1x1x?xf32> { - %0 = tosa.reshape %arg0 {new_shape = array} : (tensor) -> tensor<1x1x1x?xf32> + %1 = "tosa.const_shape"() {value = dense<[1, 1, 1, -1]> : tensor<4xindex>} : () -> !tosa.shape<4> + %0 = "tosa.reshape"(%arg0, %1) : (tensor, !tosa.shape<4>) -> tensor<1x1x1x?xf32> return %0 : tensor<1x1x1x?xf32> } @@ -414,21 +447,22 @@ func.func @reshape_bug_fix(%arg0: tensor) -> tensor<1x1x1x?xf32> { // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2], [3], [4, 5]] : tensor<1x2x3x5x7x11xf32> into tensor<6x5x77xf32> // CHECK: return %[[VAL_0]] : tensor<6x5x77xf32> func.func @test_reshape_6d_down_s2s_explicit(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> { - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> + %s = tosa.const_shape { value = dense<[6, 5, 77]> : tensor<3xindex> } : () -> !tosa.shape<3> + %0 = "tosa.reshape"(%arg0, %s) : (tensor<1x2x3x5x7x11xf32>, !tosa.shape<3>) -> tensor<6x5x77xf32> return %0 : tensor<6x5x77xf32> } // ----- // CHECK-LABEL: @test_reshape_samerank_unsigned -// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xui8>) +// CHECK-SAME: (%[[VAL_0:.*]]: tensor<3x2xui8>) func.func @test_reshape_samerank_unsigned(%arg0: tensor<3x2xui8>) -> tensor<2x3xui8> { - // CHECK-NEXT: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : tensor<3x2xui8> to tensor<3x2xi8> - // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[CAST1]] {{\[}}[0, 1]] : tensor<3x2xi8> into tensor<6xi8> - // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] output_shape {{\[}}2, 3] : tensor<6xi8> into tensor<2x3xi8> - // CHECK-NEXT: %[[CAST2:.*]] = builtin.unrealized_conversion_cast %[[RESHAPE2]] : tensor<2x3xi8> to tensor<2x3xui8 - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<3x2xui8>) -> tensor<2x3xui8> - // CHECK-NEXT: return %[[CAST2]] + // CHECK: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<3x2xui8> to tensor<3x2xi8> + // CHECK: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[CAST1]] {{\[}}[0, 1]] : tensor<3x2xi8> into tensor<6xi8> + // CHECK: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] output_shape {{\[}}2, 3] : tensor<6xi8> into tensor<2x3xi8> + // CHECK: %[[CAST2:.*]] = builtin.unrealized_conversion_cast %[[RESHAPE2]] : tensor<2x3xi8> to tensor<2x3xui8 + %s = tosa.const_shape { value = dense<[2, 3]> : tensor<2xindex> } : () -> !tosa.shape<2> + %0 = "tosa.reshape"(%arg0, %s): (tensor<3x2xui8>, !tosa.shape<2>) -> tensor<2x3xui8> return %0 : tensor<2x3xui8> } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index e0e1de6a94d10..582fd77cd7bc8 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -542,17 +542,20 @@ func.func @reduce_sum_nofold(%arg0: tensor) -> tensor { // CHECK-LABEL: @reshape_canonicalize func.func @reshape_canonicalize(%arg0: tensor) -> tensor { // CHECK: return %arg0 - %0 = tosa.reshape %arg0 {new_shape = array}: (tensor) -> tensor - return %0 : tensor + %0 = "tosa.const_shape"() {value = dense<[-1, 10]> : tensor<2xindex>} : () -> !tosa.shape<2> + %1 = tosa.reshape %arg0, %0 : (tensor, !tosa.shape<2>) -> tensor + return %1 : tensor } // ----- // CHECK-LABEL: @reshape_canonicalize_dyn_nofold func.func @reshape_canonicalize_dyn_nofold(%arg0: tensor) -> tensor { - // CHECK: %[[VAR0:.+]] = tosa.reshape %arg0 {new_shape = array} : (tensor) -> tensor + // CHECK: %[[SHAPE:.+]] = tosa.const_shape {value = dense<[-1, 2, 10]> : tensor<3xindex>} : () -> !tosa.shape<3> + // CHECK: %[[VAR0:.+]] = tosa.reshape %arg0, %[[SHAPE]] : (tensor, !tosa.shape<3>) -> tensor // CHECK: return %[[VAR0]] : tensor - %0 = tosa.reshape %arg0 {new_shape = array} : (tensor) -> tensor + %s = "tosa.const_shape"() {value = dense<[-1, 2, 10]> : tensor<3xindex>} : () -> !tosa.shape<3> + %0 = tosa.reshape %arg0, %s : (tensor, !tosa.shape<3>) -> tensor return %0 : tensor } @@ -560,10 +563,13 @@ func.func @reshape_canonicalize_dyn_nofold(%arg0: tensor) -> tensor< // CHECK-LABEL: @reshape_canonicalize_double func.func @reshape_canonicalize_double(%arg0: tensor) -> tensor { - // CHECK: %[[VAL_1:.*]] = tosa.reshape %arg0 {new_shape = array} + // CHECK: %[[VAL_0:.*]] = tosa.const_shape {value = dense<[-1, 5]> : tensor<2xindex>} : () -> !tosa.shape<2> + // CHECK: %[[VAL_1:.*]] = tosa.reshape %arg0, %[[VAL_0]] // CHECK: return %[[VAL_1]] - %0 = tosa.reshape %arg0 {new_shape = array}: (tensor) -> tensor<5x?xf32> - %1 = tosa.reshape %0 {new_shape = array}: (tensor<5x?xf32>) -> tensor + %cst0 = "tosa.const_shape"() <{value = dense<[5, -1]> : tensor<2xindex>}> : () -> !tosa.shape<2> + %0 = tosa.reshape %arg0, %cst0 : (tensor, !tosa.shape<2>) -> tensor<5x?xf32> + %cst1 = "tosa.const_shape"() <{value = dense<[-1, 5]> : tensor<2xindex>}> : () -> !tosa.shape<2> + %1 = tosa.reshape %0, %cst1 : (tensor<5x?xf32>, !tosa.shape<2>) -> tensor return %1 : tensor } @@ -574,8 +580,9 @@ func.func @reshape_canonicalize_const() -> tensor<1x5xi32> { // CHECK: %[[VAR0:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 1, 2, 3, 4]]> : tensor<1x5xi32>} // CHECK: return %[[VAR0]] %0 = "tosa.const"() {value = dense<[0, 1, 2, 3, 4]> : tensor<5xi32>} : () -> tensor<5xi32> - %1 = tosa.reshape %0 {new_shape = array} : (tensor<5xi32>) -> tensor<1x5xi32> - return %1 : tensor<1x5xi32> + %1 = "tosa.const_shape"() {value = dense<[1, 5]> : tensor<2xindex>} : () -> !tosa.shape<2> + %2 = tosa.reshape %0, %1 : (tensor<5xi32>, !tosa.shape<2>) -> tensor<1x5xi32> + return %2 : tensor<1x5xi32> } // ----- @@ -584,7 +591,8 @@ func.func @reshape_canonicalize_const() -> tensor<1x5xi32> { func.func @reshape_canonicalize_const_dynamic() -> tensor<1x?xi32> { // CHECK: tosa.reshape %0 = "tosa.const"() {value = dense<[0, 1, 2, 3, 4]> : tensor<5xi32>} : () -> tensor<5xi32> - %1 = tosa.reshape %0 {new_shape = array} : (tensor<5xi32>) -> tensor<1x?xi32> + %2 = "tosa.const_shape"() {value = dense<[1, 5]> : tensor<2xindex>} : () -> !tosa.shape<2> + %1 = tosa.reshape %0, %2 : (tensor<5xi32>, !tosa.shape<2>) -> tensor<1x?xi32> return %1 : tensor<1x?xi32> } @@ -596,7 +604,8 @@ func.func @reshape_canonicalize_const_splat() -> (tensor<10xi32>, tensor<1x10xi3 // CHECK-DAG: %[[VAR1:.+]] = "tosa.const"() <{value = dense<0> : tensor<1x10xi32>} // CHECK: return %[[VAR0]], %[[VAR1]] %0 = "tosa.const"() {value = dense<0> : tensor<10xi32>} : () -> tensor<10xi32> - %1 = tosa.reshape %0 {new_shape = array} : (tensor<10xi32>) -> tensor<1x10xi32> + %2 = "tosa.const_shape"() {value = dense<[1, 10]> : tensor<2xindex>} : () -> !tosa.shape<2> + %1 = tosa.reshape %0, %2 : (tensor<10xi32>, !tosa.shape<2>) -> tensor<1x10xi32> return %0 , %1 : tensor<10xi32>, tensor<1x10xi32> } @@ -606,7 +615,8 @@ func.func @reshape_canonicalize_const_splat() -> (tensor<10xi32>, tensor<1x10xi3 func.func @reshape_canonicalize_const_sparse() -> (tensor<3xi32>, tensor<1x3xi32>) { // CHECK: tosa.reshape %0 = "tosa.const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : ()-> tensor<3xi32> - %1 = tosa.reshape %0 {new_shape = array} : (tensor<3xi32>) -> tensor<1x3xi32> + %2 = "tosa.const_shape"() {value = dense<[1, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> + %1 = tosa.reshape %0, %2 : (tensor<3xi32>, !tosa.shape<2>) -> tensor<1x3xi32> return %0 , %1 : tensor<3xi32>, tensor<1x3xi32> } @@ -616,9 +626,10 @@ func.func @reshape_canonicalize_const_sparse() -> (tensor<3xi32>, tensor<1x3xi32 func.func @reshape_canonicalize_quant_nofold() -> (tensor<1x3x!quant.uniform>) { // disabled folding for quantized element types // CHECK{LITERAL}: "tosa.const"() <{value = dense<[1, 2, 3]> : tensor<3xi8>}> : () -> tensor<3x!quant.uniform> - // CHECK{LITERAL}: tosa.reshape %0 {new_shape = array} : (tensor<3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + // CHECK{LITERAL}: tosa.reshape %0, %1 : (tensor<3x!quant.uniform>, !tosa.shape<2>) -> tensor<1x3x!quant.uniform> %0 = "tosa.const"() {value = dense<[1, 2, 3]> : tensor<3xi8>} : ()-> tensor<3x!quant.uniform> - %1 = tosa.reshape %0 {new_shape = array} : (tensor<3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + %2 = "tosa.const_shape"() {value = dense<[1, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> + %1 = tosa.reshape %0, %2 : (tensor<3x!quant.uniform>, !tosa.shape<2>) -> tensor<1x3x!quant.uniform> return %1 : tensor<1x3x!quant.uniform> } @@ -626,8 +637,9 @@ func.func @reshape_canonicalize_quant_nofold() -> (tensor<1x3x!quant.uniform (tensor<2x1x3x!quant.uniform>) { - // CHECK: "tosa.const"() <{value = dense<0> : tensor<1x2x3xi8>}> : () -> tensor<1x2x3x!quant.uniform> - // CHECK: tosa.reshape %0 {new_shape = array} : (tensor<1x2x3x!quant.uniform>) -> tensor<2x1x3x!quant.uniform> + // CHECK-DAG: tosa.const_shape {value = dense<[2, 1, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> + // CHECK-DAG: "tosa.const"() <{value = dense<0> : tensor<1x2x3xi8>}> : () -> tensor<1x2x3x!quant.uniform> + // CHECK: tosa.reshape %0, %1 : (tensor<1x2x3x!quant.uniform>, !tosa.shape<3>) -> tensor<2x1x3x!quant.uniform> %perms = "tosa.const"() {value = dense<[1, 0, 2]> : tensor<3xi32>} : () -> tensor<3xi32> %0 = "tosa.const"() {value = dense<0> : tensor<1x2x3xi8>} : ()-> tensor<1x2x3x!quant.uniform> %1 = tosa.transpose %0, %perms : (tensor<1x2x3x!quant.uniform>, tensor<3xi32>) -> tensor<2x1x3x!quant.uniform> @@ -691,7 +703,8 @@ func.func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> { // CHECK-LABEL: @transpose_is_reshape func.func @transpose_is_reshape(%arg0: tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32> { - // CHECK: tosa.reshape %arg0 {new_shape = array} : (tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32> + // CHECK: %[[CONST0:.+]] = tosa.const_shape {value = dense<[1, 4, 1, 5]> : tensor<4xindex>} : () -> !tosa.shape<4> + // CHECK: tosa.reshape %arg0, %[[CONST0]] %perms = "tosa.const"() <{value = dense<[3, 1, 0, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> %0 = tosa.transpose %arg0, %perms : (tensor<1x4x5x1xf32>, tensor<4xi32>) -> tensor<1x4x1x5xf32> return %0 : tensor<1x4x1x5xf32> @@ -704,7 +717,8 @@ func.func @transpose_is_reshape(%arg0: tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf3 func.func @single_bit_reshape() -> tensor<1xi1> { // CHECK: "tosa.const"() <{value = dense : tensor<1xi1>} %0 = arith.constant dense : tensor<1x1xi1> - %1 = tosa.reshape %0 {new_shape = array} : (tensor<1x1xi1>) -> tensor<1xi1> + %2 = "tosa.const_shape"() <{value = dense<1> : tensor<1xindex>}> : () -> !tosa.shape<1> + %1 = tosa.reshape %0, %2 : (tensor<1x1xi1>, !tosa.shape<1>) -> tensor<1xi1> return %1 : tensor<1xi1> } @@ -870,8 +884,11 @@ func.func nested @fold_tile_rank_zero() -> tensor { // check that segfault is fixed func.func @reshape_quant_nofold() -> tensor<1x1x1x1xi32> { %0 = "tosa.const"() {value = dense<127> : tensor} : () -> tensor> - %1 = tosa.reshape %0 {new_shape = array} : (tensor>) -> tensor<1x1x1x1x!quant.uniform> - %2 = tosa.rescale %1 {double_round = true, input_zp = -128 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} : (tensor<1x1x1x1x!quant.uniform>) -> tensor<1x1x1x1xi32> + %cst0 = "tosa.const_shape"() {value = dense<[1, 1, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> + %1 = tosa.reshape %0, %cst0 : (tensor>, !tosa.shape<4>) -> tensor<1x1x1x1x!quant.uniform> + %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32> + %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8> + %2 = tosa.rescale %1 {double_round = true, input_zp = -128 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x!quant.uniform>) -> tensor<1x1x1x1xi32> return %2 : tensor<1x1x1x1xi32> } diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir index 32677f06e2252..40469987d89d0 100644 --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -500,7 +500,8 @@ func.func @fold_eq_i32(%arg0 : tensor<10xi32>) -> (tensor<10xi1>) { func.func @reshape_splat() -> tensor<6x5x4xi32> { // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{value = dense<42> : tensor<6x5x4xi32>} %splat = "tosa.const"() {value = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32> - %reshape = tosa.reshape %splat { new_shape = array } : (tensor<4x5x6xi32>) -> tensor<6x5x4xi32> + %const = tosa.const_shape {value = dense<[6, 5, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> + %reshape = tosa.reshape %splat, %const : (tensor<4x5x6xi32>, !tosa.shape<3>) -> tensor<6x5x4xi32> // CHECK: return %[[SPLAT]] return %reshape : tensor<6x5x4xi32> } diff --git a/mlir/test/Dialect/Tosa/inlining.mlir b/mlir/test/Dialect/Tosa/inlining.mlir index e892fdaa27750..2a3065e80d0ea 100644 --- a/mlir/test/Dialect/Tosa/inlining.mlir +++ b/mlir/test/Dialect/Tosa/inlining.mlir @@ -47,7 +47,8 @@ func.func @inlined_while_fn(%arg0: tensor, %arg1: tensor, %arg2: tenso } func.func private @while_body_50(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<10xi32>) -> (tensor, tensor, tensor, tensor<10xi32>) { %1 = "tosa.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - %3 = "tosa.reshape"(%1) {new_shape = array} : (tensor) -> tensor<1xi32> + %4 = "tosa.const_shape"() {value = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1> + %3 = "tosa.reshape"(%1, %4) : (tensor, !tosa.shape<1>) -> tensor<1xi32> %2 = "tosa.add"(%arg3, %3) : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32> return %1, %arg1, %arg2, %2: tensor, tensor, tensor, tensor<10xi32> } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 006c5bd52a9f6..2165e1f7ae3ba 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -316,7 +316,8 @@ func.func @test_transpose_element_type_mismatch(%arg0: tensor<2x3xi32>) -> tenso func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<273x2xf32> { %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> - %1 = tosa.reshape %arg0 {new_shape = array} : (tensor<13x21x3xf32>) -> tensor<273x3xf32> + %3 = tosa.const_shape {value = dense<[273, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> + %1 = tosa.reshape %arg0, %3 : (tensor<13x21x3xf32>, !tosa.shape<2>) -> tensor<273x3xf32> // expected-error@+1 {{'tosa.fully_connected' op weight of fully_connected is not constant}} %2 = tosa.fully_connected %1, %arg1, %0 : (tensor<273x3xf32>, tensor<2x3xf32>, tensor<2xf32>) -> tensor<273x2xf32> return %2 : tensor<273x2xf32> @@ -326,7 +327,8 @@ func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: ten func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<2xf32>) -> tensor<273x2xf32> { %0 = "tosa.const"() {value = dense<[[-0.613216758, -0.63714242, -0.73500061], [0.180762768, 0.773053169, -0.933686495]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %1 = tosa.reshape %arg0 {new_shape = array} : (tensor<13x21x3xf32>) -> tensor<273x3xf32> + %3 = tosa.const_shape {value = dense<[273, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> + %1 = tosa.reshape %arg0, %3 : (tensor<13x21x3xf32>, !tosa.shape<2>) -> tensor<273x3xf32> // expected-error@+1 {{'tosa.fully_connected' op bias of fully_connected is not constant}} %2 = tosa.fully_connected %1, %0, %arg1 : (tensor<273x3xf32>, tensor<2x3xf32>, tensor<2xf32>) -> tensor<273x2xf32> return %2 : tensor<273x2xf32> @@ -426,81 +428,91 @@ func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor) -> () { // ----- func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () { + %1 = tosa.const_shape {value = dense<[13, 21, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> // expected-error@+2 {{failed to infer returned types}} // expected-error@+1 {{'tosa.reshape' op inferred type(s) 'tensor<13x21x3x1xf32>' are incompatible with return type(s) of operation 'tensor<13x21x3x1xi32>'}} - %0 = tosa.reshape %arg0 {new_shape = array} : (tensor<13x21x3xf32>) -> tensor<13x21x3x1xi32> + %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf32>, !tosa.shape<4>) -> tensor<13x21x3x1xi32> return } // ----- func.func @test_reshape_static_zero_dim_input(%arg0 : tensor<13x0x3xf32>) -> () { + %s = tosa.const_shape {value = dense<[13, 21, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> // expected-error@+1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number values, but got 'tensor<13x0x3xf32>'}} - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<13x0x3xf32>) -> tensor<13x0x3xf32> + %0 = "tosa.reshape"(%arg0, %s) : (tensor<13x0x3xf32>, !tosa.shape<3>) -> tensor<13x0x3xf32> return } // ----- func.func @test_reshape_zero_dim_input(%arg0 : tensor) -> () { + %s = tosa.const_shape {value = dense<[13, 21, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> // expected-error@+1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number values, but got 'tensor'}} - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor<13x0x3xf32> + %0 = "tosa.reshape"(%arg0, %s) : (tensor, !tosa.shape<3>) -> tensor<13x0x3xf32> return } // ----- func.func @test_reshape_rank_mismatch(%arg0 : tensor) -> () { + %s = tosa.const_shape {value = dense<[2, 4]> : tensor<2xindex>} : () -> !tosa.shape<2> // expected-error@+1 {{'tosa.reshape' op new shape does not match result rank}} - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor + %0 = "tosa.reshape"(%arg0, %s) : (tensor, !tosa.shape<2>) -> tensor return } // ----- func.func @test_reshape_inconsistent_result_type(%arg0 : tensor) -> () { + %s = tosa.const_shape {value = dense<[2, 4, -1]> : tensor<3xindex>} : () -> !tosa.shape<3> // expected-error@+1 {{'tosa.reshape' op new shape is inconsistent with result shape}} - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor + %0 = "tosa.reshape"(%arg0, %s) : (tensor, !tosa.shape<3>) -> tensor return } // ----- func.func @test_reshape_invalid_size(%arg0 : tensor<2x4xf32>) -> () { + %s = tosa.const_shape {value = dense<[3, 5]> : tensor<2xindex>} : () -> !tosa.shape<2> // expected-error@+1 {{'tosa.reshape' op cannot reshape 8 elements into 15}} - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<2x4xf32>) -> tensor<3x5xf32> + %0 = "tosa.reshape"(%arg0, %s) : (tensor<2x4xf32>, !tosa.shape<2>) -> tensor<3x5xf32> return } // ----- func.func @test_reshape_invalid_newshape(%arg0 : tensor<1xf32>) -> () { + %s = tosa.const_shape {value = dense<[-1, 4]> : tensor<2xindex>} : () -> !tosa.shape<2> // expected-error@+1 {{'tosa.reshape' op cannot reshape 1 elements into 4}} - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<1xf32>) -> tensor + %0 = "tosa.reshape"(%arg0, %s) : (tensor<1xf32>, !tosa.shape<2>) -> tensor return } // ----- func.func @test_reshape_invalid_newshape(%arg0 : tensor<8xf32>) -> () { + %s = tosa.const_shape {value = dense<[1, 4]> : tensor<2xindex>} : () -> !tosa.shape<2> // expected-error@+1 {{'tosa.reshape' op cannot reshape 8 elements into 4}} - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<8xf32>) -> tensor + %0 = "tosa.reshape"(%arg0, %s) : (tensor<8xf32>, !tosa.shape<2>) -> tensor return } // ----- func.func @test_reshape_invalid_placeholders(%arg0 : tensor) -> () { + %s = tosa.const_shape {value = dense<[2, -1, -1]> : tensor<3xindex>} : () -> !tosa.shape<3> // expected-error@+1 {{'tosa.reshape' op expected at most one target dimension to be -1}} - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor<2x?x?xf32> + %0 = "tosa.reshape"(%arg0, %s) : (tensor, !tosa.shape<3>) -> tensor<2x?x?xf32> return } // ----- func.func @test_reshape_invalid_tensor_dim(%arg0 : tensor<4x?xf32>) -> () { + %s = tosa.const_shape {value = dense<[-2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2> // expected-error@+1 {{'tosa.reshape' op new shape has invalid tensor dimension size -2}} - %0 = "tosa.reshape" (%arg0) {new_shape = array} : (tensor<4x?xf32>) -> tensor + %0 = "tosa.reshape" (%arg0, %s) : (tensor<4x?xf32>, !tosa.shape<2>) -> tensor return } @@ -514,6 +526,15 @@ func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () { // ----- +func.func @test_reshape_zero_dim_input(%arg0 : tensor) -> () { + %1 = tosa.const_shape {value = dense<[13, 21, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> + // expected-error@+1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number values, but got 'tensor'}} + %0 = "tosa.reshape"(%arg0, %1) : (tensor, !tosa.shape<3>) -> tensor<13x0x3xf32> + return +} + +// ----- + func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> { // expected-error@+1 {{'tosa.const' op failed to verify that all of {value, output} have same shape}} %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<100x100xf32> diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 26bebdd898a0d..a7f76f2d0fa64 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -70,8 +70,9 @@ func.func @test_concat(%arg0: tensor<1x1x1x13x21x3x8xf32>, %arg1: tensor<1x1x1x1 // ----- func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x1x1x1x1x1x819xf32> { + %1 = tosa.const_shape {value = dense<[1, 1, 1, 1, 1, 1, 819]> : tensor<7xindex>} : () -> !tosa.shape<7> // expected-error@+1 {{'tosa.reshape' op failed level check: result rank(shape) <= MAX_RANK}} - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<13x21x3xf32>) -> tensor<1x1x1x1x1x1x819xf32> + %0 = "tosa.reshape"(%arg0, %1) : (tensor<13x21x3xf32>, !tosa.shape<7>) -> tensor<1x1x1x1x1x1x819xf32> return %0 : tensor<1x1x1x1x1x1x819xf32> } diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index d00230d12aab1..baf09e089aa30 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -504,7 +504,8 @@ func.func @test_greater_equal(%arg0: tensor<13x1x3xf32>, %arg1: tensor<13x21x3xf // CHECK-LABEL: reduce_all func.func @test_reduce_all(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1> - %1 = tosa.reshape %0 {new_shape = array} : (tensor<1x21x3xi1>) -> tensor<21x3xi1> + %2 = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> + %1 = tosa.reshape %0, %2 : (tensor<1x21x3xi1>, !tosa.shape<2>) -> tensor<21x3xi1> return %1 : tensor<21x3xi1> } @@ -512,7 +513,8 @@ func.func @test_reduce_all(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { // CHECK-LABEL: reduce_any func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { %0 = tosa.reduce_any %arg0 {axis = 0 : i32} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1> - %1 = tosa.reshape %0 {new_shape = array} : (tensor<1x21x3xi1>) -> tensor<21x3xi1> + %2 = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> + %1 = tosa.reshape %0, %2 : (tensor<1x21x3xi1>, !tosa.shape<2>) -> tensor<21x3xi1> return %1 : tensor<21x3xi1> } @@ -520,7 +522,8 @@ func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { // CHECK-LABEL: reduce_max func.func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %0 = tosa.reduce_max %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32> - %1 = tosa.reshape %0 {new_shape = array} : (tensor<1x21x3xf32>) -> tensor<21x3xf32> + %2 = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> + %1 = tosa.reshape %0, %2 : (tensor<1x21x3xf32>, !tosa.shape<2>) -> tensor<21x3xf32> return %1 : tensor<21x3xf32> } @@ -528,7 +531,8 @@ func.func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // CHECK-LABEL: reduce_min func.func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %0 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32> - %1 = tosa.reshape %0 {new_shape = array} : (tensor<1x21x3xf32>) -> tensor<21x3xf32> + %2 = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> + %1 = tosa.reshape %0, %2 : (tensor<1x21x3xf32>, !tosa.shape<2>) -> tensor<21x3xf32> return %1 : tensor<21x3xf32> } @@ -536,7 +540,8 @@ func.func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // CHECK-LABEL: reduce_product func.func @test_reduce_product(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %0 = tosa.reduce_prod %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32> - %1 = tosa.reshape %0 {new_shape = array} : (tensor<1x21x3xf32>) -> tensor<21x3xf32> + %2 = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> + %1 = tosa.reshape %0, %2 : (tensor<1x21x3xf32>, !tosa.shape<2>) -> tensor<21x3xf32> return %1 : tensor<21x3xf32> } @@ -544,7 +549,8 @@ func.func @test_reduce_product(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // CHECK-LABEL: reduce_sum func.func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32> - %1 = tosa.reshape %0 {new_shape = array} : (tensor<1x21x3xf32>) -> tensor<21x3xf32> + %2 = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> + %1 = tosa.reshape %0, %2 : (tensor<1x21x3xf32>, !tosa.shape<2>) -> tensor<21x3xf32> return %1 : tensor<21x3xf32> } @@ -575,7 +581,8 @@ func.func @test_pad_explicit_value(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3 // ----- // CHECK-LABEL: reshape func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> { - %0 = tosa.reshape %arg0 {new_shape = array} : (tensor<13x21x3xf32>) -> tensor<1x819xf32> + %1 = tosa.const_shape {value = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2> + %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf32>, !tosa.shape<2>) -> tensor<1x819xf32> return %0 : tensor<1x819xf32> } @@ -724,7 +731,8 @@ func.func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor) { ^bb0(%arg2: tensor, %arg3: tensor, %arg4: tensor<10xi32>): %2 = "tosa.const"() {value = dense<1> : tensor} : () -> tensor %3 = tosa.add %arg3, %2 : (tensor, tensor) -> tensor - %4 = tosa.reshape %2 {new_shape = array} : (tensor) -> tensor<1xi32> + %7 = tosa.const_shape {value = dense<[1]> : tensor<1xindex>} : () -> !tosa.shape<1> + %4 = tosa.reshape %2, %7 : (tensor, !tosa.shape<1>) -> tensor<1xi32> %5 = tosa.add %arg4, %4 : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32> %6 = tosa.add %arg2, %2 : (tensor, tensor) -> tensor tosa.yield %6, %3, %5 : tensor, tensor, tensor<10xi32> diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir index e4a2897908072..9aade2fe45eb6 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir @@ -5,13 +5,16 @@ // CHECK-LABEL: @conv2d_as_fully_connected func.func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> { // CHECK-NOT: tosa.conv2d - // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array} + // CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {value = dense<[400, 2]> : tensor<2xindex>} : () -> !tosa.shape<2> + // CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {value = dense<[3, 2]> : tensor<2xindex>} : () -> !tosa.shape<2> + // CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {value = dense<[4, 10, 10, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> + // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0, %[[CONST0]] // CHECK-SAME: -> tensor<400x2xf32> - // CHECK: %[[VAR1:.*]] = tosa.reshape %arg1 {new_shape = array} + // CHECK: %[[VAR1:.*]] = tosa.reshape %arg1, %[[CONST1]] // CHECK-SAME: -> tensor<3x2xf32> // CHECK: %[[VAR2:.*]] = tosa.fully_connected %[[VAR0]], %[[VAR1]], %arg2 // CHECK-SAME: -> tensor<400x3xf32> - // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array} + // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]], %[[CONST2]] // CHECK-SAME: -> tensor<4x10x10x3xf32> // CHECK: return %[[VAR3]] %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32> @@ -23,14 +26,17 @@ func.func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor // CHECK-LABEL: @conv2d_as_fully_connected_quant func.func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x10x10x3xi32> { // CHECK-NOT: tosa.conv2d - // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array} + // CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {value = dense<[400, 2]> : tensor<2xindex>} : () -> !tosa.shape<2> + // CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {value = dense<[3, 2]> : tensor<2xindex>} : () -> !tosa.shape<2> + // CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {value = dense<[4, 10, 10, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> + // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0, %[[CONST0]] // CHECK-SAME: -> tensor<400x2xi8> - // CHECK: %[[VAR1:.*]] = tosa.reshape %arg1 {new_shape = array} + // CHECK: %[[VAR1:.*]] = tosa.reshape %arg1, %[[CONST1]] // CHECK-SAME: -> tensor<3x2xi8> // CHECK: %[[VAR2:.*]] = tosa.fully_connected %[[VAR0]], %[[VAR1]], %arg2 // CHECK-SAME: {input_zp = 42 : i32, weight_zp = 24 : i32} // CHECK-SAME: -> tensor<400x3xi32> - // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array} + // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]], %[[CONST2]] // CHECK-SAME: -> tensor<4x10x10x3xi32> // CHECK: return %[[VAR3]] %input_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8> @@ -42,14 +48,14 @@ func.func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: t // ----- // CHECK-LABEL: func.func @conv_with_dynamic_dim( -// CHECK-SAME: %[[VAL_0:.*]]: tensor, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<384x1x1x64xi8>, -// CHECK-SAME: %[[VAL_2:.*]]: tensor<384xi32>) -> tensor { func.func @conv_with_dynamic_dim(%arg0: tensor, %arg1: tensor<384x1x1x64xi8>, %arg2: tensor<384xi32>) -> tensor { -// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_0]] {new_shape = array} : (tensor) -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<384x1x1x64xi8>) -> tensor<384x64xi8> -// CHECK: %[[VAL_5:.*]] = tosa.fully_connected %[[VAL_3]], %[[VAL_4]], %[[VAL_2]] {input_zp = -6 : i32, weight_zp = 11 : i32} : (tensor, tensor<384x64xi8>, tensor<384xi32>) -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor +// CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {value = dense<[-1, 64]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {value = dense<[384, 64]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {value = dense<[-1, 14, 14, 384]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_3:.*]] = tosa.reshape %arg0, %[[CONST0]] +// CHECK: %[[VAL_4:.*]] = tosa.reshape %arg1, %[[CONST1]] : (tensor<384x1x1x64xi8>, !tosa.shape<2>) -> tensor<384x64xi8> +// CHECK: %[[VAL_5:.*]] = tosa.fully_connected %[[VAL_3]], %[[VAL_4]], %arg2 {input_zp = -6 : i32, weight_zp = 11 : i32} : (tensor, tensor<384x64xi8>, tensor<384xi32>) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]], %[[CONST2]] : (tensor, !tosa.shape<4>) -> tensor // CHECK: return %[[VAL_6]] : tensor // CHECK: } %input_zp = "tosa.const"() {value = dense<-6> : tensor<1xi8>} : () -> tensor<1xi8> @@ -62,15 +68,19 @@ func.func @conv_with_dynamic_dim(%arg0: tensor, %arg1: tensor<384 // CHECK-LABEL: @conv2d_as_fully_connected_padded func.func @conv2d_as_fully_connected_padded(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x12x12x3xi32> { + // CHECK-DAG: %[[FULLY_NEW_SHAPE:.+]] = tosa.const_shape {value = dense<[4, 12, 12, 3]> : tensor<4xindex>} + // CHECK-DAG: %[[INPUT_NEW_SHAPE:.+]] = tosa.const_shape {value = dense<[576, 2]> : tensor<2xindex>} + // CHECK-DAG: %[[FILTER_NEW_SHAPE:.+]] = tosa.const_shape {value = dense<[3, 2]> : tensor<2xindex>} // CHECK-DAG: %[[PAD_SHAPE:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> // CHECK-DAG: %[[PAD_VAL:.+]] = "tosa.const"() <{value = dense<42> : tensor} // CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[PAD_SHAPE]], %[[PAD_VAL]] : (tensor<4x10x10x2xi8>, !tosa.shape<8>, tensor) -> tensor<4x12x12x2xi8> - // CHECK-DAG: %[[RESHAPE_INPUT:.+]] = tosa.reshape %[[PAD]] {new_shape = array} - // CHECK-DAG: %[[RESHAPE_FILTER:.+]] = tosa.reshape %arg1 {new_shape = array} + // CHECK-DAG: %[[RESHAPE_INPUT:.+]] = tosa.reshape %[[PAD]], %[[INPUT_NEW_SHAPE]] + // CHECK-DAG: %[[RESHAPE_FILTER:.+]] = tosa.reshape %arg1, %[[FILTER_NEW_SHAPE]] // CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {input_zp = 42 : i32, weight_zp = 24 : i32} - // CHECK: %[[RESHAPE:.+]] = tosa.reshape %[[FULLY]] {new_shape = array} + // CHECK: %[[RESHAPE:.+]] = tosa.reshape %[[FULLY]], %[[FULLY_NEW_SHAPE]] %input_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8> %weight_zp = "tosa.const"() {value = dense<24> : tensor<1xi8>} : () -> tensor<1xi8> %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, pad = array, stride = array, dilation = array} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x12x12x3xi32> return %0 : tensor<4x12x12x3xi32> } + diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir index ce29d1a498b4f..6562a7c2ab55c 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir @@ -5,15 +5,19 @@ // CHECK-LABEL: @depthwise_conv2d_as_mul func.func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> { // CHECK-NOT: tosa.depthwise_conv2d - // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array} + // CHECK-DAG: %[[CONST0:.+]] = tosa.const_shape {value = dense<[4, 10, 10, 2, 1]> : tensor<5xindex> + // CHECK-DAG: %[[CONST1:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 2, 3]> : tensor<5xindex> + // CHECK-DAG: %[[CONST2:.+]] = tosa.const_shape {value = dense<[4, 10, 10, 6]> : tensor<4xindex> + // CHECK-DAG: %[[CONST3:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 6]> : tensor<4xindex> + // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0, %[[CONST0]] // CHECK-SAME: -> tensor<4x10x10x2x1xf32> - // CHECK: %[[VAR1:.*]] = tosa.reshape %arg1 {new_shape = array} + // CHECK: %[[VAR1:.*]] = tosa.reshape %arg1, %[[CONST1]] // CHECK-SAME: -> tensor<1x1x1x2x3xf32> // CHECK: %[[VAR2:.*]] = tosa.mul %[[VAR0]], %[[VAR1]] // CHECK-SAME: -> tensor<4x10x10x2x3xf32> - // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array} + // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]], %[[CONST2]] // CHECK-SAME: -> tensor<4x10x10x6xf32> - // CHECK: %[[VAR4:.*]] = tosa.reshape %arg2 {new_shape = array} + // CHECK: %[[VAR4:.*]] = tosa.reshape %arg2, %[[CONST3]] // CHECK-SAME: -> tensor<1x1x1x6xf32> // CHECK: %[[VAR5:.*]] = tosa.add %[[VAR3]], %[[VAR4]] // CHECK-SAME: -> tensor<4x10x10x6xf32> @@ -26,17 +30,22 @@ func.func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1 // CHECK-LABEL: @depthwise_conv2d_as_mul_q func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> { + // CHECK-DAG: %[[CONST0:.+]] = tosa.const_shape {value = dense<[4, 10, 10, 2, 1]> : tensor<5xindex> // CHECK-DAG: %[[iZp:.+]] = "tosa.const"() <{value = dense<7> : tensor<1x1x1x1x1xi32>} // CHECK-DAG: %[[wZp:.+]] = "tosa.const"() <{value = dense<11> : tensor<1x1x1x1xi32>} - // CHECK: %[[rIn:.+]] = tosa.reshape %arg0 {new_shape = array} + // CHECK-DAG: %[[CONST3:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 2, 3]> : tensor<5xindex> + // CHECK-DAG: %[[CONST4:.+]] = tosa.const_shape {value = dense<[4, 10, 10, 6]> : tensor<4xindex> + // CHECK-DAG: %[[CONST5:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 6]> : tensor<4xindex> + // CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + // CHECK: %[[rIn:.+]] = tosa.reshape %arg0, %[[CONST0]] // CHECK: %[[cIn:.+]] = tosa.cast %[[rIn]] : (tensor<4x10x10x2x1xi8>) -> tensor<4x10x10x2x1xi32> // CHECK: %[[cWe:.+]] = tosa.cast %arg1 : (tensor<1x1x2x3xi8>) -> tensor<1x1x2x3xi32> // CHECK: %[[sIn:.+]] = tosa.sub %[[cIn]], %[[iZp]] // CHECK: %[[sWe:.+]] = tosa.sub %[[cWe]], %[[wZp]] - // CHECK: %[[resWe:.+]] = tosa.reshape %[[sWe]] {new_shape = array} - // CHECK: %[[mul:.+]] = tosa.mul %[[sIn]], %[[resWe]] - // CHECK: %[[reO:.+]] = tosa.reshape %[[mul]] {new_shape = array} - // CHECK: %[[reArg2:.+]] = tosa.reshape %arg2 {new_shape = array} + // CHECK: %[[resWe:.+]] = tosa.reshape %[[sWe]], %[[CONST3]] + // CHECK: %[[mul:.+]] = tosa.mul %[[sIn]], %[[resWe]], %[[SHIFT]] + // CHECK: %[[reO:.+]] = tosa.reshape %[[mul]], %[[CONST4]] + // CHECK: %[[reArg2:.+]] = tosa.reshape %arg2, %[[CONST5]] // CHECK: %[[add:.+]] = tosa.add %[[reO]], %[[reArg2]] %input_zp = "tosa.const"() {value = dense<7> : tensor<1xi8>} : () -> tensor<1xi8> %weight_zp = "tosa.const"() {value = dense<11> : tensor<1xi8>} : () -> tensor<1xi8> @@ -48,14 +57,19 @@ func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor< // CHECK-LABEL: @depthwise_conv2d_as_mul_padded func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x12x12x6xf32> { - // CHECK-DAG: %[[pad:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0, 0, 0]> : tensor<10xindex>} : () -> !tosa.shape<10> + // CHECK-DAG: %[[CONST0:.+]] = tosa.const_shape {value = dense<[4, 10, 10, 2, 1]> : tensor<5xindex>} + // CHECK-DAG: %[[pad:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0, 0, 0]> : tensor<10xindex>} : () -> !tosa.shape<10> // CHECK-DAG: %[[zero:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor} - // CHECK: %[[reIn:.+]] = tosa.reshape %arg0 {new_shape = array} + // CHECK-DAG: %[[CONST3:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 2, 3]> : tensor<5xindex>} + // CHECK-DAG: %[[CONST4:.+]] = tosa.const_shape {value = dense<[4, 12, 12, 6]> : tensor<4xindex>} + // CHECK-DAG: %[[CONST5:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 6]> : tensor<4xindex>} + // CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + // CHECK: %[[reIn:.+]] = tosa.reshape %arg0, %[[CONST0]] // CHECK: %[[padded:.+]] = tosa.pad %[[reIn]], %[[pad]], %[[zero]] : (tensor<4x10x10x2x1xf32>, !tosa.shape<10>, tensor) -> tensor<4x12x12x2x1xf32> - // CHECK: %[[reArg1:.+]] = tosa.reshape %arg1 {new_shape = array} - // CHECK: %[[mul:.+]] = tosa.mul %[[padded]], %[[reArg1]] - // CHECK: %[[reOut:.+]] = tosa.reshape %[[mul]] {new_shape = array} - // CHECK: %[[reArg2:.+]] = tosa.reshape %arg2 {new_shape = array} + // CHECK: %[[reArg1:.+]] = tosa.reshape %arg1, %[[CONST3]] + // CHECK: %[[mul:.+]] = tosa.mul %[[padded]], %[[reArg1]], %[[SHIFT]] + // CHECK: %[[reOut:.+]] = tosa.reshape %[[mul]], %[[CONST4]] + // CHECK: %[[reArg2:.+]] = tosa.reshape %arg2, %[[CONST5]] // CHECK: %[[add:.+]] = tosa.add %[[reOut]], %[[reArg2]] %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x12x12x6xf32> return %0 : tensor<4x12x12x6xf32> diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir index 82838cc7e1545..bd18b7ea0fdff 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir @@ -56,11 +56,15 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor< // CHECK-DAG: %[[PADV:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> // CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>} // CHECK-DAG: %[[PADW:.+]] = tosa.pad %arg1, %[[PADV]] - // CHECK-DAG: %[[RESW1:.+]] = tosa.reshape %[[PADW]] {new_shape = array} + // CHECK-DAG: %[[CONST1:.+]] = tosa.const_shape {value = dense<[5, 2, 2, 2, 3, 3]> : tensor<6xindex>} + // CHECK-DAG: %[[RESW1:.+]] = tosa.reshape %[[PADW]], %[[CONST1]] // CHECK-DAG: %[[TRANS:.+]] = tosa.transpose %[[RESW1]], %[[TRANSV]] - // CHECK-DAG: %[[RESW2:.+]] = tosa.reshape %[[TRANS]] {new_shape = array} + // CHECK-DAG: %[[CONST3:.+]] = tosa.const_shape {value = dense<[30, 2, 2, 3]> : tensor<4xindex>} + // CHECK-DAG: %[[RESW2:.+]] = tosa.reshape %[[TRANS]], %[[CONST3]] // CHECK-DAG: %[[REV1:.+]] = tosa.reverse %[[RESW2]] {axis = 1 : i32} // CHECK-DAG: %[[NEWWEIGHT:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32} + // CHECK-DAG: %[[SIZE:.*]] = tosa.const_shape {value = dense<[2, 35, 47, 5]> : tensor<4xindex>} : () -> !tosa.shape<4> + // CHECK-DAG: %[[START:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> // Pad out the input matrix to handle the transpose conv. // CHECK-DAG: %[[PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> @@ -70,13 +74,14 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor< // Manipulate the final shape. // CHECK-DAG: %[[BIAS:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<30xf32>} // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {acc_type = f32, dilation = array, pad = array, stride = array} - // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]] {new_shape = array} + // CHECK-DAG: %[[CONST6:.+]] = tosa.const_shape {value = dense<[2, 18, 16, 2, 3, 5]> : tensor<6xindex>} + // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]], %[[CONST6]] // CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]], %[[TRANS2]] - // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]] - // CHECK-DAG: %[[START:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> - // CHECK-DAG: %[[SIZE:.*]] = tosa.const_shape {value = dense<[2, 35, 47, 5]> : tensor<4xindex>} : () -> !tosa.shape<4> - // CHECK-DAG: %[[SLICE:.*]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]] - // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 + // CHECK-DAG: %[[CONST8:.+]] = tosa.const_shape {value = dense<[2, 36, 48, 5]> : tensor<4xindex> + // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]], %[[CONST8]] + // CHECK-DAG: %[[SLICE:.+]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]] + // CHECK-DAG: %[[CONST9:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 5]> : tensor<4xindex>} + // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2, %[[CONST9]] // CHECK: %[[ADD:.+]] = tosa.add %[[SLICE]], %[[RESHAPE_ARG2]] %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2{acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32> %1 = tensor.cast %0 : tensor<2x35x47x5xf32> to tensor<2x?x?x5xf32> @@ -92,11 +97,15 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1 // CHECK-DAG: %[[PADV:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> // CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>} // CHECK-DAG: %[[PADW:.+]] = tosa.pad %arg1, %[[PADV]] {input_zp = 42 : i32} - // CHECK-DAG: %[[RESW1:.+]] = tosa.reshape %[[PADW]] {new_shape = array} + // CHECK-DAG: %[[CONST1:.+]] = tosa.const_shape {value = dense<[5, 2, 2, 2, 3, 3]> : tensor<6xindex>} + // CHECK-DAG: %[[RESW1:.+]] = tosa.reshape %[[PADW]], %[[CONST1]] // CHECK-DAG: %[[TRANS:.+]] = tosa.transpose %[[RESW1]], %[[TRANSV]] - // CHECK-DAG: %[[RESW2:.+]] = tosa.reshape %[[TRANS]] {new_shape = array} + // CHECK-DAG: %[[CONST3:.+]] = tosa.const_shape {value = dense<[30, 2, 2, 3]> : tensor<4xindex>} + // CHECK-DAG: %[[RESW2:.+]] = tosa.reshape %[[TRANS]], %[[CONST3]] // CHECK-DAG: %[[REV1:.+]] = tosa.reverse %[[RESW2]] {axis = 1 : i32} // CHECK-DAG: %[[NEWWEIGHT:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32} + // CHECK-DAG: %[[SIZE:.*]] = tosa.const_shape {value = dense<[2, 35, 47, 5]> : tensor<4xindex>} : () -> !tosa.shape<4> + // CHECK-DAG: %[[START:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> // Pad out the input matrix to handle the transpose conv. // CHECK-DAG: %[[PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> @@ -108,13 +117,14 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1 // CHECK-DAG: %[[INPUT_ZP:.+]] = "tosa.const"() <{value = dense<-22> : tensor<1xi8>} // CHECK-DAG: %[[WEIGHT_ZP:.+]] = "tosa.const"() <{value = dense<42> : tensor<1xi8>} // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = i32, dilation = array, pad = array, stride = array} - // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]] {new_shape = array} + // CHECK-DAG: %[[CONV_NEW_SHAPE:.*]] = tosa.const_shape {value = dense<[2, 18, 16, 2, 3, 5]> : tensor<6xindex>} + // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]], %[[CONV_NEW_SHAPE]] // CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]], %[[TRANS2]] - // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]] - // CHECK-DAG: %[[START:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} - // CHECK-DAG: %[[SIZE:.*]] = tosa.const_shape {value = dense<[2, 35, 47, 5]> : tensor<4xindex>} - // CHECK-DAG: %[[SLICE:.*]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]] - // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 + // CHECK-DAG: %[[TEANS_NEW_SHAPE:.+]] = tosa.const_shape {value = dense<[2, 36, 48, 5]> : tensor<4xindex>} + // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]], %[[TEANS_NEW_SHAPE]] + // CHECK-DAG: %[[SLICE:.+]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]] + // CHECK-DAG: %[[ARG2_NEW_SHAPE:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 5]> : tensor<4xindex>} + // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2, %[[ARG2_NEW_SHAPE]] // CHECK: %[[ADD:.+]] = tosa.add %[[SLICE]], %[[RESHAPE_ARG2]] %input_zp = "tosa.const"() {value = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8> %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8> @@ -126,25 +136,31 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1 // CHECK-LABEL: @transpose_conv2d_strided_overpad func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 : tensor<1x2x1x1xi8>, %arg2 : tensor<1xi32>) -> (tensor<1x19x2x1xi32>) { - // CHECK-DAG: %[[WEIGHT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 0, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> + // CHECK-DAG: %[[WEIGHT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 0, 0, 1, 0, 0]> : tensor<8xindex>} + // CHECK-DAG: %[[CONST1:.+]] = tosa.const_shape {value = dense<[1, 2, 1, 1, 2, 1]> : tensor<6xindex>} // CHECK-DAG: %[[WEIGHT_PERMS:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>} - // CHECK-DAG: %[[INPUT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> + // CHECK-DAG: %[[CONST3:.+]] = tosa.const_shape {value = dense<[2, 2, 1, 1]> : tensor<4xindex>} + // CHECK-DAG: %[[INPUT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xindex>} // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<2xi32>} + // CHECK-DAG: %[[CONST6:.+]] = tosa.const_shape {value = dense<[1, 17, 1, 1, 2, 1]> : tensor<6xindex>} // CHECK-DAG: %[[RESULT_PERMS:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>} - // CHECK-DAG: %[[RESULT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 2, 0, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> + // CHECK-DAG: %[[CONST8:.+]] = tosa.const_shape {value = dense<[1, 17, 2, 1]> : tensor<4xindex>} + // CHECK-DAG: %[[RESULT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 2, 0, 0, 0, 0, 0]> : tensor<8xindex>} + // CHECK-DAG: %[[CONST10:.+]] = tosa.const_shape {value = dense<1> : tensor<4xindex>} + // CHECK-DAG: %[[INPUT_ZP:.*]] = "tosa.const"() <{value = dense<-103> : tensor<1xi8>}> + // CHECK-DAG: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{value = dense<93> : tensor<1xi8>}> // CHECK: %[[PAD_WEIGHT:.+]] = tosa.pad %arg1, %[[WEIGHT_PAD]] {input_zp = 93 : i32} - // CHECK: %[[RESHAPE_WEIGHT_0:.+]] = tosa.reshape %[[PAD_WEIGHT]] {new_shape = array} + // CHECK: %[[RESHAPE_WEIGHT_0:.+]] = tosa.reshape %[[PAD_WEIGHT]], %[[CONST1]] // CHECK: %[[TRANSPOSE_WEIGHT:.+]] = tosa.transpose %[[RESHAPE_WEIGHT_0]], %[[WEIGHT_PERMS]] - // CHECK: %[[RESHAPE_WEIGHT_1:.+]] = tosa.reshape %[[TRANSPOSE_WEIGHT]] {new_shape = array} + // CHECK: %[[RESHAPE_WEIGHT_1:.+]] = tosa.reshape %[[TRANSPOSE_WEIGHT]], %[[CONST3]] // CHECK: %[[REVERSE:.+]] = tosa.reverse %[[RESHAPE_WEIGHT_1]] {axis = 1 : i32} // CHECK: %[[PAD_INPUT:.+]] = tosa.pad %arg0, %[[INPUT_PAD]] {input_zp = -103 : i32} - // CHECK: %[[CONV:.+]] = tosa.conv2d %[[PAD_INPUT]], %[[REVERSE]], %[[ZERO]] - // CHECK-SAME{literal}: dilation = [1, 1], pad = [0, 0, 0, 0], input_zp = -103 : i32, weight_zp = 93 : i32, stride = [1, 1]} - // CHECK: %[[RESHAPE_RESULT_0:.+]] = tosa.reshape %[[CONV]] {new_shape = array} + // CHECK: %[[CONV:.+]] = tosa.conv2d %[[PAD_INPUT]], %[[REVERSE]], %[[ZERO]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = i32, dilation = array, pad = array, stride = array} + // CHECK: %[[RESHAPE_RESULT_0:.+]] = tosa.reshape %[[CONV]], %[[CONST6]] // CHECK: %[[TRANSPOSE_RESULT:.+]] = tosa.transpose %[[RESHAPE_RESULT_0]], %[[RESULT_PERMS]] - // CHECK: %[[RESHAPE_RESULT_1:.+]] = tosa.reshape %[[TRANSPOSE_RESULT]] {new_shape = array} + // CHECK: %[[RESHAPE_RESULT_1:.+]] = tosa.reshape %[[TRANSPOSE_RESULT]], %[[CONST8]] // CHECK: %[[PAD_RESULT:.+]] = tosa.pad %[[RESHAPE_RESULT_1]], %[[RESULT_PAD]] - // CHECK: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 {new_shape = array} + // CHECK: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2, %[[CONST10]] // CHECK: %[[ADD:.+]] = tosa.add %[[PAD_RESULT]], %[[RESHAPE_ARG2]] %input_zp = "tosa.const"() {value = dense<-103> : tensor<1xi8>} : () -> tensor<1xi8> %weight_zp = "tosa.const"() {value = dense<93> : tensor<1xi8>} : () -> tensor<1xi8> diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index 73eabab657f38..bdd403567a4ed 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -376,29 +376,42 @@ func.func @test_table_dynamic(%arg0 : tensor<4x?xi16>, %arg1 : tensor<513xi16>) // CHECK-LABEL: @test_static_reshape func.func @test_static_reshape(%arg0 : tensor<4x4xi32>) -> () { - // CHECK: tosa.reshape %arg0 {new_shape = array} : (tensor<4x4xi32>) -> tensor<16xi32> - %0 = tosa.reshape %arg0 {new_shape = array} : (tensor<4x4xi32>) -> tensor + // CHECK: %[[CONST3:.+]] = tosa.const_shape {value = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1> + %3 = tosa.const_shape {value = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1> + // CHECK: tosa.reshape %arg0, %[[CONST3]] : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32> + %0 = tosa.reshape %arg0, %3 : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32> - // CHECK: tosa.reshape %arg0 {new_shape = array} : (tensor<4x4xi32>) -> tensor<16xi32> - %1 = tosa.reshape %arg0 {new_shape = array} : (tensor<4x4xi32>) -> tensor + // CHECK: %[[CONST4:.+]] = tosa.const_shape {value = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1> + // CHECK: tosa.reshape %arg0, %[[CONST4]] : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32> + %4 = tosa.const_shape {value = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1> + %1 = tosa.reshape %arg0, %4 : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32> - // CHECK: tosa.reshape %arg0 {new_shape = array} : (tensor<4x4xi32>) -> tensor<2x8xi32> - %2 = tosa.reshape %arg0 {new_shape = array} : (tensor<4x4xi32>) -> tensor + // CHECK: %[[CONST5:.+]] = tosa.const_shape {value = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // CHECK: tosa.reshape %arg0, %[[CONST5]] : (tensor<4x4xi32>, !tosa.shape<2>) -> tensor<2x8xi32> + %5 = tosa.const_shape {value = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %2 = tosa.reshape %arg0, %5 : (tensor<4x4xi32>, !tosa.shape<2>) -> tensor<2x8xi32> return } + // ----- // CHECK-LABEL: @test_dynamic_reshape func.func @test_dynamic_reshape(%arg0 : tensor<4x?xi32>) -> () { - // CHECK: %0 = tosa.reshape %arg0 {new_shape = array} : (tensor<4x?xi32>) -> tensor<16xi32> - %0 = tosa.reshape %arg0 {new_shape = array} : (tensor<4x?xi32>) -> tensor - - // CHECK: %1 = tosa.reshape %arg0 {new_shape = array} : (tensor<4x?xi32>) -> tensor - %1 = tosa.reshape %arg0 {new_shape = array} : (tensor<4x?xi32>) -> tensor - - // CHECK: %2 = tosa.reshape %arg0 {new_shape = array} : (tensor<4x?xi32>) -> tensor<2x?xi32> - %2 = tosa.reshape %arg0 {new_shape = array} : (tensor<4x?xi32>) -> tensor + // CHECK: %0 = tosa.const_shape {value = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1> + %0 = tosa.const_shape {value = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1> + // CHECK: %1 = tosa.reshape %arg0, %0 : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<16xi32> + %1 = tosa.reshape %arg0, %0 : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor + + // CHECK: %2 = tosa.const_shape {value = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1> + %2 = tosa.const_shape {value = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1> + // CHECK: %3 = tosa.reshape %arg0, %2 : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor + %3 = tosa.reshape %arg0, %2 : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor + + // CHECK: %4 = tosa.const_shape {value = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %4 = tosa.const_shape {value = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // CHECK: %5 = tosa.reshape %arg0, %4 : (tensor<4x?xi32>, !tosa.shape<2>) -> tensor<2x?xi32> + %5 = tosa.reshape %arg0, %4 : (tensor<4x?xi32>, !tosa.shape<2>) -> tensor return } diff --git a/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir b/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir index f274eb9c10a81..947335e45a9d9 100644 --- a/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir @@ -141,12 +141,14 @@ func.func @test_mulop_conversion(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x3 // COM: this case is a reshape we don't convert, since can't fold the transpose into it. // COM: a transform actually occurs underneath the hood, but it results in identical IR. // CHECK-LABEL: @test_basic_non_broadcasting_reshape -// CHECK: "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: tosa.reshape %arg0 {new_shape = array} : (tensor<2x3xi32>) -> tensor<1x3x2xi32> -// CHECK: tosa.transpose %1, %0 : (tensor<1x3x2xi32>, tensor<3xi32>) -> tensor<1x2x3xi32> +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[1, 3, 2]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> +// CHECK: %[[VAL_3:.*]] = tosa.reshape %arg0, %[[VAL_1]] : (tensor<2x3xi32>, !tosa.shape<3>) -> tensor<1x3x2xi32> +// CHECK: %[[VAL_4:.*]] = tosa.transpose %[[VAL_3]], %[[VAL_2]] : (tensor<1x3x2xi32>, tensor<3xi32>) -> tensor<1x2x3xi32> func.func @test_basic_non_broadcasting_reshape(%arg0: tensor<2x3xi32>) -> tensor<1x2x3xi32> { + %shape = tosa.const_shape {value = dense<[1, 3, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> %perms = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32> - %1 = tosa.reshape %arg0 {new_shape = array} : (tensor<2x3xi32>) -> tensor<1x3x2xi32> + %1 = tosa.reshape %arg0, %shape : (tensor<2x3xi32>, !tosa.shape<3>) -> tensor<1x3x2xi32> %2 = tosa.transpose %1, %perms : (tensor<1x3x2xi32>, tensor<3xi32>) -> tensor<1x2x3xi32> return %2 : tensor<1x2x3xi32> } @@ -154,11 +156,13 @@ func.func @test_basic_non_broadcasting_reshape(%arg0: tensor<2x3xi32>) -> tensor // ----- // CHECK-LABEL: @test_dynamic_broadcasting_reshape -// CHECK: %[[RES:.*]] = tosa.reshape %arg0 {new_shape = array} : (tensor) -> tensor<1x1x?xi32> +// CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {value = dense<[1, 1, -1]> : tensor<3xindex>} +// CHECK: %[[RES:.*]] = tosa.reshape %arg0, %[[SHAPE]] : (tensor, !tosa.shape<3>) -> tensor<1x1x?xi32> // CHECK: return %[[RES]] func.func @test_dynamic_broadcasting_reshape(%arg0: tensor) -> tensor<1x1x?xi32> { + %shape = tosa.const_shape {value = dense<[1, -1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> %perms = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32> - %1 = tosa.reshape %arg0 {new_shape = array} : (tensor) -> tensor<1x?x1xi32> + %1 = tosa.reshape %arg0, %shape : (tensor, !tosa.shape<3>) -> tensor<1x?x1xi32> %2 = tosa.transpose %1, %perms : (tensor<1x?x1xi32>, tensor<3xi32>) -> tensor<1x1x?xi32> return %2 : tensor<1x1x?xi32> } @@ -167,12 +171,14 @@ func.func @test_dynamic_broadcasting_reshape(%arg0: tensor) -> tensor<1x1 // CHECK-LABEL: @test_reshape_for_broadcast // CHECK-DAG: %[[RESHAPE_INPUT:.*]] = "tosa.const"() <{value = dense<[1, 2, 3, 4]> -// CHECK-DAG: %[[RESHAPE:.*]] = tosa.reshape %[[RESHAPE_INPUT]] {new_shape = array} -// CHECK-DAG: %[[ADD:.*]] = tosa.add %arg0, %[[RESHAPE]] +// CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {value = dense<[4, 1, 1]> : tensor<3xindex>} +// CHECK: %[[RESHAPE:.*]] = tosa.reshape %[[RESHAPE_INPUT]], %[[SHAPE]] : (tensor<4xi32>, !tosa.shape<3>) -> tensor<4x1x1xi32> +// CHECK: %[[ADD:.*]] = tosa.add %arg0, %[[RESHAPE]] // CHECK: return %[[ADD]] func.func @test_reshape_for_broadcast(%arg0: tensor<4x3x2xi32>) -> tensor<4x3x2xi32> { %0 = "tosa.const"() {value = dense<[1,2,3,4]> : tensor<4xi32>} : () -> tensor<4xi32> - %reshape = tosa.reshape %0 {new_shape = array} : (tensor<4xi32>) -> tensor<1x1x4xi32> + %1 = tosa.const_shape {value = dense<[1, 1, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> + %reshape = tosa.reshape %0, %1 : (tensor<4xi32>, !tosa.shape<3>) -> tensor<1x1x4xi32> %perms0 = "tosa.const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> tensor<3xi32> %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<4x3x2xi32>, tensor<3xi32>) -> tensor<2x3x4xi32> %add = tosa.add %transpose0, %reshape : (tensor<2x3x4xi32>, tensor<1x1x4xi32>) -> tensor<2x3x4xi32> @@ -187,25 +193,28 @@ func.func @test_reshape_for_broadcast(%arg0: tensor<4x3x2xi32>) -> tensor<4x3x2x // CHECK-LABEL: @test_resnet18_common_case // COM: note that %74 is now represented by %arg2 -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense_resource : tensor<64xf32>}> : () -> tensor<64xf32> -// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense_resource : tensor<64xf32>}> : () -> tensor<64xf32> -// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK-DAG: %[[VAL_6:.*]] = tosa.add %arg1, %[[VAL_4]] : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32> -// CHECK-DAG: %[[VAL_7:.*]] = tosa.pow %[[VAL_6]], %[[VAL_5]] : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32> -// CHECK-DAG: %[[VAL_8:.*]] = tosa.reciprocal %[[VAL_7]] : (tensor<64xf32>) -> tensor<64xf32> -// CHECK-DAG: %[[VAL_9:.*]] = tosa.reshape %arg0 {new_shape = array} : (tensor<64xf32>) -> tensor<1x1x1x64xf32> -// CHECK-DAG: %[[VAL_10:.*]] = tosa.sub %arg2, %[[VAL_9]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32> -// CHECK-DAG: %[[VAL_11:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor<64xf32>) -> tensor<1x1x1x64xf32> -// CHECK-DAG: %[[VAL_12:.*]] = tosa.mul %[[VAL_10]], %[[VAL_11]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32> -// CHECK-DAG: %[[VAL_13:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<64xf32>) -> tensor<1x1x1x64xf32> -// CHECK-DAG: %[[VAL_14:.*]] = tosa.mul %[[VAL_12]], %[[VAL_13]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32> -// CHECK-DAG: %[[VAL_15:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<64xf32>) -> tensor<1x1x1x64xf32> -// CHECK-DAG: %[[VAL_16:.*]] = tosa.add %[[VAL_14]], %[[VAL_15]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32> -// CHECK-DAG: %[[VAL_17:.*]] = tosa.clamp %[[VAL_16]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32> -// CHECK: return %[[VAL_17]] : tensor<1x112x112x64xf32> - +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense_resource : tensor<64xf32>}> : () -> tensor<64xf32> +// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense_resource : tensor<64xf32>}> : () -> tensor<64xf32> +// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: %[[VAL_7:.*]] = tosa.add %arg1, %[[VAL_5]] : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32> +// CHECK-DAG: %[[VAL_8:.*]] = tosa.pow %[[VAL_7]], %[[VAL_6]] : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32> +// CHECK-DAG: %[[VAL_9:.*]] = tosa.reciprocal %[[VAL_8]] : (tensor<64xf32>) -> tensor<64xf32> +// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 64]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_11:.*]] = tosa.reshape %arg0, %[[VAL_10]] : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x1x1x64xf32> +// CHECK-DAG: %[[VAL_12:.*]] = tosa.sub %arg2, %[[VAL_11]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32> +// CHECK-DAG: %[[VAL_13:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 64]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_14:.*]] = tosa.reshape %[[VAL_9]], %[[VAL_13]] : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x1x1x64xf32> +// CHECK-DAG: %[[VAL_15:.*]] = tosa.mul %[[VAL_12]], %[[VAL_14]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32> +// CHECK-DAG: %[[VAL_16:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 64]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_17:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_16]] : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x1x1x64xf32> +// CHECK-DAG: %[[VAL_18:.*]] = tosa.mul %[[VAL_15]], %[[VAL_17]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32> +// CHECK-DAG: %[[VAL_19:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 64]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_20:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_19]] : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x1x1x64xf32> +// CHECK-DAG: %[[VAL_21:.*]] = tosa.add %[[VAL_18]], %[[VAL_20]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32> +// CHECK-DAG: %[[VAL_22:.*]] = tosa.clamp %[[VAL_21]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32> func.func @test_resnet18_common_case(%arg0: tensor<64xf32>, %arg1: tensor<64xf32>, %74: tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32> { + %58 = tosa.const_shape {value = dense<[1, 64, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> %59 = "tosa.const"() <{value = dense_resource : tensor<64xf32>}> : () -> tensor<64xf32> %60 = "tosa.const"() <{value = dense_resource : tensor<64xf32>}> : () -> tensor<64xf32> %63 = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> @@ -216,13 +225,13 @@ func.func @test_resnet18_common_case(%arg0: tensor<64xf32>, %arg1: tensor<64xf32 %76 = tosa.add %arg1, %69 : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32> %77 = tosa.pow %76, %70 : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32> %78 = tosa.reciprocal %77 : (tensor<64xf32>) -> tensor<64xf32> - %79 = tosa.reshape %arg0 {new_shape = array} : (tensor<64xf32>) -> tensor<1x64x1x1xf32> + %79 = tosa.reshape %arg0, %58 : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x64x1x1xf32> %80 = tosa.sub %75, %79 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32> - %81 = tosa.reshape %78 {new_shape = array} : (tensor<64xf32>) -> tensor<1x64x1x1xf32> + %81 = tosa.reshape %78, %58 : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x64x1x1xf32> %82 = tosa.mul %80, %81 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32> - %83 = tosa.reshape %60 {new_shape = array} : (tensor<64xf32>) -> tensor<1x64x1x1xf32> + %83 = tosa.reshape %60, %58 : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x64x1x1xf32> %84 = tosa.mul %82, %83 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32> - %85 = tosa.reshape %59 {new_shape = array} : (tensor<64xf32>) -> tensor<1x64x1x1xf32> + %85 = tosa.reshape %59, %58 : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x64x1x1xf32> %86 = tosa.add %84, %85 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32> %87 = tosa.clamp %86 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x64x112x112xf32>) -> tensor<1x64x112x112xf32> %88 = tosa.transpose %87, %63 : (tensor<1x64x112x112xf32>, tensor<4xi32>) -> tensor<1x112x112x64xf32> @@ -285,7 +294,8 @@ func.func @test_no_transform_if_outside_fan_in_cone(%arg0: tensor<3x3x3x3xi32>) // CHECK: return %[[RESHAPE]], %[[CLAMP]] func.func @test_two_different_downstream_converge_to_reshape_same_perms(%arg0: tensor<64xf32>) -> (tensor<1x1x64xf32>, tensor<1x1x64xf32>) { %0 = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> - %1 = tosa.reshape %arg0 {new_shape = array} : (tensor<64xf32>) -> tensor<1x64x1xf32> + %shape = tosa.const_shape {value = dense<[1, 64, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> + %1 = tosa.reshape %arg0, %shape : (tensor<64xf32>, !tosa.shape<3>) -> tensor<1x64x1xf32> %2 = tosa.clamp %1 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x64x1xf32>) -> tensor<1x64x1xf32> %3 = tosa.transpose %1, %0 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<1x1x64xf32> %4 = tosa.transpose %2, %0 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<1x1x64xf32> @@ -305,7 +315,8 @@ func.func @test_two_different_downstream_converge_to_reshape_same_perms(%arg0: t func.func @test_two_different_downstream_converge_to_reshape_different_perms(%arg0: tensor<64xf32>) -> (tensor<1x1x64xf32>, tensor<64x1x1xf32>) { %0 = "tosa.const"() <{value = dense<[1, 2, 0]> : tensor<3xi32>}> : () -> tensor<3xi32> %1 = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> - %2 = tosa.reshape %arg0 {new_shape = array} : (tensor<64xf32>) -> tensor<1x64x1xf32> + %shape = tosa.const_shape {value = dense<[1, 64, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> + %2 = tosa.reshape %arg0, %shape : (tensor<64xf32>, !tosa.shape<3>) -> tensor<1x64x1xf32> %3 = tosa.clamp %2 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x64x1xf32>) -> tensor<1x64x1xf32> %4 = tosa.transpose %2, %1 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<1x1x64xf32> %5 = tosa.transpose %3, %0 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<64x1x1xf32>