11/*
2- * Copyright (c) 2019, 2024 Arm Limited.
2+ * Copyright (c) 2019, 2024-2025 Arm Limited.
33 *
44 * SPDX-License-Identifier: MIT
55 *
@@ -682,20 +682,16 @@ namespace {
682682 template <typename T>
683683 inline int16x8_t accumulate_16 (const T *ptr, int16x8_t sum);
684684
685- /* Load a full 16 byte vector, but mask before accumulation (see above). */
685+ /* Load "odd" bytes */
686686 template <typename T>
687- inline int16x8_t accumulate_masked_16 (const T *ptr, int16x8_t sum, uint64x2_t mask);
688-
689- /* Load 8 bytes and mask before accumulation. */
690- template <typename T>
691- inline int16x8_t accumulate_masked_8 (const T *ptr, int16x8_t sum, uint64x2_t mask);
687+ inline int16x8_t accumulate_odds_16 (const T *ptr, int16x8_t sum, size_t odds);
692688
693689 /* This function does the actual work for up to 4 rows at a time.
694690 * It's pulled out so we can template on the row count to generate
695691 * the 4 different cases. 4 rows are computed at a time as this
696692 * reduces to a single vector write. */
697693 template <unsigned int rows, typename T>
698- void compute_some_rows (unsigned int blocks, const T *input, unsigned int in_stride, int32_t *row_bias, unsigned int mask_mode, uint64x2_t mask , int32x4_t offset_mul) {
694+ void compute_some_rows (unsigned int blocks, const T *input, unsigned int in_stride, int32_t *row_bias, size_t odds , int32x4_t offset_mul) {
699695 int16x8_t sums[rows];
700696 int32x4_t finalsums[rows];
701697
@@ -731,14 +727,10 @@ namespace {
731727 }
732728 }
733729
734- /* Handle the final masked read if needed. */
735- if (mask_mode > 0 ) {
730+ /* Handle the final odd read if needed. */
731+ if (odds > 0 ) {
736732 for (unsigned int r=0 ; r<rows; r++) {
737- if (mask_mode == 1 ) {
738- sums[r] = accumulate_masked_8 (input + (r * in_stride) + (blocks * 16 ), sums[r], mask);
739- } else {
740- sums[r] = accumulate_masked_16 (input + (r * in_stride) + (blocks * 16 ), sums[r], mask);
741- }
733+ sums[r] = accumulate_odds_16 (input + (r * in_stride) + (blocks * 16 ), sums[r], odds);
742734 }
743735 }
744736
@@ -814,30 +806,13 @@ namespace {
814806 return vpadalq_s8 (sum, vld1q_s8 (ptr));
815807 }
816808
817- template <>
818- int16x8_t row_sum_helpers::accumulate_masked_16 (const int8_t *ptr, int16x8_t sum, uint64x2_t mask) {
819- int8x16_t v = vandq_s8 (vld1q_s8 (ptr), vreinterpretq_s8_u64 (mask));
820- return vpadalq_s8 (sum, v);
821- }
822-
823- template <>
824- int16x8_t row_sum_helpers::accumulate_masked_16 (const uint8_t *ptr, int16x8_t sum, uint64x2_t mask) {
825- uint8x16_t v = vandq_u8 (vld1q_u8 (ptr), vreinterpretq_u8_u64 (mask));
826- return vreinterpretq_s16_u16 (vpadalq_u8 (vreinterpretq_u16_s16 (sum), v));
827- }
828-
829- template <>
830- int16x8_t row_sum_helpers::accumulate_masked_8 (const int8_t *ptr, int16x8_t sum, uint64x2_t mask) {
831- int8x16_t v = vcombine_s8 (vld1_s8 (ptr), vdup_n_s8 (0 ));
832- v = vreinterpretq_s8_u64 (vandq_u64 (mask, vreinterpretq_u64_s8 (v)));
833- return vpadalq_s8 (sum, v);
834- }
835-
836- template <>
837- int16x8_t row_sum_helpers::accumulate_masked_8 (const uint8_t *ptr, int16x8_t sum, uint64x2_t mask) {
838- uint8x16_t v = vcombine_u8 (vld1_u8 (ptr), vdup_n_u8 (0 ));
839- v = vreinterpretq_u8_u64 (vandq_u64 (mask, vreinterpretq_u64_u8 (v)));
840- return vreinterpretq_s16_u16 (vpadalq_u8 (vreinterpretq_u16_s16 (sum), v));
809+ template <typename T>
810+ int16x8_t row_sum_helpers::accumulate_odds_16 (const T *ptr, int16x8_t sum, size_t odds) {
811+ T buffer[16 ] = {};
812+ for (size_t i=0 ; i<odds; i++) {
813+ buffer[i] = ptr[i];
814+ }
815+ return accumulate_16 (buffer, sum);
841816 }
842817}
843818
@@ -859,40 +834,20 @@ void compute_row_sums(const Requantize32 &qp, unsigned int width, unsigned int h
859834 unsigned int blocks = (width / 16 );
860835 const unsigned int odds = width % 16 ;
861836
862- /* Generate a mask to use on the last iteration, if necessary. */
863- uint64x2_t mask = vdupq_n_u64 (0 );
864- unsigned int mask_mode = 0 ;
865-
866- if (odds > 0 && odds <= 8 ) {
867- /* 1-8 odds: mask in the low lane, 0 in the top */
868- uint64_t maskval = (~0ULL ) >> (8 * (8 -odds));
869-
870- mask = vsetq_lane_u64 (maskval, vdupq_n_u64 (0 ), 0 );
871-
872- mask_mode = 1 ;
873- } else if (odds > 8 ) {
874- /* 9-15 odds: mask in the top lane, all 1s in the bottom. */
875- uint64_t maskval = (~0ULL ) >> (8 * (16 -odds));
876-
877- mask = vsetq_lane_u64 (maskval, vdupq_n_u64 (~0ULL ), 1 );
878-
879- mask_mode = 2 ;
880- }
881-
882837 for (unsigned int row=0 ; row<height; row+=4 ) {
883838 switch (height-row) {
884839 default :
885840 case 4 :
886- thehelpers.compute_some_rows <4 >(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask , offset_mul);
841+ thehelpers.compute_some_rows <4 >(blocks, input + (row * in_stride), in_stride, row_bias + row, odds , offset_mul);
887842 break ;
888843 case 3 :
889- thehelpers.compute_some_rows <3 >(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask , offset_mul);
844+ thehelpers.compute_some_rows <3 >(blocks, input + (row * in_stride), in_stride, row_bias + row, odds , offset_mul);
890845 break ;
891846 case 2 :
892- thehelpers.compute_some_rows <2 >(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask , offset_mul);
847+ thehelpers.compute_some_rows <2 >(blocks, input + (row * in_stride), in_stride, row_bias + row, odds , offset_mul);
893848 break ;
894849 case 1 :
895- thehelpers.compute_some_rows <1 >(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask , offset_mul);
850+ thehelpers.compute_some_rows <1 >(blocks, input + (row * in_stride), in_stride, row_bias + row, odds , offset_mul);
896851 break ;
897852 }
898853 }
0 commit comments