@@ -268,11 +268,11 @@ inline bool Mean(const T* input_data, const int* input_dims,
268268 return true ;
269269}
270270
271- template <typename T>
272271inline void Mean (const tflite::MeanParams& op_params,
273272 const RuntimeShape& unextended_input_shape,
274- const T* input_data,
275- const RuntimeShape& unextended_output_shape, T* output_data) {
273+ const float * input_data,
274+ const RuntimeShape& unextended_output_shape,
275+ float * output_data) {
276276 ruy::profiler::ScopeLabel label (" Mean4D" );
277277
278278 // Current implementation only supports dimension equals 4 and simultaneous
@@ -312,78 +312,21 @@ inline void Mean(const tflite::MeanParams& op_params,
312312 }
313313}
314314
315- inline void Mean (const tflite::MeanParams& op_params,
316- const RuntimeShape& unextended_input_shape,
317- const uint8_t * input_data, int32_t input_zero_point,
318- float input_scale, const RuntimeShape& unextended_output_shape,
319- uint8_t * output_data, int32_t output_zero_point,
320- float output_scale) {
321- ruy::profiler::ScopeLabel label (" Mean4D/Uint8" );
322-
323- // Current implementation only supports dimension equals 4 and simultaneous
324- // reduction over width and height.
325- TFLITE_CHECK_EQ (unextended_input_shape.DimensionsCount (), 4 );
326- TFLITE_CHECK_LE (unextended_output_shape.DimensionsCount (), 4 );
327- const RuntimeShape input_shape =
328- RuntimeShape::ExtendedShape (4 , unextended_input_shape);
329- const RuntimeShape output_shape =
330- RuntimeShape::ExtendedShape (4 , unextended_output_shape);
331- const int output_batch = output_shape.Dims (0 );
332- const int output_height = output_shape.Dims (1 );
333- const int output_width = output_shape.Dims (2 );
334- const int output_depth = output_shape.Dims (3 );
335- const int input_height = input_shape.Dims (1 );
336- const int input_width = input_shape.Dims (2 );
337- const float num_elements_in_axis = input_width * input_height;
338-
339- TFLITE_CHECK_EQ (op_params.axis_count , 2 );
340- TFLITE_CHECK ((op_params.axis [0 ] == 1 && op_params.axis [1 ] == 2 ) ||
341- (op_params.axis [0 ] == 2 && op_params.axis [1 ] == 1 ));
342- TFLITE_CHECK_EQ (output_height, 1 );
343- TFLITE_CHECK_EQ (output_width, 1 );
344-
345- constexpr int32_t kMinValue = std::numeric_limits<uint8_t >::min ();
346- constexpr int32_t kMaxValue = std::numeric_limits<uint8_t >::max ();
347-
348- float temp = input_zero_point * input_scale / output_scale;
349- temp = temp > 0 ? temp + 0 .5f : temp - 0 .5f ;
350- int32_t bias = output_zero_point - static_cast <int32_t >(temp);
351- double real_scale =
352- static_cast <double >(input_scale / (num_elements_in_axis * output_scale));
353-
354- int32_t multiplier;
355- int shift;
356- QuantizeMultiplier (real_scale, &multiplier, &shift);
357- for (int out_b = 0 ; out_b < output_batch; ++out_b) {
358- for (int out_d = 0 ; out_d < output_depth; ++out_d) {
359- int32_t acc = 0 ;
360- for (int in_h = 0 ; in_h < input_height; ++in_h) {
361- for (int in_w = 0 ; in_w < input_width; ++in_w) {
362- acc += input_data[Offset (input_shape, out_b, in_h, in_w, out_d)];
363- }
364- }
365- acc = MultiplyByQuantizedMultiplier (acc, multiplier, shift);
366- acc += bias;
367- acc = std::min (std::max (acc, kMinValue ), kMaxValue );
368- output_data[Offset (output_shape, out_b, 0 , 0 , out_d)] =
369- static_cast <uint8_t >(acc);
370- }
371- }
372- }
373-
374315// Computes the mean of elements across dimensions given in axis.
375316// It does so in two stages, first calculates the sum of elements along the axis
376317// then divides it by the number of element in axis for quantized values.
377318template <typename T, typename U>
378319inline bool QuantizedMeanOrSum (const T* input_data, int32_t input_zero_point,
379- float input_scale , const int * input_dims ,
380- const int input_num_dims, T* output_data,
381- int32_t output_zero_point, float output_scale ,
320+ const int * input_dims , const int input_num_dims ,
321+ T* output_data, int32_t output_multiplier ,
322+ int output_shift, int32_t output_zero_point ,
382323 const int * output_dims,
383324 const int output_num_dims, const int * axis,
384325 const int num_axis_dimensions, bool keep_dims,
385326 int * temp_index, int * resolved_axis, U* temp_sum,
386327 bool compute_sum) {
328+ const int32_t kMinValue = std::numeric_limits<T>::min ();
329+ const int32_t kMaxValue = std::numeric_limits<T>::max ();
387330 const bool uint8_case = std::is_same<T, uint8_t >::value;
388331 const bool int16_case = std::is_same<T, int16_t >::value;
389332 if (uint8_case) {
@@ -430,40 +373,46 @@ inline bool QuantizedMeanOrSum(const T* input_data, int32_t input_zero_point,
430373 }
431374
432375 // Calculate mean by dividing output_data by num of aggregated element.
433- size_t num_elements_in_axis = 1 ;
376+ int64_t num_elements_in_axis = 1 ;
434377 for (int idx = 0 ; idx < num_resolved_axis; ++idx) {
435378 size_t current = static_cast <size_t >(input_dims[resolved_axis[idx]]);
436379 // Overflow prevention.
437- if (current > (std::numeric_limits<size_t >::max () / num_elements_in_axis)) {
380+ if (current > static_cast <size_t >(std::numeric_limits<int64_t >::max () /
381+ num_elements_in_axis)) {
438382 return false ;
439383 }
440384 num_elements_in_axis *= current;
441385 }
442386
443- if (num_elements_in_axis > 0 ) {
444- const float scale = input_scale / output_scale;
445- if (compute_sum) {
446- // TODO(b/116341117): Eliminate float and do this completely in 8bit.
447- const float bias = -input_zero_point * scale * num_elements_in_axis;
448- for (size_t idx = 0 ; idx < num_outputs; ++idx) {
449- const U value =
450- static_cast <U>(TfLiteRound (temp_sum[idx] * scale + bias)) +
451- output_zero_point;
452- output_data[idx] = static_cast <T>(value);
453- }
454- } else {
455- const float bias = -input_zero_point * scale;
456- for (size_t idx = 0 ; idx < num_outputs; ++idx) {
457- float float_mean = static_cast <float >(temp_sum[idx]) /
458- static_cast <float >(num_elements_in_axis);
459- float result = TfLiteMin (
460- TfLiteRound (float_mean * scale + bias) + output_zero_point,
461- static_cast <float >(std::numeric_limits<T>::max ()));
462- result = TfLiteMax (result,
463- static_cast <float >(std::numeric_limits<T>::min ()));
464- output_data[idx] = static_cast <T>(result);
465- }
466- }
387+ if (num_elements_in_axis == 0 ) {
388+ return true ;
389+ }
390+
391+ // Readapt output rescaling when calculating the mean to integrate a
392+ // 1/num_elements_in_axis multiplier.
393+ if (!compute_sum) {
394+ TFLITE_DCHECK_GE (num_elements_in_axis, 0 );
395+ int shift =
396+ 63 - CountLeadingZeros (static_cast <uint64_t >(num_elements_in_axis));
397+ // To avoid any overflow risk 'shift' should be <= 32 and to satisfy
398+ // 'MultiplyByQuantizedMultiplier' pre-conditions 'output_shift - shift'
399+ // should be >= -31. Clamp the value at the price of some precision loss.
400+ shift = std::min (shift, 32 );
401+ shift = std::min (shift, 31 + output_shift);
402+ output_multiplier = static_cast <int32_t >(
403+ (static_cast <int64_t >(output_multiplier) << shift) /
404+ num_elements_in_axis);
405+ output_shift = output_shift - shift;
406+ }
407+
408+ for (size_t idx = 0 ; idx < num_outputs; ++idx) {
409+ const U shifted_sum =
410+ static_cast <U>(temp_sum[idx] - input_zero_point * num_elements_in_axis);
411+ int32_t output = MultiplyByQuantizedMultiplier (
412+ shifted_sum, output_multiplier, output_shift) +
413+ output_zero_point;
414+ output = std::min (std::max (output, kMinValue ), kMaxValue );
415+ output_data[idx] = static_cast <T>(output);
467416 }
468417 return true ;
469418}
@@ -478,8 +427,8 @@ inline bool QuantizedMeanOrSumExtraArgs(
478427 bool keep_dims, int * temp_index, int * resolved_axis, U* temp_sum,
479428 bool compute_sum) {
480429 return QuantizedMeanOrSum<T, U>(
481- input_data, input_zero_point, input_scale, input_dims, input_num_dims,
482- output_data, output_zero_point, output_scale , output_dims,
430+ input_data, input_zero_point, input_dims, input_num_dims, output_data ,
431+ output_multiplier, output_shift, output_zero_point , output_dims,
483432 output_num_dims, axis, num_axis_dimensions, keep_dims, temp_index,
484433 resolved_axis, temp_sum, compute_sum);
485434}
0 commit comments