Skip to content

Commit 6781e30

Browse files
authored
Sync from tflite-micro. (tensorflow#161)
1 parent 9ec2773 commit 6781e30

File tree

5 files changed

+28
-15
lines changed

5 files changed

+28
-15
lines changed

src/tensorflow/lite/kernels/internal/reference/reduce.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,22 @@ inline bool QuantizedMeanOrSum(const T* input_data, int32_t input_zero_point,
468468
return true;
469469
}
470470

471+
template <typename T, typename U>
472+
inline bool QuantizedMeanOrSumExtraArgs(
473+
const T* input_data, int32_t input_zero_point, float input_scale,
474+
const int* input_dims, const int input_num_dims, T* output_data,
475+
float output_scale, int32_t output_multiplier, int output_shift,
476+
int32_t output_zero_point, const int* output_dims,
477+
const int output_num_dims, const int* axis, const int num_axis_dimensions,
478+
bool keep_dims, int* temp_index, int* resolved_axis, U* temp_sum,
479+
bool compute_sum) {
480+
return QuantizedMeanOrSum<T, U>(
481+
input_data, input_zero_point, input_scale, input_dims, input_num_dims,
482+
output_data, output_zero_point, output_scale, output_dims,
483+
output_num_dims, axis, num_axis_dimensions, keep_dims, temp_index,
484+
resolved_axis, temp_sum, compute_sum);
485+
}
486+
471487
template <typename T>
472488
inline bool QuantizedReduceProd(const T* input_data, int32_t input_zero_point,
473489
const RuntimeShape& input_shape, T* output_data,

src/tensorflow/lite/micro/kernels/micro_ops.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -67,6 +67,7 @@ TfLiteRegistration Register_LOGICAL_OR();
6767
TfLiteRegistration Register_LOGISTIC();
6868
TfLiteRegistration Register_MAX_POOL_2D();
6969
TfLiteRegistration Register_MIRROR_PAD();
70+
TfLiteRegistration Register_NEG();
7071
TfLiteRegistration Register_PRELU();
7172
TfLiteRegistration Register_MUL();
7273
TfLiteRegistration Register_PAD();
@@ -111,7 +112,6 @@ TfLiteRegistration Register_LOG();
111112
TfLiteRegistration Register_LOGICAL_NOT();
112113
TfLiteRegistration Register_MAXIMUM();
113114
TfLiteRegistration Register_MINIMUM();
114-
TfLiteRegistration Register_NEG();
115115
TfLiteRegistration Register_NOT_EQUAL();
116116
TfLiteRegistration Register_PACK();
117117
TfLiteRegistration Register_RESHAPE();

src/tensorflow/lite/micro/kernels/neg.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -21,9 +21,8 @@ limitations under the License.
2121
#include "tensorflow/lite/micro/micro_log.h"
2222

2323
namespace tflite {
24-
namespace ops {
25-
namespace micro {
26-
namespace neg {
24+
25+
namespace {
2726

2827
constexpr int kInputTensor = 0;
2928
constexpr int kOutputTensor = 0;
@@ -49,12 +48,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
4948
return kTfLiteOk;
5049
}
5150

52-
} // namespace neg
51+
} // namespace
5352

5453
TfLiteRegistration Register_NEG() {
55-
return tflite::micro::RegisterOp(nullptr, nullptr, neg::Eval);
54+
return tflite::micro::RegisterOp(nullptr, nullptr, Eval);
5655
}
5756

58-
} // namespace micro
59-
} // namespace ops
6057
} // namespace tflite

src/tensorflow/lite/micro/kernels/reduce_common.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,12 @@ TfLiteStatus QuantizedMeanOrSum(TfLiteContext* context, TfLiteNode* node,
147147
TfLiteReducerParams* params =
148148
static_cast<TfLiteReducerParams*>(node->builtin_data);
149149

150-
bool result = reference_ops::QuantizedMeanOrSum<T, int32_t>(
150+
bool result = reference_ops::QuantizedMeanOrSumExtraArgs<T, int32_t>(
151151
tflite::micro::GetTensorData<T>(input), op_data->input_zp,
152152
op_data->input_scale, &input->dims->data[0], input->dims->size,
153-
tflite::micro::GetTensorData<T>(output), op_data->output_zp,
154-
op_data->output_scale, &output->dims->data[0], output->dims->size,
153+
tflite::micro::GetTensorData<T>(output), op_data->output_scale,
154+
op_data->multiplier, op_data->shift, op_data->output_zp,
155+
&output->dims->data[0], output->dims->size,
155156
tflite::micro::GetTensorData<int>(axis), op_data->num_axis,
156157
params->keep_dims, temp_index, resolved_axis, temp_sum, compute_sum);
157158
TF_LITE_ENSURE(context, result);

src/tensorflow/lite/micro/micro_mutable_op_resolver.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -390,8 +390,7 @@ class MicroMutableOpResolver : public MicroOpResolver {
390390
}
391391

392392
TfLiteStatus AddNeg() {
393-
return AddBuiltin(BuiltinOperator_NEG, tflite::ops::micro::Register_NEG(),
394-
ParseNeg);
393+
return AddBuiltin(BuiltinOperator_NEG, Register_NEG(), ParseNeg);
395394
}
396395

397396
TfLiteStatus AddNotEqual() {

0 commit comments

Comments
 (0)