Skip to content

Conversation

@tatwaichong
Copy link
Contributor

The shape operand is changed to input shape type since V1.0

Change-Id: I508cc1d67e9b017048b3f29fecf202cb7d707110

@llvmbot
Copy link
Member

llvmbot commented Feb 5, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir-linalg

Author: TatWai Chong (tatwaichong)

Changes

The shape operand is changed to input shape type since V1.0

Change-Id: I508cc1d67e9b017048b3f29fecf202cb7d707110


Patch is 115.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125789.diff

25 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h (+3)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+3-2)
  • (modified) mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp (+7-1)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+6-2)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+23-9)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp (+17-20)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp (+7-5)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+17-6)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp (+9-2)
  • (modified) mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp (+12-7)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (+2-1)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+10-5)
  • (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+72-38)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+37-20)
  • (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+2-1)
  • (modified) mlir/test/Dialect/Tosa/inlining.mlir (+2-1)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+33-12)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+2-1)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+16-8)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir (+25-12)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir (+29-15)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+42-26)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+27-14)
  • (modified) mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir (+44-33)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 8ede271cc56a8a..869ab913a715ad 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1621,7 +1621,7 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape"> {
 
   let arguments = (ins
     Tosa_Tensor:$input1,
-    DenseI64ArrayAttr:$new_shape
+    Tosa_Shape:$shape
   );
 
   let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index 78a8828855437e..88c21629286525 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -230,8 +230,11 @@ SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
 }
 
 // Computes shape value using tosa const_shape op.
+Value getTosaConstShape(ImplicitLocOpBuilder &builder,
+                        llvm::ArrayRef<int64_t> shape);
 Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
                         llvm::ArrayRef<int64_t> shape);
+
 SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);
 
 bool getConstShapeValue(Operation *op,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b0eb2d6cbc30b6..edb04010d53fd9 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1952,9 +1952,10 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
           nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
         });
 
+    auto shapeValue = getTosaConstShape(
+        rewriter, loc, mlir::tosa::convertFromMlirShape(resultTy.getShape()));
     rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
-        op, resultTy, genericOp.getResult(0),
-        rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
+        op, resultTy, genericOp.getResult(0), shapeValue);
     return success();
   }
 };
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index c4b787d5c865b0..fdb8b1e1471a73 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -235,7 +236,12 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
       return rewriter.notifyMatchFailure(reshape.getLoc(),
                                          "expected input type to be tensor");
     }
-    auto newShape = reshape.getNewShape();
+
+    llvm::SmallVector<int64_t> newShape;
+    if (!tosa::getConstShapeValue(reshape.getShape().getDefiningOp(),
+                                  newShape)) {
+      return failure();
+    }
 
     // Infer all intermediate types
     auto inputType = inferReshapeInputType(input, newShape);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9d36947b4352bb..229719f5ef84d4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -180,7 +180,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
 
     rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
         op, op.getType(), op.getInput1(),
-        rewriter.getDenseI64ArrayAttr(newShape));
+        getTosaConstShape(rewriter, op.getLoc(), newShape));
     return success();
   }
 };
@@ -948,8 +948,12 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
     if (!getInput1().hasOneUse())
       return {};
 
+    llvm::SmallVector<int64_t> shapeVec;
+    if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeVec))
+      return {};
+
     return operand.reshape(
-        llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
+        llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
   }
 
   return {};
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e8b28906135edf..f88c6df8e2b458 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1309,8 +1309,16 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   ShapeAdaptor inputShape(adaptor.getInput1().getType());
   Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
-  llvm::SmallVector<int64_t> newShapeValue =
-      convertToMlirShape(adaptor.getNewShape());
+  llvm::SmallVector<int64_t> newShapeValue;
+  if (!tosa::getConstShapeValue(adaptor.getShape().getDefiningOp(),
+                                newShapeValue)) {
+    auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
+    SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
+    inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
+    return success();
+  } else {
+    newShapeValue = convertToMlirShape(newShapeValue);
+  }
 
   // We cannot infer from the total number of elements so we must take the
   // shape attribute as exact.
@@ -1346,13 +1354,19 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
   TensorType inputType = getInput1().getType();
   RankedTensorType outputType = getType();
 
-  if ((int64_t)getNewShape().size() != outputType.getRank())
+  SmallVector<int64_t> shapeValues;
+  if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeValues)) {
+    // skip following checks if shape is not constant
+    return mlir::success();
+  }
+
+  if ((int64_t)shapeValues.size() != outputType.getRank())
     return emitOpError() << "new shape does not match result rank";
 
   for (auto [newShapeDim, outputShapeDim] :
-       zip(getNewShape(), outputType.getShape())) {
-    if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
-        newShapeDim != outputShapeDim)
+       zip(shapeValues, outputType.getShape())) {
+    if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
+        outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
       return emitOpError() << "new shape is inconsistent with result shape";
 
     if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
@@ -1371,10 +1385,10 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
     }
 
     int64_t newShapeElementsNum = std::accumulate(
-        getNewShape().begin(), getNewShape().end(), 1LL,
+        shapeValues.begin(), shapeValues.end(), 1LL,
         [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
     bool isStaticNewShape =
-        llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; });
+        llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
     if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
         (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
       return emitOpError() << "cannot reshape " << inputElementsNum
@@ -1382,7 +1396,7 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
     }
   }
 
-  int missingDims = llvm::count(getNewShape(), -1);
+  int missingDims = llvm::count(shapeValues, -1);
   if (missingDims > 1)
     return emitOpError() << "expected at most one target dimension to be -1";
 
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
index 7d3deae3330afe..04e8ad31cf2e2e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -20,12 +20,6 @@ using namespace mlir::tosa;
 
 namespace {
 
-SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape) {
-  return to_vector(llvm::map_range(shape, [](int64_t dim) {
-    return ShapedType::isDynamic(dim) ? -1 : dim;
-  }));
-}
-
 struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
   explicit Conv2DIsFullyConnected(MLIRContext *context)
       : OpRewritePattern(context) {}
@@ -98,12 +92,13 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
     llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]};
     auto revisedInputShapeType =
         RankedTensorType::get(revisedInputShape, inputType.getElementType());
-    auto reshapedInput = rewriter
-                             .create<tosa::ReshapeOp>(
-                                 op.getLoc(), revisedInputShapeType, input,
-                                 rewriter.getDenseI64ArrayAttr(
-                                     convertFromMlirShape(revisedInputShape)))
-                             .getResult();
+    auto revisedInputShapeValue = getTosaConstShape(
+        rewriter, op.getLoc(), convertFromMlirShape(revisedInputShape));
+    auto reshapedInput =
+        rewriter
+            .create<tosa::ReshapeOp>(op.getLoc(), revisedInputShapeType, input,
+                                     revisedInputShapeValue)
+            .getResult();
 
     // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
     llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
@@ -111,12 +106,13 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
     auto revisedWeightShapeType = RankedTensorType::get(
         revisedWeightShape,
         dyn_cast<RankedTensorType>(weight.getType()).getElementType());
-    auto reshapedWeight = rewriter
-                              .create<tosa::ReshapeOp>(
-                                  op.getLoc(), revisedWeightShapeType, weight,
-                                  rewriter.getDenseI64ArrayAttr(
-                                      convertFromMlirShape(revisedWeightShape)))
-                              .getResult();
+    auto revisedWeightShapeValue = getTosaConstShape(
+        rewriter, op.getLoc(), convertFromMlirShape(revisedWeightShape));
+    auto reshapedWeight =
+        rewriter
+            .create<tosa::ReshapeOp>(op.getLoc(), revisedWeightShapeType,
+                                     weight, revisedWeightShapeValue)
+            .getResult();
 
     // Perform a fully connected network over the reshaped input and weight.
     llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]};
@@ -149,9 +145,10 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
     // Reshape output to [N, IH, IW, OC].
     llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
                                               inputShape[2], weightShape[0]};
+    auto outputShapeValue = getTosaConstShape(
+        rewriter, op.getLoc(), convertFromMlirShape(outputShape));
     rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
-        op, resultType, fullyConnectedValue,
-        rewriter.getDenseI64ArrayAttr(convertFromMlirShape(outputShape)));
+        op, resultType, fullyConnectedValue, outputShapeValue);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index ee857f1998a54d..b26397d0e3ed7a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -55,10 +55,11 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
     inputType = RankedTensorType::get(
         revisedInputShape,
         dyn_cast<RankedTensorType>(input.getType()).getElementType());
+    auto revisedInputShapeValue =
+        getTosaConstShape(rewriter, op.getLoc(), revisedInputShape);
     input = rewriter
-                .create<tosa::ReshapeOp>(
-                    op.getLoc(), inputType, input,
-                    rewriter.getDenseI64ArrayAttr(revisedInputShape))
+                .create<tosa::ReshapeOp>(op.getLoc(), inputType, input,
+                                         revisedInputShapeValue)
                 .getResult();
 
     Type inputETy = inputType.getElementType();
@@ -153,9 +154,10 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
     auto outputShapeType = RankedTensorType::get(
         outputShape,
         dyn_cast<RankedTensorType>(input.getType()).getElementType());
+    auto outputShapeValue =
+        getTosaConstShape(rewriter, op->getLoc(), outputShape);
     Value outputValue = rewriter.create<tosa::ReshapeOp>(
-        op.getLoc(), outputShapeType, mulValue,
-        rewriter.getDenseI64ArrayAttr(outputShape));
+        op.getLoc(), outputShapeType, mulValue, outputShapeValue);
 
     Value bias = op.getBias();
     if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index ae224671e304f2..69a66c98307e94 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -160,9 +160,11 @@ class TransposeConvStridedConverter
         outputChannels, weightHeight / stride[0],
         stride[0],      weightWidth / stride[1],
         stride[1],      inputChannels};
+
+    ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
     weight = CreateOpAndInferShape<tosa::ReshapeOp>(
-        rewriter, loc, UnrankedTensorType::get(weightETy), weight,
-        rewriter.getDenseI64ArrayAttr(weightReshapeDims0));
+        builder, UnrankedTensorType::get(weightETy), weight,
+        getTosaConstShape(rewriter, loc, weightReshapeDims0));
 
     // Transpose the factored-out stride to the output channels.
     Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
@@ -174,12 +176,13 @@ class TransposeConvStridedConverter
         transposeWeightVal);
 
     // Collapse the strides and output channels into a single dimension.
-    llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
+    llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
         outputChannels * stride[0] * stride[1], weightHeight / stride[0],
         weightWidth / stride[1], inputChannels};
+
     weight = CreateOpAndInferShape<tosa::ReshapeOp>(
         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
-        rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
+        getTosaConstShape(rewriter, loc, weightReshapeDims1));
     ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());
 
     weight = CreateOpAndInferShape<tosa::ReverseOp>(
@@ -258,9 +261,13 @@ class TransposeConvStridedConverter
     // Factor striding out of the convolution result.
     llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
         batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
+
+    auto convReshapeDims0Value =
+        getTosaConstShape(rewriter, loc, convReshapeDims0);
+
     conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
         rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
-        rewriter.getDenseI64ArrayAttr(convReshapeDims0));
+        convReshapeDims0Value);
 
     // Transpose the factored-out stride to the output channels.
     Value transposeConvVal = rewriter.create<tosa::ConstOp>(
@@ -274,9 +281,13 @@ class TransposeConvStridedConverter
     // Fuse striding behavior back into width / height.
     llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
         batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
+
+    auto convReshapeDims1Value =
+        getTosaConstShape(rewriter, loc, convReshapeDims1);
+
     conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
         rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
-        rewriter.getDenseI64ArrayAttr(convReshapeDims1));
+        convReshapeDims1Value);
 
     // Determine the amount to slice / pad from the result start.
     int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index 520f283a3ba888..281f0529a5c081 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -402,13 +402,20 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
     return std::nullopt;
 
   // Do not insert a TransposeOp, instead we fold the reshape and its attribute.
+  llvm::SmallVector<int64_t> newShape;
+  if (!tosa::getConstShapeValue(reshapeOp.getShape().getDefiningOp(),
+                                newShape)) {
+    // this mean shape is not constant
+    return std::nullopt;
+  }
+  ImplicitLocOpBuilder builder(reshapeOp.getLoc(), rewriter);
   auto foldedReshape = rewriter.create<ReshapeOp>(
       reshapeOp.getLoc(),
       RankedTensorType::get(applyTOSAPermutation(shape, hoistedPerms),
                             reshapeOutputType.getElementType()),
       reshapeOp.getInput1(),
-      rewriter.getDenseI64ArrayAttr(
-          applyTOSAPermutation(reshapeOp.getNewShape(), hoistedPerms)));
+      getTosaConstShape(builder, applyTOSAPermutation(llvm::ArrayRef(newShape),
+                                                      hoistedPerms)));
   return foldedReshape->getResult(0);
 }
 
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index 62b0bc1857e395..8ab12d038849f4 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -145,10 +145,10 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
       llvm::cast<RankedTensorType>(lowerTensorValue.getType());
   auto reshapeOutputType = RankedTensorType::get(
       ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
+  auto reshapeOutputShapeValue = getTosaConstShape(builder, reshapeOutputShape);
 
   auto reshapeLower = builder.create<tosa::ReshapeOp>(
-      reshapeOutputType, lowerTensorValue,
-      builder.getDenseI64ArrayAttr(reshapeOutputShape));
+      reshapeOutputType, lowerTensorValue, reshapeOutputShapeValue);
 
   if (input1Rank > input2Rank) {
     input1 = higherTensorValue;
@@ -161,15 +161,20 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
   return success();
 }
 
-Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
+Value mlir::tosa::getTosaConstShape(ImplicitLocOpBuilder &builder,
                                     llvm::ArrayRef<int64_t> shape) {
-  auto attr = rewriter.getIndexTensorAttr(shape);
-  auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size());
-  mlir::Operation *mlir_op =
-      rewriter.create<tosa::ConstShapeOp>(loc, type, attr);
+  auto attr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
+  auto type = mlir::tosa::shapeType::get(builder.getContext(), shape.size());
+  mlir::Operation *mlir_op = builder.create<tosa::ConstShapeOp>(type, attr);
   return mlir_op->getResult(0);
 }
 
+Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
+                                    llvm::ArrayRef<int64_t> shape) {
+  ImplicitLocOpBuilder builder(loc, rewriter);
+  return getTosaConstShape(builder, shape);
+}
+
 SmallVector<int64_t> mlir::tosa::convertFromMlirShape(ArrayRef<int64_t> shape) {
   return to_vector(llvm::map_range(shape, [](int64_t dim) {
     return ShapedType::isDynamic(dim) ? -1 : dim;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 75b48f2b06d899..460e207d62de6a 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -24,7 +24,8 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
   %reduce = tosa.reduce_max %arg0 {axis = 1 : i32} : (tensor<10x10xf32>) -> tensor<10x1xf32>
   %1 = tosa.add %reduce, %arg1 : (tensor<10x1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
   %0 = tosa.add %1, %arg2 : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32>
-  %2 = tosa.reshape %0 {new_shape = array<i64: 10, 10>} : (tensor<...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Feb 5, 2025

@llvm/pr-subscribers-mlir

Author: TatWai Chong (tatwaichong)

Changes

The shape operand is changed to input shape type since V1.0

Change-Id: I508cc1d67e9b017048b3f29fecf202cb7d707110


Patch is 115.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125789.diff

25 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h (+3)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+3-2)
  • (modified) mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp (+7-1)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+6-2)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+23-9)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp (+17-20)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp (+7-5)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+17-6)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp (+9-2)
  • (modified) mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp (+12-7)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (+2-1)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+10-5)
  • (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+72-38)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+37-20)
  • (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+2-1)
  • (modified) mlir/test/Dialect/Tosa/inlining.mlir (+2-1)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+33-12)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+2-1)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+16-8)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir (+25-12)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir (+29-15)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+42-26)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+27-14)
  • (modified) mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir (+44-33)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 8ede271cc56a8a..869ab913a715ad 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1621,7 +1621,7 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape"> {
 
   let arguments = (ins
     Tosa_Tensor:$input1,
-    DenseI64ArrayAttr:$new_shape
+    Tosa_Shape:$shape
   );
 
   let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index 78a8828855437e..88c21629286525 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -230,8 +230,11 @@ SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
 }
 
 // Computes shape value using tosa const_shape op.
+Value getTosaConstShape(ImplicitLocOpBuilder &builder,
+                        llvm::ArrayRef<int64_t> shape);
 Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
                         llvm::ArrayRef<int64_t> shape);
+
 SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);
 
 bool getConstShapeValue(Operation *op,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b0eb2d6cbc30b6..edb04010d53fd9 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1952,9 +1952,10 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
           nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
         });
 
+    auto shapeValue = getTosaConstShape(
+        rewriter, loc, mlir::tosa::convertFromMlirShape(resultTy.getShape()));
     rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
-        op, resultTy, genericOp.getResult(0),
-        rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
+        op, resultTy, genericOp.getResult(0), shapeValue);
     return success();
   }
 };
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index c4b787d5c865b0..fdb8b1e1471a73 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -235,7 +236,12 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
       return rewriter.notifyMatchFailure(reshape.getLoc(),
                                          "expected input type to be tensor");
     }
-    auto newShape = reshape.getNewShape();
+
+    llvm::SmallVector<int64_t> newShape;
+    if (!tosa::getConstShapeValue(reshape.getShape().getDefiningOp(),
+                                  newShape)) {
+      return failure();
+    }
 
     // Infer all intermediate types
     auto inputType = inferReshapeInputType(input, newShape);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9d36947b4352bb..229719f5ef84d4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -180,7 +180,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
 
     rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
         op, op.getType(), op.getInput1(),
-        rewriter.getDenseI64ArrayAttr(newShape));
+        getTosaConstShape(rewriter, op.getLoc(), newShape));
     return success();
   }
 };
@@ -948,8 +948,12 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
     if (!getInput1().hasOneUse())
       return {};
 
+    llvm::SmallVector<int64_t> shapeVec;
+    if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeVec))
+      return {};
+
     return operand.reshape(
-        llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
+        llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
   }
 
   return {};
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e8b28906135edf..f88c6df8e2b458 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1309,8 +1309,16 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   ShapeAdaptor inputShape(adaptor.getInput1().getType());
   Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
-  llvm::SmallVector<int64_t> newShapeValue =
-      convertToMlirShape(adaptor.getNewShape());
+  llvm::SmallVector<int64_t> newShapeValue;
+  if (!tosa::getConstShapeValue(adaptor.getShape().getDefiningOp(),
+                                newShapeValue)) {
+    auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
+    SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
+    inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
+    return success();
+  } else {
+    newShapeValue = convertToMlirShape(newShapeValue);
+  }
 
   // We cannot infer from the total number of elements so we must take the
   // shape attribute as exact.
@@ -1346,13 +1354,19 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
   TensorType inputType = getInput1().getType();
   RankedTensorType outputType = getType();
 
-  if ((int64_t)getNewShape().size() != outputType.getRank())
+  SmallVector<int64_t> shapeValues;
+  if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeValues)) {
+    // skip following checks if shape is not constant
+    return mlir::success();
+  }
+
+  if ((int64_t)shapeValues.size() != outputType.getRank())
     return emitOpError() << "new shape does not match result rank";
 
   for (auto [newShapeDim, outputShapeDim] :
-       zip(getNewShape(), outputType.getShape())) {
-    if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
-        newShapeDim != outputShapeDim)
+       zip(shapeValues, outputType.getShape())) {
+    if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
+        outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
       return emitOpError() << "new shape is inconsistent with result shape";
 
     if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
@@ -1371,10 +1385,10 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
     }
 
     int64_t newShapeElementsNum = std::accumulate(
-        getNewShape().begin(), getNewShape().end(), 1LL,
+        shapeValues.begin(), shapeValues.end(), 1LL,
         [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
     bool isStaticNewShape =
-        llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; });
+        llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
     if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
         (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
       return emitOpError() << "cannot reshape " << inputElementsNum
@@ -1382,7 +1396,7 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
     }
   }
 
-  int missingDims = llvm::count(getNewShape(), -1);
+  int missingDims = llvm::count(shapeValues, -1);
   if (missingDims > 1)
     return emitOpError() << "expected at most one target dimension to be -1";
 
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
index 7d3deae3330afe..04e8ad31cf2e2e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -20,12 +20,6 @@ using namespace mlir::tosa;
 
 namespace {
 
-SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape) {
-  return to_vector(llvm::map_range(shape, [](int64_t dim) {
-    return ShapedType::isDynamic(dim) ? -1 : dim;
-  }));
-}
-
 struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
   explicit Conv2DIsFullyConnected(MLIRContext *context)
       : OpRewritePattern(context) {}
@@ -98,12 +92,13 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
     llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]};
     auto revisedInputShapeType =
         RankedTensorType::get(revisedInputShape, inputType.getElementType());
-    auto reshapedInput = rewriter
-                             .create<tosa::ReshapeOp>(
-                                 op.getLoc(), revisedInputShapeType, input,
-                                 rewriter.getDenseI64ArrayAttr(
-                                     convertFromMlirShape(revisedInputShape)))
-                             .getResult();
+    auto revisedInputShapeValue = getTosaConstShape(
+        rewriter, op.getLoc(), convertFromMlirShape(revisedInputShape));
+    auto reshapedInput =
+        rewriter
+            .create<tosa::ReshapeOp>(op.getLoc(), revisedInputShapeType, input,
+                                     revisedInputShapeValue)
+            .getResult();
 
     // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
     llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
@@ -111,12 +106,13 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
     auto revisedWeightShapeType = RankedTensorType::get(
         revisedWeightShape,
         dyn_cast<RankedTensorType>(weight.getType()).getElementType());
-    auto reshapedWeight = rewriter
-                              .create<tosa::ReshapeOp>(
-                                  op.getLoc(), revisedWeightShapeType, weight,
-                                  rewriter.getDenseI64ArrayAttr(
-                                      convertFromMlirShape(revisedWeightShape)))
-                              .getResult();
+    auto revisedWeightShapeValue = getTosaConstShape(
+        rewriter, op.getLoc(), convertFromMlirShape(revisedWeightShape));
+    auto reshapedWeight =
+        rewriter
+            .create<tosa::ReshapeOp>(op.getLoc(), revisedWeightShapeType,
+                                     weight, revisedWeightShapeValue)
+            .getResult();
 
     // Perform a fully connected network over the reshaped input and weight.
     llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]};
@@ -149,9 +145,10 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
     // Reshape output to [N, IH, IW, OC].
     llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
                                               inputShape[2], weightShape[0]};
+    auto outputShapeValue = getTosaConstShape(
+        rewriter, op.getLoc(), convertFromMlirShape(outputShape));
     rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
-        op, resultType, fullyConnectedValue,
-        rewriter.getDenseI64ArrayAttr(convertFromMlirShape(outputShape)));
+        op, resultType, fullyConnectedValue, outputShapeValue);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index ee857f1998a54d..b26397d0e3ed7a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -55,10 +55,11 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
     inputType = RankedTensorType::get(
         revisedInputShape,
         dyn_cast<RankedTensorType>(input.getType()).getElementType());
+    auto revisedInputShapeValue =
+        getTosaConstShape(rewriter, op.getLoc(), revisedInputShape);
     input = rewriter
-                .create<tosa::ReshapeOp>(
-                    op.getLoc(), inputType, input,
-                    rewriter.getDenseI64ArrayAttr(revisedInputShape))
+                .create<tosa::ReshapeOp>(op.getLoc(), inputType, input,
+                                         revisedInputShapeValue)
                 .getResult();
 
     Type inputETy = inputType.getElementType();
@@ -153,9 +154,10 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
     auto outputShapeType = RankedTensorType::get(
         outputShape,
         dyn_cast<RankedTensorType>(input.getType()).getElementType());
+    auto outputShapeValue =
+        getTosaConstShape(rewriter, op->getLoc(), outputShape);
     Value outputValue = rewriter.create<tosa::ReshapeOp>(
-        op.getLoc(), outputShapeType, mulValue,
-        rewriter.getDenseI64ArrayAttr(outputShape));
+        op.getLoc(), outputShapeType, mulValue, outputShapeValue);
 
     Value bias = op.getBias();
     if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index ae224671e304f2..69a66c98307e94 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -160,9 +160,11 @@ class TransposeConvStridedConverter
         outputChannels, weightHeight / stride[0],
         stride[0],      weightWidth / stride[1],
         stride[1],      inputChannels};
+
+    ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
     weight = CreateOpAndInferShape<tosa::ReshapeOp>(
-        rewriter, loc, UnrankedTensorType::get(weightETy), weight,
-        rewriter.getDenseI64ArrayAttr(weightReshapeDims0));
+        builder, UnrankedTensorType::get(weightETy), weight,
+        getTosaConstShape(rewriter, loc, weightReshapeDims0));
 
     // Transpose the factored-out stride to the output channels.
     Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
@@ -174,12 +176,13 @@ class TransposeConvStridedConverter
         transposeWeightVal);
 
     // Collapse the strides and output channels into a single dimension.
-    llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
+    llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
         outputChannels * stride[0] * stride[1], weightHeight / stride[0],
         weightWidth / stride[1], inputChannels};
+
     weight = CreateOpAndInferShape<tosa::ReshapeOp>(
         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
-        rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
+        getTosaConstShape(rewriter, loc, weightReshapeDims1));
     ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());
 
     weight = CreateOpAndInferShape<tosa::ReverseOp>(
@@ -258,9 +261,13 @@ class TransposeConvStridedConverter
     // Factor striding out of the convolution result.
     llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
         batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
+
+    auto convReshapeDims0Value =
+        getTosaConstShape(rewriter, loc, convReshapeDims0);
+
     conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
         rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
-        rewriter.getDenseI64ArrayAttr(convReshapeDims0));
+        convReshapeDims0Value);
 
     // Transpose the factored-out stride to the output channels.
     Value transposeConvVal = rewriter.create<tosa::ConstOp>(
@@ -274,9 +281,13 @@ class TransposeConvStridedConverter
     // Fuse striding behavior back into width / height.
     llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
         batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
+
+    auto convReshapeDims1Value =
+        getTosaConstShape(rewriter, loc, convReshapeDims1);
+
     conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
         rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
-        rewriter.getDenseI64ArrayAttr(convReshapeDims1));
+        convReshapeDims1Value);
 
     // Determine the amount to slice / pad from the result start.
     int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index 520f283a3ba888..281f0529a5c081 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -402,13 +402,20 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
     return std::nullopt;
 
   // Do not insert a TransposeOp, instead we fold the reshape and its attribute.
+  llvm::SmallVector<int64_t> newShape;
+  if (!tosa::getConstShapeValue(reshapeOp.getShape().getDefiningOp(),
+                                newShape)) {
+    // this mean shape is not constant
+    return std::nullopt;
+  }
+  ImplicitLocOpBuilder builder(reshapeOp.getLoc(), rewriter);
   auto foldedReshape = rewriter.create<ReshapeOp>(
       reshapeOp.getLoc(),
       RankedTensorType::get(applyTOSAPermutation(shape, hoistedPerms),
                             reshapeOutputType.getElementType()),
       reshapeOp.getInput1(),
-      rewriter.getDenseI64ArrayAttr(
-          applyTOSAPermutation(reshapeOp.getNewShape(), hoistedPerms)));
+      getTosaConstShape(builder, applyTOSAPermutation(llvm::ArrayRef(newShape),
+                                                      hoistedPerms)));
   return foldedReshape->getResult(0);
 }
 
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index 62b0bc1857e395..8ab12d038849f4 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -145,10 +145,10 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
       llvm::cast<RankedTensorType>(lowerTensorValue.getType());
   auto reshapeOutputType = RankedTensorType::get(
       ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
+  auto reshapeOutputShapeValue = getTosaConstShape(builder, reshapeOutputShape);
 
   auto reshapeLower = builder.create<tosa::ReshapeOp>(
-      reshapeOutputType, lowerTensorValue,
-      builder.getDenseI64ArrayAttr(reshapeOutputShape));
+      reshapeOutputType, lowerTensorValue, reshapeOutputShapeValue);
 
   if (input1Rank > input2Rank) {
     input1 = higherTensorValue;
@@ -161,15 +161,20 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
   return success();
 }
 
-Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
+Value mlir::tosa::getTosaConstShape(ImplicitLocOpBuilder &builder,
                                     llvm::ArrayRef<int64_t> shape) {
-  auto attr = rewriter.getIndexTensorAttr(shape);
-  auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size());
-  mlir::Operation *mlir_op =
-      rewriter.create<tosa::ConstShapeOp>(loc, type, attr);
+  auto attr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
+  auto type = mlir::tosa::shapeType::get(builder.getContext(), shape.size());
+  mlir::Operation *mlir_op = builder.create<tosa::ConstShapeOp>(type, attr);
   return mlir_op->getResult(0);
 }
 
+Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
+                                    llvm::ArrayRef<int64_t> shape) {
+  ImplicitLocOpBuilder builder(loc, rewriter);
+  return getTosaConstShape(builder, shape);
+}
+
 SmallVector<int64_t> mlir::tosa::convertFromMlirShape(ArrayRef<int64_t> shape) {
   return to_vector(llvm::map_range(shape, [](int64_t dim) {
     return ShapedType::isDynamic(dim) ? -1 : dim;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 75b48f2b06d899..460e207d62de6a 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -24,7 +24,8 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
   %reduce = tosa.reduce_max %arg0 {axis = 1 : i32} : (tensor<10x10xf32>) -> tensor<10x1xf32>
   %1 = tosa.add %reduce, %arg1 : (tensor<10x1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
   %0 = tosa.add %1, %arg2 : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32>
-  %2 = tosa.reshape %0 {new_shape = array<i64: 10, 10>} : (tensor<...
[truncated]

…type

Co-authored-by: TatWai Chong <[email protected]>
Change-Id: I508cc1d67e9b017048b3f29fecf202cb7d707110
@sjarus sjarus requested a review from GeorgeARM February 6, 2025 19:01
Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments.

llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
batch, convHeight * stride[0], convWidth * stride[1], outputChannels};

auto convReshapeDims1Value =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need the new variable? You can put it directly in the function below?

llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
batch, convHeight, convWidth, stride[0], stride[1], outputChannels};

auto convReshapeDims0Value =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Directly to the function below?

auto outputShapeType = RankedTensorType::get(
outputShape,
dyn_cast<RankedTensorType>(input.getType()).getElementType());
auto outputShapeValue =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

directly in the function below?

Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
llvm::SmallVector<int64_t> newShapeValue =
convertToMlirShape(adaptor.getNewShape());
llvm::SmallVector<int64_t> newShapeValue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this two-stage setup for the shape obscure. Would prefer an optional but probably worth a separate patch

@Jerry-Ge Jerry-Ge merged commit 571a987 into llvm:main Feb 7, 2025
8 checks passed
@tatwaichong tatwaichong deleted the reshape_shape_type branch February 7, 2025 18:47
bjacob added a commit to iree-org/llvm-project that referenced this pull request Feb 10, 2025
bjacob added a commit to iree-org/iree that referenced this pull request Feb 10, 2025
We had previously cherry-picked
llvm/llvm-project@73f11ac
in #19939.

Now we're integrating up to that commit, so it's no longer a
cherry-pick.

Reverting llvm/llvm-project#125789 because it
breaks TorchToTosa, in torch-mlir. We will need to wait for this to be
resolved in torch-mlir, then simultaneously bump torch-mlir and drop the
revert.

Chery-pick a Bazel fix:
llvm/llvm-project@4df287a

---------

Signed-off-by: Benoit Jacob <[email protected]>
bjacob added a commit to iree-org/llvm-project that referenced this pull request Feb 10, 2025
bjacob added a commit to iree-org/iree that referenced this pull request Feb 10, 2025
Carrying the existing revert of
llvm/llvm-project#125789 because it breaks
TorchToTosa, in torch-mlir. We will need to wait for this to be resolved
in torch-mlir, then simultaneously bump torch-mlir and drop the revert.

Signed-off-by: Benoit Jacob <[email protected]>
bjacob added a commit to iree-org/llvm-project that referenced this pull request Feb 11, 2025
MaheshRavishankar pushed a commit to iree-org/iree that referenced this pull request Feb 11, 2025
Integrate at llvm/llvm-project@001ba42f

Carrying the existing revert of
llvm/llvm-project#125789 because it breaks
TorchToTosa, in torch-mlir. We will need to wait for this to be resolved
in torch-mlir, then simultaneously bump torch-mlir and drop the revert.

Signed-off-by: Benoit Jacob <[email protected]>
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
llvm#125789)

The shape operand is changed to input shape type since V1.0

Change-Id: I508cc1d67e9b017048b3f29fecf202cb7d707110

Co-authored-by: Won Jeon <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants