Skip to content

Conversation

@lhutton1
Copy link
Contributor

This commit allows reshape to be created with an unranked output, allowing it to be inferred by the shape inference pass.

This commit allows reshape to be created with an unranked output,
allowing it to be inferred by the shape inference pass.

Change-Id: I639e68982946eeac6dcbc0d30e6cfa2217592091
@llvmbot
Copy link
Member

llvmbot commented May 19, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

Changes

This commit allows reshape to be created with an unranked output, allowing it to be inferred by the shape inference pass.


Full diff: https://github.com/llvm/llvm-project/pull/140617.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1-1)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+8-5)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+8)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 52bb0eb992b69..86f9ab94ec152 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1959,7 +1959,7 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape"> {
   );
 
   let results = (outs
-    Tosa_RankedTensor:$output
+    Tosa_Tensor:$output
   );
 
   list<Availability> availability = [
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index b2e471f2bba93..b74b820e11f75 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2012,7 +2012,6 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
     return failure();
   }
   TensorType inputType = getInput1().getType();
-  RankedTensorType outputType = getType();
 
   SmallVector<int64_t> shapeValues;
   if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
@@ -2020,6 +2019,14 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
     return mlir::success();
   }
 
+  int missingDims = llvm::count(shapeValues, -1);
+  if (missingDims > 1)
+    return emitOpError() << "expected at most one target dimension to be -1";
+
+  const auto outputType = dyn_cast<RankedTensorType>(getType());
+  if (!outputType)
+    return success();
+
   if ((int64_t)shapeValues.size() != outputType.getRank())
     return emitOpError() << "new shape does not match result rank";
 
@@ -2056,10 +2063,6 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
     }
   }
 
-  int missingDims = llvm::count(shapeValues, -1);
-  if (missingDims > 1)
-    return emitOpError() << "expected at most one target dimension to be -1";
-
   return mlir::success();
 }
 
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index f8273190bde40..e727614bd76f9 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -643,6 +643,14 @@ func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> {
   return %0 : tensor<1x819xf32>
 }
 
+// -----
+// CHECK-LABEL: reshape_unranked_output
+func.func @test_reshape_unranked_output(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> {
+  %1 = tosa.const_shape {values = dense<[21, 13, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf32>, !tosa.shape<3>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
 // -----
 // CHECK-LABEL: reverse
 func.func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {

@lhutton1 lhutton1 merged commit 22a4930 into llvm:main May 21, 2025
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants