Skip to content

Commit d32dbe6

Browse files
authored
[luci] Add additional check of unknown dimension for Reshape (#14858)
This commit adds checking of number of remaining elements in order to confirm if it matches expected output shape. ONE-DCO-1.0-Signed-off-by: Mateusz Bencer <m.bencer@partner.samsung.com>
1 parent daa19ea commit d32dbe6

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

compiler/luci/service/src/Nodes/CircleReshape.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,10 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
210210
}
211211
if (unknown_dim_index != UINT32_MAX)
212212
{
213+
if (input_element_count % output_element_count != 0)
214+
{
215+
INTERNAL_EXN("Reshape Op cannot infer unknown dimension from inputs.");
216+
}
213217
output_shape.dim(unknown_dim_index) = input_element_count / output_element_count;
214218
}
215219
}

compiler/luci/service/src/Nodes/CircleReshape.test.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,32 @@ TEST(ShapeRuleTest, reshape_should_infer)
135135
ASSERT_EQ(4, output_shape.dim(1).value());
136136
}
137137

138+
TEST(ShapeRuleTest, reshape_wrong_target_shape_NEG)
139+
{
140+
auto g = loco::make_graph();
141+
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
142+
auto tensor_input = g->nodes()->create<luci::CircleInput>();
143+
auto shape_by_input = g->nodes()->create<luci::CircleConst>();
144+
145+
tensor_input->dtype(loco::DataType::S32);
146+
tensor_input->shape({2, 4});
147+
tensor_input->shape_status(luci::ShapeStatus::VALID);
148+
149+
shape_by_input->dtype(loco::DataType::S32);
150+
shape_by_input->size<loco::DataType::S32>(3);
151+
shape_by_input->at<loco::DataType::S32>(0) = 6;
152+
shape_by_input->at<loco::DataType::S32>(2) = -1;
153+
shape_by_input->shape_status(luci::ShapeStatus::VALID);
154+
155+
node_reshape->tensor(tensor_input);
156+
node_reshape->shape(shape_by_input);
157+
158+
loco::TensorShape output_shape;
159+
luci::sinf::Rule shape_inf_rule;
160+
161+
ASSERT_THROW(shape_inf_rule.infer(node_reshape, output_shape), oops::InternalExn);
162+
}
163+
138164
TEST(ShapeRuleTest, reshape_by_input_node)
139165
{
140166
auto g = loco::make_graph();

0 commit comments

Comments
 (0)