Skip to content

Commit 2760083

Browse files
authored
Merge pull request #458 from Xilinx/liangta.ConstantOfShape_crash
Fix crash for ConstantOfShape decomposition
2 parents 3c980df + 2aa5715 commit 2760083

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

src/Dialect/ONNX/Transforms/Decompose.td

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,10 +517,18 @@ def ReduceSumSquareV13OpPattern2
517517
// `ONNXExpandOp(ONNXConstantOp {value}, %shape)
518518
//===----------------------------------------------------------------------===//
519519

520-
def ConstantOfShapePattern: Pat<
520+
def ConstantOfShapePattern1: Pat<
521521
(ONNXConstantOfShapeOp:$res $shape, $value),
522522
(ONNXExpandOp (ONNXConstantOpFromDenseAttr (ReshapeElementsAttrToRank0 $value)),
523-
$shape)
523+
$shape),
524+
[(AttributeIsNotNull:$value)]
525+
>;
526+
527+
def ConstantOfShapePattern2: Pat<
528+
(ONNXConstantOfShapeOp:$res $shape, $value),
529+
(ONNXExpandOp (ONNXConstantOpFromDenseAttr (ReshapeElementsAttrToRank0 (createDenseArrayAttrFromSingleAttr (GetZeroFloatAttr)))),
530+
$shape),
531+
[(AttributeIsNull:$value)]
524532
>;
525533

526534
//===----------------------------------------------------------------------===//

test/mlir/onnx/onnx_decompose.mlir

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -572,11 +572,11 @@ func.func @test_concatfuse_2(%arg0: tensor<?x20xf32>, %arg1: tensor<?x30xf32>) -
572572

573573
// -----
574574

575-
func.func @test_constantofshape(%arg0: tensor<?xi64>) -> tensor<*xi32> {
575+
func.func @test_constantofshape_1(%arg0: tensor<?xi64>) -> tensor<*xi32> {
576576
%0 = onnx.ConstantOfShape(%arg0) {value = dense<1> : tensor<1xi32>} : (tensor<?xi64>) -> tensor<*xi32>
577577
return %0 : tensor<*xi32>
578578

579-
// CHECK-LABEL: func.func @test_constantofshape
579+
// CHECK-LABEL: func.func @test_constantofshape_1
580580
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?xi64>) -> tensor<*xi32> {
581581
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<1> : tensor<i32>
582582
// CHECK: [[VAR_1_:%.+]] = "onnx.Expand"([[VAR_0_]], [[PARAM_0_]]) : (tensor<i32>, tensor<?xi64>) -> tensor<*xi32>
@@ -586,6 +586,20 @@ func.func @test_constantofshape(%arg0: tensor<?xi64>) -> tensor<*xi32> {
586586

587587
// -----
588588

589+
func.func @test_constantofshape_2(%arg0: tensor<?xi64>) -> tensor<*xi32> {
590+
%0 = onnx.ConstantOfShape(%arg0) : (tensor<?xi64>) -> tensor<*xi32>
591+
return %0 : tensor<*xi32>
592+
593+
// CHECK-LABEL: func.func @test_constantofshape_2
594+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?xi64>) -> tensor<*xi32> {
595+
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<0.000000e+00> : tensor<f32>
596+
// CHECK: [[VAR_1_:%.+]] = "onnx.Expand"([[VAR_0_]], [[PARAM_0_]]) : (tensor<f32>, tensor<?xi64>) -> tensor<*xi32>
597+
// CHECK: return [[VAR_1_]] : tensor<*xi32>
598+
// CHECK: }
599+
}
600+
601+
// -----
602+
589603
func.func @test_hardswish_f32(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
590604
%0 = "onnx.HardSwish"(%arg0) : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
591605
return %0 : tensor<?x?x?xf32>

0 commit comments

Comments
 (0)