Skip to content

Commit e4218b4

Browse files
committed
[mlir][tosa] Allow creation of reshape with unranked output
This commit allows reshape to be created with an unranked output, allowing it to be inferred by the shape inference pass. Change-Id: I639e68982946eeac6dcbc0d30e6cfa2217592091
1 parent 5528770 commit e4218b4

File tree

3 files changed

+17
-6
lines changed

3 files changed

+17
-6
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1959,7 +1959,7 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape"> {
19591959
);
19601960

19611961
let results = (outs
1962-
Tosa_RankedTensor:$output
1962+
Tosa_Tensor:$output
19631963
);
19641964

19651965
list<Availability> availability = [

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2012,14 +2012,21 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
20122012
return failure();
20132013
}
20142014
TensorType inputType = getInput1().getType();
2015-
RankedTensorType outputType = getType();
20162015

20172016
SmallVector<int64_t> shapeValues;
20182017
if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
20192018
// skip following checks if shape is not constant
20202019
return mlir::success();
20212020
}
20222021

2022+
int missingDims = llvm::count(shapeValues, -1);
2023+
if (missingDims > 1)
2024+
return emitOpError() << "expected at most one target dimension to be -1";
2025+
2026+
const auto outputType = dyn_cast<RankedTensorType>(getType());
2027+
if (!outputType)
2028+
return success();
2029+
20232030
if ((int64_t)shapeValues.size() != outputType.getRank())
20242031
return emitOpError() << "new shape does not match result rank";
20252032

@@ -2056,10 +2063,6 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
20562063
}
20572064
}
20582065

2059-
int missingDims = llvm::count(shapeValues, -1);
2060-
if (missingDims > 1)
2061-
return emitOpError() << "expected at most one target dimension to be -1";
2062-
20632066
return mlir::success();
20642067
}
20652068

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,14 @@ func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> {
643643
return %0 : tensor<1x819xf32>
644644
}
645645

646+
// -----
647+
// CHECK-LABEL: reshape_unranked_output
648+
func.func @test_reshape_unranked_output(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> {
649+
%1 = tosa.const_shape {values = dense<[21, 13, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
650+
%0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf32>, !tosa.shape<3>) -> tensor<*xf32>
651+
return %0 : tensor<*xf32>
652+
}
653+
646654
// -----
647655
// CHECK-LABEL: reverse
648656
func.func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {

0 commit comments

Comments
 (0)