Skip to content

Commit a4c081c

Browse files
authored
Merge pull request #452 from Xilinx/jrickert.dynamic_shapes
Fix crashes realted to dynamic shapes for GridSample and Cast
2 parents 3b12e41 + 2bd0f01 commit a4c081c

File tree

4 files changed

+47
-9
lines changed

4 files changed

+47
-9
lines changed

src/Dialect/ONNX/ONNXOps/Canonicalize.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,8 @@ SmallVector<Value, 4> castVariadicInput(PatternRewriter &rewriter, Location loc,
126126
SmallVector<Value, 4> castInputs;
127127
for (Value inp : inputs) {
128128
ShapedType inpType = mlir::cast<ShapedType>(inp.getType());
129-
assert(inpType && "Type is not ShapedType");
130-
ONNXCastOp castOp = rewriter.create<ONNXCastOp>(loc,
131-
UnrankedTensorType::get(inpType.getElementType()), inp, saturate, to);
132-
static_cast<void>(castOp.inferShapes([](Region &region) {}));
129+
ONNXCastOp castOp = rewriter.create<ONNXCastOp>(
130+
loc, inpType.clone(to.getValue()), inp, saturate, to);
133131
castInputs.emplace_back(castOp.getResult());
134132
}
135133
return castInputs;

src/Dialect/ONNX/ONNXOps/Tensor/GridSample.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ LogicalResult ONNXGridSampleOpShapeHelper::computeShape() {
2929

3030
// Read data and indices shapes as dim indices.
3131
ONNXGridSampleOpAdaptor operandAdaptor(operands);
32+
if (!hasShapeAndRank(operandAdaptor.getX()) ||
33+
!hasShapeAndRank(operandAdaptor.getGrid())) {
34+
return failure();
35+
}
3236
DimsExpr inputDims;
3337
DimsExpr gridDims;
3438
createIE->getShapeAsDims(operandAdaptor.getX(), inputDims);
@@ -78,11 +82,18 @@ LogicalResult ONNXGridSampleOp::verify() {
7882
if (!hasShapeAndRank(getOperation()))
7983
return success();
8084

81-
auto inputShape =
82-
mlir::cast<ShapedType>(operandAdaptor.getX().getType()).getShape();
83-
int64_t inputRank = inputShape.size();
84-
auto gridShape =
85-
mlir::cast<ShapedType>(operandAdaptor.getGrid().getType()).getShape();
85+
auto inputType = mlir::cast<ShapedType>(operandAdaptor.getX().getType());
86+
if (!inputType.hasStaticShape()) {
87+
return success();
88+
}
89+
const auto inputShape = inputType.getShape();
90+
const int64_t inputRank = inputShape.size();
91+
92+
auto gridType = mlir::cast<ShapedType>(operandAdaptor.getGrid().getType());
93+
if (!gridType.hasStaticShape()) {
94+
return success();
95+
}
96+
const auto gridShape = gridType.getShape();
8697

8798
// Check whether the ranks of input and grid are valid and are equal.
8899
// Input's dimensions of rank r+2 should be in the form of (N,C,D1,D2,...,Dr)

test/mlir/onnx/onnx_canonicalization.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,22 @@ func.func @cast_concat_swap(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tenso
8585

8686
// -----
8787

88+
func.func @cast_concat_swap_dynamic(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi64> {
89+
%0 = "onnx.Concat"(%arg0, %arg1) {axis = 0 : si64} : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
90+
%1 = "onnx.Cast"(%0) {to = i64} : (tensor<*xi32>) -> tensor<*xi64>
91+
onnx.Return %1 : tensor<*xi64>
92+
93+
// CHECK-LABEL: func.func @cast_concat_swap_dynamic
94+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xi32>, [[PARAM_1_:%.+]]: tensor<*xi32>) -> tensor<*xi64> {
95+
// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Cast"([[PARAM_0_]]) {saturate = 1 : si64, to = i64} : (tensor<*xi32>) -> tensor<*xi64>
96+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Cast"([[PARAM_1_]]) {saturate = 1 : si64, to = i64} : (tensor<*xi32>) -> tensor<*xi64>
97+
// CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<*xi64>, tensor<*xi64>) -> tensor<*xi64>
98+
// CHECK: onnx.Return [[VAR_2_]] : tensor<*xi64>
99+
// CHECK: }
100+
}
101+
102+
// -----
103+
88104
func.func @cast_slice_swap(%arg0: tensor<3xi32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>, %arg3: tensor<1xi64>, %arg4: tensor<1xi64>) -> tensor<1xi64> {
89105
%0 = "onnx.Slice"(%arg0, %arg1, %arg2, %arg3, %arg4) {axis = 0 : si64} : (tensor<3xi32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi32>
90106
%1 = "onnx.Cast"(%0) {to = i64} : (tensor<1xi32>) -> tensor<1xi64>

test/mlir/onnx/onnx_shape_inference.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4270,6 +4270,19 @@ func.func @test_grid_sample_same_dims(%arg0: tensor<1x3x1152x1344xf32>, %arg1: t
42704270
// CHECK: }
42714271
}
42724272

4273+
4274+
func.func @test_grid_sample_one_dynamic(%arg0: tensor<*xf32>, %arg1: tensor<1x1152x1344x2xf32>) -> tensor<*xf32> {
4275+
%0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<*xf32>, tensor<1x1152x1344x2xf32>) -> tensor<*xf32>
4276+
return %0 : tensor<*xf32>
4277+
4278+
// COM: Check that we do not crash
4279+
// CHECK-LABEL: func.func @test_grid_sample_one_dynamic
4280+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<1x1152x1344x2xf32>) -> tensor<*xf32> {
4281+
// CHECK: [[VAR_0_:%.+]] = "onnx.GridSample"([[PARAM_0_]], [[PARAM_1_]]) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<*xf32>, tensor<1x1152x1344x2xf32>) -> tensor<*xf32>
4282+
// CHECK: return [[VAR_0_]] : tensor<*xf32>
4283+
// CHECK: }
4284+
}
4285+
42734286
func.func @test_grid_sample_diff_dims(%arg0: tensor<1x1x4x4xf32>, %arg1: tensor<1x6x6x2xf32>) -> tensor<*xf32> {
42744287
%0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<1x6x6x2xf32>) -> tensor<*xf32>
42754288
return %0 : tensor<*xf32>

0 commit comments

Comments
 (0)