Skip to content

Commit 520608e

Browse files
DavidMansellgunes-arm
authored andcommitted
fix: gemm: don't overread when computing row sums.
Resolves: COMPMID-8396 Signed-off-by: David Mansell <[email protected]> Change-Id: I661a1f6aaf092a2f374adad3a4aa39e024092f39 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/14679 Reviewed-by: Gunes Bayir <[email protected]> Tested-by: Arm Jenkins <[email protected]> Benchmark: Arm Jenkins <[email protected]> Comments-Addressed: Arm Jenkins <[email protected]>
1 parent 24f2767 commit 520608e

File tree

1 file changed

+18
-63
lines changed

1 file changed

+18
-63
lines changed

src/core/NEON/kernels/arm_gemm/quantized.cpp

Lines changed: 18 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2019, 2024 Arm Limited.
2+
* Copyright (c) 2019, 2024-2025 Arm Limited.
33
*
44
* SPDX-License-Identifier: MIT
55
*
@@ -682,20 +682,16 @@ namespace {
682682
template<typename T>
683683
inline int16x8_t accumulate_16(const T *ptr, int16x8_t sum);
684684

685-
/* Load a full 16 byte vector, but mask before accumulation (see above). */
685+
/* Load "odd" bytes */
686686
template<typename T>
687-
inline int16x8_t accumulate_masked_16(const T *ptr, int16x8_t sum, uint64x2_t mask);
688-
689-
/* Load 8 bytes and mask before accumulation. */
690-
template<typename T>
691-
inline int16x8_t accumulate_masked_8(const T *ptr, int16x8_t sum, uint64x2_t mask);
687+
inline int16x8_t accumulate_odds_16(const T *ptr, int16x8_t sum, size_t odds);
692688

693689
/* This function does the actual work for up to 4 rows at a time.
694690
* It's pulled out so we can template on the row count to generate
695691
* the 4 different cases. 4 rows are computed at a time as this
696692
* reduces to a single vector write. */
697693
template<unsigned int rows, typename T>
698-
void compute_some_rows(unsigned int blocks, const T *input, unsigned int in_stride, int32_t *row_bias, unsigned int mask_mode, uint64x2_t mask, int32x4_t offset_mul) {
694+
void compute_some_rows(unsigned int blocks, const T *input, unsigned int in_stride, int32_t *row_bias, size_t odds, int32x4_t offset_mul) {
699695
int16x8_t sums[rows];
700696
int32x4_t finalsums[rows];
701697

@@ -731,14 +727,10 @@ namespace {
731727
}
732728
}
733729

734-
/* Handle the final masked read if needed. */
735-
if (mask_mode > 0) {
730+
/* Handle the final odd read if needed. */
731+
if (odds > 0) {
736732
for (unsigned int r=0; r<rows; r++) {
737-
if (mask_mode == 1) {
738-
sums[r] = accumulate_masked_8(input + (r * in_stride) + (blocks * 16), sums[r], mask);
739-
} else {
740-
sums[r] = accumulate_masked_16(input + (r * in_stride) + (blocks * 16), sums[r], mask);
741-
}
733+
sums[r] = accumulate_odds_16(input + (r * in_stride) + (blocks * 16), sums[r], odds);
742734
}
743735
}
744736

@@ -814,30 +806,13 @@ namespace {
814806
return vpadalq_s8(sum, vld1q_s8(ptr));
815807
}
816808

817-
template<>
818-
int16x8_t row_sum_helpers::accumulate_masked_16(const int8_t *ptr, int16x8_t sum, uint64x2_t mask) {
819-
int8x16_t v = vandq_s8(vld1q_s8(ptr), vreinterpretq_s8_u64(mask));
820-
return vpadalq_s8(sum, v);
821-
}
822-
823-
template<>
824-
int16x8_t row_sum_helpers::accumulate_masked_16(const uint8_t *ptr, int16x8_t sum, uint64x2_t mask) {
825-
uint8x16_t v = vandq_u8(vld1q_u8(ptr), vreinterpretq_u8_u64(mask));
826-
return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), v));
827-
}
828-
829-
template<>
830-
int16x8_t row_sum_helpers::accumulate_masked_8(const int8_t *ptr, int16x8_t sum, uint64x2_t mask) {
831-
int8x16_t v = vcombine_s8(vld1_s8(ptr), vdup_n_s8(0));
832-
v = vreinterpretq_s8_u64(vandq_u64(mask, vreinterpretq_u64_s8(v)));
833-
return vpadalq_s8(sum, v);
834-
}
835-
836-
template<>
837-
int16x8_t row_sum_helpers::accumulate_masked_8(const uint8_t *ptr, int16x8_t sum, uint64x2_t mask) {
838-
uint8x16_t v = vcombine_u8(vld1_u8(ptr), vdup_n_u8(0));
839-
v = vreinterpretq_u8_u64(vandq_u64(mask, vreinterpretq_u64_u8(v)));
840-
return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), v));
809+
template<typename T>
810+
int16x8_t row_sum_helpers::accumulate_odds_16(const T *ptr, int16x8_t sum, size_t odds) {
811+
T buffer[16] = {};
812+
for(size_t i=0; i<odds; i++) {
813+
buffer[i] = ptr[i];
814+
}
815+
return accumulate_16(buffer, sum);
841816
}
842817
}
843818

@@ -859,40 +834,20 @@ void compute_row_sums(const Requantize32 &qp, unsigned int width, unsigned int h
859834
unsigned int blocks = (width / 16);
860835
const unsigned int odds = width % 16;
861836

862-
/* Generate a mask to use on the last iteration, if necessary. */
863-
uint64x2_t mask = vdupq_n_u64(0);
864-
unsigned int mask_mode = 0;
865-
866-
if (odds > 0 && odds <= 8) {
867-
/* 1-8 odds: mask in the low lane, 0 in the top */
868-
uint64_t maskval = (~0ULL) >> (8 * (8-odds));
869-
870-
mask = vsetq_lane_u64(maskval, vdupq_n_u64(0), 0);
871-
872-
mask_mode = 1;
873-
} else if (odds > 8) {
874-
/* 9-15 odds: mask in the top lane, all 1s in the bottom. */
875-
uint64_t maskval = (~0ULL) >> (8 * (16-odds));
876-
877-
mask = vsetq_lane_u64(maskval, vdupq_n_u64(~0ULL), 1);
878-
879-
mask_mode = 2;
880-
}
881-
882837
for (unsigned int row=0; row<height; row+=4) {
883838
switch(height-row) {
884839
default:
885840
case 4:
886-
thehelpers.compute_some_rows<4>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
841+
thehelpers.compute_some_rows<4>(blocks, input + (row * in_stride), in_stride, row_bias + row, odds, offset_mul);
887842
break;
888843
case 3:
889-
thehelpers.compute_some_rows<3>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
844+
thehelpers.compute_some_rows<3>(blocks, input + (row * in_stride), in_stride, row_bias + row, odds, offset_mul);
890845
break;
891846
case 2:
892-
thehelpers.compute_some_rows<2>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
847+
thehelpers.compute_some_rows<2>(blocks, input + (row * in_stride), in_stride, row_bias + row, odds, offset_mul);
893848
break;
894849
case 1:
895-
thehelpers.compute_some_rows<1>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
850+
thehelpers.compute_some_rows<1>(blocks, input + (row * in_stride), in_stride, row_bias + row, odds, offset_mul);
896851
break;
897852
}
898853
}

0 commit comments

Comments
 (0)