Skip to content

Commit e47a8cd

Browse files
committed
Do create Cast operations with correct type, instead of relying on shape inference
Signed-off-by: Rickert, Jonas <[email protected]>
1 parent 3b12e41 commit e47a8cd

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

src/Dialect/ONNX/ONNXOps/Canonicalize.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,8 @@ SmallVector<Value, 4> castVariadicInput(PatternRewriter &rewriter, Location loc,
126126
SmallVector<Value, 4> castInputs;
127127
for (Value inp : inputs) {
128128
ShapedType inpType = mlir::cast<ShapedType>(inp.getType());
129-
assert(inpType && "Type is not ShapedType");
130-
ONNXCastOp castOp = rewriter.create<ONNXCastOp>(loc,
131-
UnrankedTensorType::get(inpType.getElementType()), inp, saturate, to);
132-
static_cast<void>(castOp.inferShapes([](Region &region) {}));
129+
ONNXCastOp castOp = rewriter.create<ONNXCastOp>(
130+
loc, inpType.clone(to.getValue()), inp, saturate, to);
133131
castInputs.emplace_back(castOp.getResult());
134132
}
135133
return castInputs;

test/mlir/onnx/onnx_canonicalization.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,22 @@ func.func @cast_concat_swap(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tenso
8585

8686
// -----
8787

88+
func.func @cast_concat_swap_dynamic(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi64> {
89+
%0 = "onnx.Concat"(%arg0, %arg1) {axis = 0 : si64} : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
90+
%1 = "onnx.Cast"(%0) {to = i64} : (tensor<*xi32>) -> tensor<*xi64>
91+
onnx.Return %1 : tensor<*xi64>
92+
93+
// CHECK-LABEL: func.func @cast_concat_swap_dynamic
94+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xi32>, [[PARAM_1_:%.+]]: tensor<*xi32>) -> tensor<*xi64> {
95+
// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Cast"([[PARAM_0_]]) {saturate = 1 : si64, to = i64} : (tensor<*xi32>) -> tensor<*xi64>
96+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Cast"([[PARAM_1_]]) {saturate = 1 : si64, to = i64} : (tensor<*xi32>) -> tensor<*xi64>
97+
// CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<*xi64>, tensor<*xi64>) -> tensor<*xi64>
98+
// CHECK: onnx.Return [[VAR_2_]] : tensor<*xi64>
99+
// CHECK: }
100+
}
101+
102+
// -----
103+
88104
func.func @cast_slice_swap(%arg0: tensor<3xi32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>, %arg3: tensor<1xi64>, %arg4: tensor<1xi64>) -> tensor<1xi64> {
89105
%0 = "onnx.Slice"(%arg0, %arg1, %arg2, %arg3, %arg4) {axis = 0 : si64} : (tensor<3xi32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi32>
90106
%1 = "onnx.Cast"(%0) {to = i64} : (tensor<1xi32>) -> tensor<1xi64>

0 commit comments

Comments
 (0)