@@ -88,29 +88,38 @@ inline float avx2_inner_product(const V& a, const W& b) {
8888
8989template <feature_vector V, feature_vector W>
9090 requires std::same_as<typename V::value_type, float > &&
91- std::same_as<typename W::value_type, uint8_t >
91+ (std::same_as<typename W::value_type, uint8_t > ||
92+ std::same_as<typename W::value_type, int8_t >)
9293inline float avx2_inner_product (const V& a, const W& b) {
9394 // @todo Align on 256 bit boundaries
9495 const size_t start = 0 ;
9596 const size_t size_a = size (a);
9697 const size_t stop = size_a - (size_a % 8 );
9798
9899 const float * a_ptr = a.data ();
99- const uint8_t * b_ptr = b.data ();
100+ // Can be uint8_t* or int8_t*
101+ const auto * b_ptr = b.data ();
100102
101103 __m256 vec_sum = _mm256_setzero_ps ();
102104
103105 for (size_t i = start; i < stop; i += 8 ) {
104106 // Load 8 floats
105- __m256 a_floats = _mm256_loadu_ps (a_ptr + i + 0 );
107+ __m256 a_floats = _mm256_loadu_ps (a_ptr + i);
106108
107- // Load 8 bytes
109+ // Load 8 bytes (uint8_t or int8_t)
108110 __m128i vec_b = _mm_loadu_si64 ((__m64*)(b_ptr + i));
109111
110- // Zero extend 8bit to 32bit ints
111- __m256i b_ints = _mm256_cvtepu8_epi32 (vec_b);
112-
113- // Convert signed integers to floats
112+ // Conditionally convert based on the type of W::value_type
113+ __m256i b_ints;
114+ if constexpr (std::same_as<typename W::value_type, uint8_t >) {
115+ // Zero extend uint8_t to int32_t
116+ b_ints = _mm256_cvtepu8_epi32 (vec_b);
117+ } else if constexpr (std::same_as<typename W::value_type, int8_t >) {
118+ // Sign extend int8_t to int32_t
119+ b_ints = _mm256_cvtepi8_epi32 (vec_b);
120+ }
121+
122+ // Convert the 32-bit integers to floats
114123 __m256 b_floats = _mm256_cvtepi32_ps (b_ints);
115124
116125 // Multiply and accumulate
@@ -139,30 +148,39 @@ inline float avx2_inner_product(const V& a, const W& b) {
139148}
140149
141150template <feature_vector V, feature_vector W>
142- requires std::same_as<typename V::value_type, uint8_t > &&
143- std::same_as<typename W::value_type, float >
151+ requires (std::same_as<typename V::value_type, uint8_t > ||
152+ std::same_as<typename V::value_type, int8_t >) &&
153+ std::same_as<typename W::value_type, float>
144154inline float avx2_inner_product(const V& a, const W& b) {
145155 // @todo Align on 256 bit boundaries
146156 const size_t start = 0 ;
147157 const size_t size_a = size (a);
148158 const size_t stop = size_a - (size_a % 8 );
149159
150- const uint8_t * a_ptr = a.data ();
160+ // Can be uint8_t* or int8_t*
161+ const auto * a_ptr = a.data ();
151162 const float * b_ptr = b.data ();
152163
153164 __m256 vec_sum = _mm256_setzero_ps ();
154165
155166 for (size_t i = start; i < stop; i += 8 ) {
156- // Load 8 bytes == 64 bits -- zeros out top 8 bytes
167+ // Load 8 bytes (either uint8_t or int8_t)
157168 __m128i vec_a = _mm_loadu_si64 ((__m64*)(a_ptr + i));
158169
159170 // Load 8 floats
160171 __m256 b_floats = _mm256_loadu_ps (b_ptr + i + 0 );
161172
162- // Zero extend 8bit to 32bit ints
163- __m256i a_ints = _mm256_cvtepu8_epi32 (vec_a);
164-
165- // Convert signed integers to floats
173+ // Extend 8 bit to 32 bit ints
174+ __m256i a_ints;
175+ if constexpr (std::same_as<typename V::value_type, uint8_t >) {
176+ // Zero extend uint8_t to int32_t
177+ a_ints = _mm256_cvtepu8_epi32 (vec_a);
178+ } else if constexpr (std::same_as<typename V::value_type, int8_t >) {
179+ // Sign extend int8_t to int32_t
180+ a_ints = _mm256_cvtepi8_epi32 (vec_a);
181+ }
182+
183+ // Convert the 32-bit integers to floats
166184 __m256 a_floats = _mm256_cvtepi32_ps (a_ints);
167185
168186 // Multiply and accumulate
@@ -191,29 +209,49 @@ inline float avx2_inner_product(const V& a, const W& b) {
191209}
192210
193211template <feature_vector V, feature_vector W>
194- requires std::same_as<typename V::value_type, uint8_t > &&
195- std::same_as<typename W::value_type, uint8_t >
212+ requires (std::same_as<typename V::value_type, uint8_t > ||
213+ std::same_as<typename V::value_type, int8_t >) &&
214+ (std::same_as<typename W::value_type, uint8_t > ||
215+ std::same_as<typename W::value_type, int8_t >)
196216inline float avx2_inner_product(const V& a, const W& b) {
197217 // @todo Align on 256 bit boundaries
198218 const size_t start = 0 ;
199219 const size_t size_a = size (a);
200220 const size_t stop = size_a - (size_a % 8 );
201221
202- const uint8_t * a_ptr = a.data ();
203- const uint8_t * b_ptr = b.data ();
222+ // Can be either uint8_t* or int8_t*
223+ const auto * a_ptr = a.data ();
224+ // Can be either uint8_t* or int8_t*
225+ const auto * b_ptr = b.data ();
204226
205227 __m256 vec_sum = _mm256_setzero_ps ();
206228
207229 for (size_t i = start; i < stop; i += 8 ) {
208- // Load 8 bytes == 64 bits -- zeros out top 8 bytes
230+ // Load 8 bytes (uint8_t or int8_t) from both vectors
209231 __m128i vec_a = _mm_loadu_si64 ((__m64*)(a_ptr + i));
210232 __m128i vec_b = _mm_loadu_si64 ((__m64*)(b_ptr + i));
211233
212- // Zero extend 8bit to 32bit ints
213- __m256i a_ints = _mm256_cvtepu8_epi32 (vec_a);
214- __m256i b_ints = _mm256_cvtepu8_epi32 (vec_b);
215-
216- // Convert signed integers to floats
234+ // Extend 8 bit to 32 bit ints
235+ __m256i a_ints;
236+ if constexpr (std::same_as<typename V::value_type, uint8_t >) {
237+ // Zero extend uint8_t to int32_t
238+ a_ints = _mm256_cvtepu8_epi32 (vec_a);
239+ } else if constexpr (std::same_as<typename V::value_type, int8_t >) {
240+ // Sign extend int8_t to int32_t
241+ a_ints = _mm256_cvtepi8_epi32 (vec_a);
242+ }
243+
244+ // Conditionally convert based on the type of W::value_type
245+ __m256i b_ints;
246+ if constexpr (std::same_as<typename W::value_type, uint8_t >) {
247+ // Zero extend uint8_t to int32_t
248+ b_ints = _mm256_cvtepu8_epi32 (vec_b);
249+ } else if constexpr (std::same_as<typename W::value_type, int8_t >) {
250+ // Sign extend int8_t to int32_t
251+ b_ints = _mm256_cvtepi8_epi32 (vec_b);
252+ }
253+
254+ // Convert the 32-bit integers to floats for both vectors
217255 __m256 a_floats = _mm256_cvtepi32_ps (a_ints);
218256 __m256 b_floats = _mm256_cvtepi32_ps (b_ints);
219257
0 commit comments