Skip to content

Commit c1334a5

Browse files
committed
Merge remote-tracking branch 'upstream/master' into mbencer/ReshapeAdditionalUnitTest
2 parents 02a9832 + d32dbe6 commit c1334a5

File tree

4 files changed

+64
-1
lines changed

4 files changed

+64
-1
lines changed

compiler/luci-interpreter/src/kernels/Gather.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,14 @@ void Gather::configure()
5555
// refer tensorflow/lite/kernels/gather.cc
5656

5757
const Shape &params_shape = params()->shape();
58-
const Shape &indices_shape = indices()->shape();
58+
Shape indices_shape = indices()->shape();
59+
{
60+
// scalar index is treated as a tensor with the shape of [1]
61+
if (indices_shape.num_dims() == 0)
62+
{
63+
indices_shape = Shape({1});
64+
}
65+
}
5966

6067
int axis = _params.axis;
6168
if (axis < 0)

compiler/luci-interpreter/src/kernels/Gather.test.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,32 @@ TEST_F(GatherTest, Simple)
6262
EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({1, 4}));
6363
}
6464

65+
TEST_F(GatherTest, Scalar_Index)
66+
{
67+
std::vector<float> params_data{1.f, 2.f, 3.f, 4.f, 5.f, 6.f};
68+
std::vector<int32_t> indices_data{1};
69+
std::vector<float> ref_output_data{2.f};
70+
71+
Tensor params_tensor =
72+
makeInputTensor<DataType::FLOAT32>({1, 1, 6}, params_data, _memory_manager.get());
73+
Tensor indices_tensor =
74+
makeInputTensor<DataType::S32>(/* scalar */ {}, indices_data, _memory_manager.get());
75+
Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
76+
GatherParams gparams;
77+
78+
gparams.axis = 2;
79+
gparams.batch_dims = 0;
80+
81+
Gather kernel(&params_tensor, &indices_tensor, &output_tensor, gparams);
82+
kernel.configure();
83+
_memory_manager->allocate_memory(output_tensor);
84+
kernel.execute();
85+
86+
EXPECT_THAT(extractTensorData<float>(output_tensor),
87+
::testing::ElementsAreArray(ref_output_data));
88+
EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({1, 1, 1}));
89+
}
90+
6591
TEST_F(GatherTest, Simple_Batch)
6692
{
6793
Shape params_shape = {3, 5};

compiler/luci/service/src/Nodes/CircleReshape.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,10 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
210210
}
211211
if (unknown_dim_index != UINT32_MAX)
212212
{
213+
if (input_element_count % output_element_count != 0)
214+
{
215+
INTERNAL_EXN("Reshape Op cannot infer unknown dimension from inputs.");
216+
}
213217
output_shape.dim(unknown_dim_index) = input_element_count / output_element_count;
214218
}
215219
}

compiler/luci/service/src/Nodes/CircleReshape.test.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,32 @@ TEST(ShapeRuleTest, reshape_zero_rank_mismatch_NEG)
162162
ASSERT_THROW(shape_inf_rule.infer(node_reshape, output_shape), oops::InternalExn);
163163
}
164164

165+
TEST(ShapeRuleTest, reshape_wrong_target_shape_NEG)
166+
{
167+
auto g = loco::make_graph();
168+
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
169+
auto tensor_input = g->nodes()->create<luci::CircleInput>();
170+
auto shape_by_input = g->nodes()->create<luci::CircleConst>();
171+
172+
tensor_input->dtype(loco::DataType::S32);
173+
tensor_input->shape({2, 4});
174+
tensor_input->shape_status(luci::ShapeStatus::VALID);
175+
176+
shape_by_input->dtype(loco::DataType::S32);
177+
shape_by_input->size<loco::DataType::S32>(3);
178+
shape_by_input->at<loco::DataType::S32>(0) = 6;
179+
shape_by_input->at<loco::DataType::S32>(2) = -1;
180+
shape_by_input->shape_status(luci::ShapeStatus::VALID);
181+
182+
node_reshape->tensor(tensor_input);
183+
node_reshape->shape(shape_by_input);
184+
185+
loco::TensorShape output_shape;
186+
luci::sinf::Rule shape_inf_rule;
187+
188+
ASSERT_THROW(shape_inf_rule.infer(node_reshape, output_shape), oops::InternalExn);
189+
}
190+
165191
TEST(ShapeRuleTest, reshape_by_input_node)
166192
{
167193
auto g = loco::make_graph();

0 commit comments

Comments
 (0)