diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp index 80229d4566f..944c0f28a99 100644 --- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp +++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp @@ -1238,7 +1238,7 @@ loco::NodeShape infer_squeeze(const luci::CircleSqueeze *node) int32_t dim = raw_dim < 0 ? raw_dim + input_shape.rank() : raw_dim; if (dim < 0 || static_cast(dim) >= input_shape.rank() || - input_shape.dim(dim).value() != 1) + (input_shape.dim(dim).known() && input_shape.dim(dim).value() != 1)) { INTERNAL_EXN("invalid dimention specified to Squeeze"); } diff --git a/compiler/luci/service/src/Nodes/CircleSqueeze.test.cpp b/compiler/luci/service/src/Nodes/CircleSqueeze.test.cpp index bc73eafa715..8de0cb7924c 100644 --- a/compiler/luci/service/src/Nodes/CircleSqueeze.test.cpp +++ b/compiler/luci/service/src/Nodes/CircleSqueeze.test.cpp @@ -44,6 +44,23 @@ TEST(ShapeRuleTest, squeeze_simple) ASSERT_EQ(1, shape.dim(2).value()); } +TEST(ShapeRuleTest, squeeze_incorrect_dim_NEG) +{ + luci::CircleInput input; + luci::CircleSqueeze squeeze; + + input.shape({2, 4, 3, 1}); + input.shape_status(luci::ShapeStatus::VALID); + + squeeze.input(&input); + squeeze.squeeze_dims({0}); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_THROW(shape_inf_rule.infer(&squeeze, shape), oops::InternalExn); +} + TEST(ShapeRuleTest, squeeze_all) { luci::CircleInput input; @@ -64,6 +81,33 @@ TEST(ShapeRuleTest, squeeze_all) ASSERT_EQ(3, shape.dim(1).value()); } +TEST(ShapeRuleTest, squeeze_dyn_squeezed_dims) +{ + luci::CircleInput input; + luci::CircleSqueeze squeeze; + + input.rank(5); + input.dim(0) = loco::Dimension(1); + input.dim(1) = loco::Dimension(); + input.dim(2) = loco::Dimension(4); + input.dim(3) = loco::Dimension(); + input.dim(4) = loco::Dimension(1); + input.shape_status(luci::ShapeStatus::VALID); + + squeeze.input(&input); + squeeze.squeeze_dims({4}); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(&squeeze, shape)); + ASSERT_EQ(4, shape.rank()); + ASSERT_EQ(1, shape.dim(0).value()); + ASSERT_FALSE(shape.dim(1).known()); + ASSERT_EQ(4, shape.dim(2).value()); + ASSERT_FALSE(shape.dim(3).known()); +} + TEST(CloneNodeTest, clone_Squeeze) { auto g = loco::make_graph();