Skip to content

Commit 2bd0f01

Browse files
committed
Do not crash when verifiying/shape-infering GridSample if inputs have not static shape
Signed-off-by: Rickert, Jonas <[email protected]>
1 parent e47a8cd commit 2bd0f01

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

src/Dialect/ONNX/ONNXOps/Tensor/GridSample.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ LogicalResult ONNXGridSampleOpShapeHelper::computeShape() {
2929

3030
// Read data and indices shapes as dim indices.
3131
ONNXGridSampleOpAdaptor operandAdaptor(operands);
32+
if (!hasShapeAndRank(operandAdaptor.getX()) ||
33+
!hasShapeAndRank(operandAdaptor.getGrid())) {
34+
return failure();
35+
}
3236
DimsExpr inputDims;
3337
DimsExpr gridDims;
3438
createIE->getShapeAsDims(operandAdaptor.getX(), inputDims);
@@ -78,11 +82,18 @@ LogicalResult ONNXGridSampleOp::verify() {
7882
if (!hasShapeAndRank(getOperation()))
7983
return success();
8084

81-
auto inputShape =
82-
mlir::cast<ShapedType>(operandAdaptor.getX().getType()).getShape();
83-
int64_t inputRank = inputShape.size();
84-
auto gridShape =
85-
mlir::cast<ShapedType>(operandAdaptor.getGrid().getType()).getShape();
85+
auto inputType = mlir::cast<ShapedType>(operandAdaptor.getX().getType());
86+
if (!inputType.hasStaticShape()) {
87+
return success();
88+
}
89+
const auto inputShape = inputType.getShape();
90+
const int64_t inputRank = inputShape.size();
91+
92+
auto gridType = mlir::cast<ShapedType>(operandAdaptor.getGrid().getType());
93+
if (!gridType.hasStaticShape()) {
94+
return success();
95+
}
96+
const auto gridShape = gridType.getShape();
8697

8798
// Check whether the ranks of input and grid are valid and are equal.
8899
// Input's dimensions of rank r+2 should be in the form of (N,C,D1,D2,...,Dr)

test/mlir/onnx/onnx_shape_inference.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4270,6 +4270,19 @@ func.func @test_grid_sample_same_dims(%arg0: tensor<1x3x1152x1344xf32>, %arg1: t
42704270
// CHECK: }
42714271
}
42724272

4273+
4274+
func.func @test_grid_sample_one_dynamic(%arg0: tensor<*xf32>, %arg1: tensor<1x1152x1344x2xf32>) -> tensor<*xf32> {
4275+
%0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<*xf32>, tensor<1x1152x1344x2xf32>) -> tensor<*xf32>
4276+
return %0 : tensor<*xf32>
4277+
4278+
// COM: Check that we do not crash
4279+
// CHECK-LABEL: func.func @test_grid_sample_one_dynamic
4280+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<1x1152x1344x2xf32>) -> tensor<*xf32> {
4281+
// CHECK: [[VAR_0_:%.+]] = "onnx.GridSample"([[PARAM_0_]], [[PARAM_1_]]) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<*xf32>, tensor<1x1152x1344x2xf32>) -> tensor<*xf32>
4282+
// CHECK: return [[VAR_0_]] : tensor<*xf32>
4283+
// CHECK: }
4284+
}
4285+
42734286
func.func @test_grid_sample_diff_dims(%arg0: tensor<1x1x4x4xf32>, %arg1: tensor<1x6x6x2xf32>) -> tensor<*xf32> {
42744287
%0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<1x6x6x2xf32>) -> tensor<*xf32>
42754288
return %0 : tensor<*xf32>

0 commit comments

Comments
 (0)