Skip to content

Commit 9350461

Browse files
authored
[luci] Handle dynamic dimensions in Squeeze shape inference (#14857)
This commit skip verification of squeezed dimension if such dimension is dynamic. ONE-DCO-1.0-Signed-off-by: Mateusz Bencer <m.bencer@partner.samsung.com>
1 parent 0d55156 commit 9350461

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

compiler/luci/service/src/CircleShapeInferenceRule.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1238,7 +1238,7 @@ loco::NodeShape infer_squeeze(const luci::CircleSqueeze *node)
12381238
int32_t dim = raw_dim < 0 ? raw_dim + input_shape.rank() : raw_dim;
12391239

12401240
if (dim < 0 || static_cast<uint32_t>(dim) >= input_shape.rank() ||
1241-
input_shape.dim(dim).value() != 1)
1241+
(input_shape.dim(dim).known() && input_shape.dim(dim).value() != 1))
12421242
{
12431243
INTERNAL_EXN("invalid dimention specified to Squeeze");
12441244
}

compiler/luci/service/src/Nodes/CircleSqueeze.test.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,23 @@ TEST(ShapeRuleTest, squeeze_simple)
4444
ASSERT_EQ(1, shape.dim(2).value());
4545
}
4646

47+
TEST(ShapeRuleTest, squeeze_incorrect_dim_NEG)
48+
{
49+
luci::CircleInput input;
50+
luci::CircleSqueeze squeeze;
51+
52+
input.shape({2, 4, 3, 1});
53+
input.shape_status(luci::ShapeStatus::VALID);
54+
55+
squeeze.input(&input);
56+
squeeze.squeeze_dims({0});
57+
58+
loco::TensorShape shape;
59+
luci::sinf::Rule shape_inf_rule;
60+
61+
ASSERT_THROW(shape_inf_rule.infer(&squeeze, shape), oops::InternalExn);
62+
}
63+
4764
TEST(ShapeRuleTest, squeeze_all)
4865
{
4966
luci::CircleInput input;
@@ -64,6 +81,33 @@ TEST(ShapeRuleTest, squeeze_all)
6481
ASSERT_EQ(3, shape.dim(1).value());
6582
}
6683

84+
TEST(ShapeRuleTest, squeeze_dyn_squeezed_dims)
85+
{
86+
luci::CircleInput input;
87+
luci::CircleSqueeze squeeze;
88+
89+
input.rank(5);
90+
input.dim(0) = loco::Dimension(1);
91+
input.dim(1) = loco::Dimension();
92+
input.dim(2) = loco::Dimension(4);
93+
input.dim(3) = loco::Dimension();
94+
input.dim(4) = loco::Dimension(1);
95+
input.shape_status(luci::ShapeStatus::VALID);
96+
97+
squeeze.input(&input);
98+
squeeze.squeeze_dims({4});
99+
100+
loco::TensorShape shape;
101+
luci::sinf::Rule shape_inf_rule;
102+
103+
ASSERT_TRUE(shape_inf_rule.infer(&squeeze, shape));
104+
ASSERT_EQ(4, shape.rank());
105+
ASSERT_EQ(1, shape.dim(0).value());
106+
ASSERT_FALSE(shape.dim(1).known());
107+
ASSERT_EQ(4, shape.dim(2).value());
108+
ASSERT_FALSE(shape.dim(3).known());
109+
}
110+
67111
TEST(CloneNodeTest, clone_Squeeze)
68112
{
69113
auto g = loco::make_graph();

0 commit comments

Comments
 (0)