Skip to content

Commit d57f2e7

Browse files
committed
[luci] Merge target shape from input node and attribute
This commit improve mechanism of shape inference for Reshape operator. If some dimension from input node is unknown we are trying to find such information in attribute. ONE-DCO-1.0-Signed-off-by: Mateusz Bencer <m.bencer@partner.samsung.com>
1 parent 0dfc2a5 commit d57f2e7

File tree

2 files changed

+163
-2
lines changed

2 files changed

+163
-2
lines changed

compiler/luci/service/src/Nodes/CircleReshape.cpp

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,27 @@ luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleReshape *node)
6666
namespace sinf
6767
{
6868

69+
namespace
70+
{
71+
loco::TensorShape merge_shapes(const loco::TensorShape &base_shape,
72+
const loco::TensorShape &merged_shape)
73+
{
74+
loco::TensorShape result_shape = base_shape;
75+
if (base_shape.rank() == merged_shape.rank())
76+
{
77+
for (uint32_t axis = 0; axis < base_shape.rank(); ++axis)
78+
{
79+
if (!base_shape.dim(axis).known() && merged_shape.dim(axis).known())
80+
{
81+
result_shape.dim(axis) = merged_shape.dim(axis);
82+
}
83+
}
84+
}
85+
return result_shape;
86+
}
87+
88+
} // namespace
89+
6990
loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
7091
{
7192
LOGGER(l);
@@ -154,7 +175,14 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
154175

155176
for (uint32_t axis = 0; axis < shape_by_attr.rank(); ++axis)
156177
{
157-
shape_by_attr.dim(axis) = node->newShape()->dim(axis);
178+
if (node->newShape()->dim(axis) > 0)
179+
{
180+
shape_by_attr.dim(axis) = node->newShape()->dim(axis);
181+
}
182+
else
183+
{
184+
shape_by_attr.dim(axis).unset(); // unset means unknown dimension
185+
}
158186
}
159187
}
160188

@@ -165,7 +193,7 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
165193
INFO(l) << " shape_by_attr : " << shape_by_attr << std::endl;
166194
}
167195

168-
loco::TensorShape output_shape = shape_by_input;
196+
loco::TensorShape output_shape = merge_shapes(shape_by_input, shape_by_attr);
169197

170198
// One of the dimensions can have special value -1, meaning its actual value should be inferred.
171199
const auto input = loco::must_cast<luci::CircleNode *>(node->tensor());
@@ -210,6 +238,10 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
210238
}
211239
if (unknown_dim_index != UINT32_MAX)
212240
{
241+
if (input_element_count % output_element_count != 0)
242+
{
243+
INTERNAL_EXN("Unknown output dimension cannot be calculated for inputs");
244+
}
213245
output_shape.dim(unknown_dim_index) = input_element_count / output_element_count;
214246
}
215247
}

compiler/luci/service/src/Nodes/CircleReshape.test.cpp

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,59 @@ TEST(ShapeRuleTest, reshape_should_infer)
135135
ASSERT_EQ(4, output_shape.dim(1).value());
136136
}
137137

138+
TEST(ShapeRuleTest, reshape_should_infer_incorrect_zero_NEG)
139+
{
140+
auto g = loco::make_graph();
141+
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
142+
auto tensor_input = g->nodes()->create<luci::CircleInput>();
143+
auto shape_by_input = g->nodes()->create<luci::CircleConst>();
144+
145+
tensor_input->dtype(loco::DataType::S32);
146+
tensor_input->shape({2, 4});
147+
tensor_input->shape_status(luci::ShapeStatus::VALID);
148+
149+
shape_by_input->dtype(loco::DataType::S32);
150+
shape_by_input->size<loco::DataType::S32>(3);
151+
shape_by_input->at<loco::DataType::S32>(0) = 2;
152+
shape_by_input->at<loco::DataType::S32>(1) = 2;
153+
shape_by_input->at<loco::DataType::S32>(2) = 0;
154+
shape_by_input->shape_status(luci::ShapeStatus::VALID);
155+
156+
node_reshape->tensor(tensor_input);
157+
node_reshape->shape(shape_by_input);
158+
159+
loco::TensorShape output_shape;
160+
luci::sinf::Rule shape_inf_rule;
161+
162+
ASSERT_THROW(shape_inf_rule.infer(node_reshape, output_shape), oops::InternalExn);
163+
}
164+
165+
TEST(ShapeRuleTest, reshape_should_infer_incorrect_target_shape_NEG)
166+
{
167+
auto g = loco::make_graph();
168+
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
169+
auto tensor_input = g->nodes()->create<luci::CircleInput>();
170+
auto shape_by_input = g->nodes()->create<luci::CircleConst>();
171+
172+
tensor_input->dtype(loco::DataType::S32);
173+
tensor_input->shape({2, 4});
174+
tensor_input->shape_status(luci::ShapeStatus::VALID);
175+
176+
shape_by_input->dtype(loco::DataType::S32);
177+
shape_by_input->size<loco::DataType::S32>(3);
178+
shape_by_input->at<loco::DataType::S32>(0) = 6;
179+
shape_by_input->at<loco::DataType::S32>(2) = -1;
180+
shape_by_input->shape_status(luci::ShapeStatus::VALID);
181+
182+
node_reshape->tensor(tensor_input);
183+
node_reshape->shape(shape_by_input);
184+
185+
loco::TensorShape output_shape;
186+
luci::sinf::Rule shape_inf_rule;
187+
188+
ASSERT_THROW(shape_inf_rule.infer(node_reshape, output_shape), oops::InternalExn);
189+
}
190+
138191
TEST(ShapeRuleTest, reshape_by_input_node)
139192
{
140193
auto g = loco::make_graph();
@@ -197,3 +250,79 @@ TEST(ShapeRuleTest, reshape_by_newShape)
197250
ASSERT_EQ(2, output_shape.dim(0).value());
198251
ASSERT_EQ(12, output_shape.dim(1).value());
199252
}
253+
254+
TEST(ShapeRuleTest, reshape_by_newShape_dynamic)
255+
{
256+
auto g = loco::make_graph();
257+
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
258+
auto tensor_input = g->nodes()->create<luci::CircleInput>();
259+
auto shape_dummy = g->nodes()->create<luci::CircleOutputDummy>();
260+
261+
tensor_input->dtype(loco::DataType::S32);
262+
tensor_input->shape({2, 3, 4});
263+
tensor_input->shape_status(luci::ShapeStatus::VALID);
264+
265+
shape_dummy->dtype(loco::DataType::S32);
266+
shape_dummy->shape({});
267+
shape_dummy->shape_status(luci::ShapeStatus::VALID);
268+
269+
node_reshape->tensor(tensor_input);
270+
node_reshape->shape(shape_dummy);
271+
272+
// reshape to {-1, 12}
273+
node_reshape->newShape()->rank(2);
274+
node_reshape->newShape()->dim(0) = -1;
275+
node_reshape->newShape()->dim(1) = 12;
276+
277+
loco::TensorShape output_shape;
278+
luci::sinf::Rule shape_inf_rule;
279+
280+
ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape));
281+
282+
ASSERT_EQ(2, output_shape.rank());
283+
ASSERT_FALSE(output_shape.dim(0).known());
284+
ASSERT_TRUE(output_shape.dim(1).known());
285+
ASSERT_EQ(12, output_shape.dim(1).value());
286+
}
287+
288+
TEST(ShapeRuleTest, merge_shape_from_newShape_and_input_node)
289+
{
290+
auto g = loco::make_graph();
291+
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
292+
auto tensor_input = g->nodes()->create<luci::CircleInput>();
293+
auto shape_by_input = g->nodes()->create<luci::CircleConst>();
294+
295+
node_reshape->tensor(tensor_input);
296+
297+
tensor_input->dtype(loco::DataType::S32);
298+
tensor_input->shape({2, 3, 4, 5});
299+
tensor_input->shape_status(luci::ShapeStatus::VALID);
300+
301+
shape_by_input->dtype(loco::DataType::S32);
302+
shape_by_input->size<loco::DataType::S32>(3);
303+
shape_by_input->at<loco::DataType::S32>(0) = 2;
304+
shape_by_input->at<loco::DataType::S32>(1) = -1;
305+
shape_by_input->at<loco::DataType::S32>(2) = -1;
306+
shape_by_input->shape_status(luci::ShapeStatus::VALID);
307+
308+
node_reshape->tensor(tensor_input);
309+
node_reshape->shape(shape_by_input);
310+
311+
node_reshape->newShape()->rank(3);
312+
node_reshape->newShape()->dim(0) = -1; // unknow here but pass by shape_by_input
313+
node_reshape->newShape()->dim(1) = 12;
314+
node_reshape->newShape()->dim(2) = 5;
315+
316+
loco::TensorShape output_shape;
317+
luci::sinf::Rule shape_inf_rule;
318+
319+
ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape));
320+
321+
ASSERT_EQ(3, output_shape.rank());
322+
ASSERT_TRUE(output_shape.dim(0).known());
323+
EXPECT_EQ(2, output_shape.dim(0).value());
324+
ASSERT_TRUE(output_shape.dim(1).known());
325+
EXPECT_EQ(12, output_shape.dim(1).value());
326+
ASSERT_TRUE(output_shape.dim(2).known());
327+
EXPECT_EQ(5, output_shape.dim(2).value());
328+
}

0 commit comments

Comments
 (0)