Skip to content

Commit b36aa95

Browse files
committed
check input shape is static
1 parent 8d72440 commit b36aa95

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

src/Conversion/ONNXToTOSA/Tensor/Gather.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class ONNXGatherLoweringToTOSA : public OpConversionPattern<ONNXGatherOp> {
4646
if (!onnx_mlir::isRankedShapedType(inputType))
4747
return rewriter.notifyMatchFailure(op, "input is not a ranked tensor");
4848

49-
if (!hasStaticShape(result.getType()))
49+
if (!hasStaticShape(inputType) || !hasStaticShape(result.getType()))
5050
return rewriter.notifyMatchFailure(op, "dynamic shapes not supported");
5151

5252
auto resultTy = dyn_cast<TensorType>(op.getType());

test/mlir/conversion/onnx_to_tosa/Tensor/Gather.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,12 @@ func.func @test_gather_dynamic_shape_indices_i32(%arg0 : tensor<?x4xf32>, %indic
175175
// CHECK-LABEL: test_gather_dynamic_shape_indices_i32
176176
// CHECK: onnx.Gather
177177
}
178+
179+
// -----
180+
181+
func.func @test_gather_dynamic_input_static_output(%arg0 : tensor<?x2xf32>, %indices: tensor<?xi64>) -> tensor<1x2xf32> {
182+
%0 = "onnx.Gather"(%arg0, %indices) {axis = 0 : si64, onnx_node_name = "/Gather_16"} : (tensor<?x2xf32>, tensor<?xi64>) -> tensor<1x2xf32>
183+
"func.return"(%0) : (tensor<1x2xf32>) -> ()
184+
// CHECK-LABEL: test_gather_dynamic_input_static_output
185+
// CHECK: onnx.Gather
186+
}

0 commit comments

Comments
 (0)