@@ -135,6 +135,59 @@ TEST(ShapeRuleTest, reshape_should_infer)
135135 ASSERT_EQ (4 , output_shape.dim (1 ).value ());
136136}
137137
138+ TEST (ShapeRuleTest, reshape_should_infer_incorrect_zero_NEG)
139+ {
140+ auto g = loco::make_graph ();
141+ auto node_reshape = g->nodes ()->create <luci::CircleReshape>();
142+ auto tensor_input = g->nodes ()->create <luci::CircleInput>();
143+ auto shape_by_input = g->nodes ()->create <luci::CircleConst>();
144+
145+ tensor_input->dtype (loco::DataType::S32);
146+ tensor_input->shape ({2 , 4 });
147+ tensor_input->shape_status (luci::ShapeStatus::VALID);
148+
149+ shape_by_input->dtype (loco::DataType::S32);
150+ shape_by_input->size <loco::DataType::S32>(3 );
151+ shape_by_input->at <loco::DataType::S32>(0 ) = 2 ;
152+ shape_by_input->at <loco::DataType::S32>(1 ) = 2 ;
153+ shape_by_input->at <loco::DataType::S32>(2 ) = 0 ;
154+ shape_by_input->shape_status (luci::ShapeStatus::VALID);
155+
156+ node_reshape->tensor (tensor_input);
157+ node_reshape->shape (shape_by_input);
158+
159+ loco::TensorShape output_shape;
160+ luci::sinf::Rule shape_inf_rule;
161+
162+ ASSERT_THROW (shape_inf_rule.infer (node_reshape, output_shape), oops::InternalExn);
163+ }
164+
165+ TEST (ShapeRuleTest, reshape_should_infer_incorrect_target_shape_NEG)
166+ {
167+ auto g = loco::make_graph ();
168+ auto node_reshape = g->nodes ()->create <luci::CircleReshape>();
169+ auto tensor_input = g->nodes ()->create <luci::CircleInput>();
170+ auto shape_by_input = g->nodes ()->create <luci::CircleConst>();
171+
172+ tensor_input->dtype (loco::DataType::S32);
173+ tensor_input->shape ({2 , 4 });
174+ tensor_input->shape_status (luci::ShapeStatus::VALID);
175+
176+ shape_by_input->dtype (loco::DataType::S32);
177+ shape_by_input->size <loco::DataType::S32>(3 );
178+ shape_by_input->at <loco::DataType::S32>(0 ) = 6 ;
179+ shape_by_input->at <loco::DataType::S32>(2 ) = -1 ;
180+ shape_by_input->shape_status (luci::ShapeStatus::VALID);
181+
182+ node_reshape->tensor (tensor_input);
183+ node_reshape->shape (shape_by_input);
184+
185+ loco::TensorShape output_shape;
186+ luci::sinf::Rule shape_inf_rule;
187+
188+ ASSERT_THROW (shape_inf_rule.infer (node_reshape, output_shape), oops::InternalExn);
189+ }
190+
138191TEST (ShapeRuleTest, reshape_by_input_node)
139192{
140193 auto g = loco::make_graph ();
@@ -197,3 +250,79 @@ TEST(ShapeRuleTest, reshape_by_newShape)
197250 ASSERT_EQ (2 , output_shape.dim (0 ).value ());
198251 ASSERT_EQ (12 , output_shape.dim (1 ).value ());
199252}
253+
254+ TEST (ShapeRuleTest, reshape_by_newShape_dynamic)
255+ {
256+ auto g = loco::make_graph ();
257+ auto node_reshape = g->nodes ()->create <luci::CircleReshape>();
258+ auto tensor_input = g->nodes ()->create <luci::CircleInput>();
259+ auto shape_dummy = g->nodes ()->create <luci::CircleOutputDummy>();
260+
261+ tensor_input->dtype (loco::DataType::S32);
262+ tensor_input->shape ({2 , 3 , 4 });
263+ tensor_input->shape_status (luci::ShapeStatus::VALID);
264+
265+ shape_dummy->dtype (loco::DataType::S32);
266+ shape_dummy->shape ({});
267+ shape_dummy->shape_status (luci::ShapeStatus::VALID);
268+
269+ node_reshape->tensor (tensor_input);
270+ node_reshape->shape (shape_dummy);
271+
272+ // reshape to {-1, 12}
273+ node_reshape->newShape ()->rank (2 );
274+ node_reshape->newShape ()->dim (0 ) = -1 ;
275+ node_reshape->newShape ()->dim (1 ) = 12 ;
276+
277+ loco::TensorShape output_shape;
278+ luci::sinf::Rule shape_inf_rule;
279+
280+ ASSERT_TRUE (shape_inf_rule.infer (node_reshape, output_shape));
281+
282+ ASSERT_EQ (2 , output_shape.rank ());
283+ ASSERT_FALSE (output_shape.dim (0 ).known ());
284+ ASSERT_TRUE (output_shape.dim (1 ).known ());
285+ ASSERT_EQ (12 , output_shape.dim (1 ).value ());
286+ }
287+
288+ TEST (ShapeRuleTest, merge_shape_from_newShape_and_input_node)
289+ {
290+ auto g = loco::make_graph ();
291+ auto node_reshape = g->nodes ()->create <luci::CircleReshape>();
292+ auto tensor_input = g->nodes ()->create <luci::CircleInput>();
293+ auto shape_by_input = g->nodes ()->create <luci::CircleConst>();
294+
295+ node_reshape->tensor (tensor_input);
296+
297+ tensor_input->dtype (loco::DataType::S32);
298+ tensor_input->shape ({2 , 3 , 4 , 5 });
299+ tensor_input->shape_status (luci::ShapeStatus::VALID);
300+
301+ shape_by_input->dtype (loco::DataType::S32);
302+ shape_by_input->size <loco::DataType::S32>(3 );
303+ shape_by_input->at <loco::DataType::S32>(0 ) = 2 ;
304+ shape_by_input->at <loco::DataType::S32>(1 ) = -1 ;
305+ shape_by_input->at <loco::DataType::S32>(2 ) = -1 ;
306+ shape_by_input->shape_status (luci::ShapeStatus::VALID);
307+
308+ node_reshape->tensor (tensor_input);
309+ node_reshape->shape (shape_by_input);
310+
311+ node_reshape->newShape ()->rank (3 );
312+ node_reshape->newShape ()->dim (0 ) = -1 ; // unknow here but pass by shape_by_input
313+ node_reshape->newShape ()->dim (1 ) = 12 ;
314+ node_reshape->newShape ()->dim (2 ) = 5 ;
315+
316+ loco::TensorShape output_shape;
317+ luci::sinf::Rule shape_inf_rule;
318+
319+ ASSERT_TRUE (shape_inf_rule.infer (node_reshape, output_shape));
320+
321+ ASSERT_EQ (3 , output_shape.rank ());
322+ ASSERT_TRUE (output_shape.dim (0 ).known ());
323+ EXPECT_EQ (2 , output_shape.dim (0 ).value ());
324+ ASSERT_TRUE (output_shape.dim (1 ).known ());
325+ EXPECT_EQ (12 , output_shape.dim (1 ).value ());
326+ ASSERT_TRUE (output_shape.dim (2 ).known ());
327+ EXPECT_EQ (5 , output_shape.dim (2 ).value ());
328+ }
0 commit comments