11/* *
2- * @file l2_distance_avx .h
2+ * @file l2_distance .h
33 *
44 * @section LICENSE
55 *
3030 * C++ functions for computing L2 distance between two feature vectors
3131 *
3232 * This file contains a number of different implementations of the L2 distance
33- * computation, including naive, unrolled, and start/stop versions. The
33+ * computation, including naive, unrolled, and start/stop versions. The
3434 * implementations are templated on the type of the feature vector, and
35- * concepts are used to select between feature vectors of float and uint8_t.
35+ * concepts are used to select between feature vectors of float and uint8_t or
36+ * int8_t.
3637 *
3738 * Implementations include:
3839 * - naive_sum_of_squares: single loop, one statement in inner loop
4142 * - unroll4_sum_of_squares with start and stop
4243 *
4344 * The unrolled versions use simple 4x unrolling, and are faster than the
44- * naive versions. The start/stop versions are useful for computing the
45+ * naive versions. The start/stop versions are useful for computing the
4546 * "sub" distance, that is, the distance between just a portion of two vectors,
4647 * which is used in pq distance computation.
4748 *
@@ -292,7 +293,8 @@ inline float naive_sum_of_squares(
292293 */
293294template <feature_vector V, feature_vector W>
294295 requires std::same_as<typename V::value_type, float > &&
295- std::same_as<typename W::value_type, uint8_t >
296+ (std::same_as<typename W::value_type, uint8_t > ||
297+ std::same_as<typename W::value_type, int8_t >)
296298inline float naive_sum_of_squares (
297299 const V& a, const W& b, size_t start, size_t stop) {
298300 float sum = 0.0 ;
@@ -307,8 +309,9 @@ inline float naive_sum_of_squares(
307309 * Compute l2 distance between vector of uint8_t and vector of float
308310 */
309311template <feature_vector V, feature_vector W>
310- requires std::same_as<typename V::value_type, uint8_t > &&
311- std::same_as<typename W::value_type, float >
312+ requires (std::same_as<typename V::value_type, uint8_t > ||
313+ std::same_as<typename V::value_type, int8_t >) &&
314+ std::same_as<typename W::value_type, float>
312315inline float naive_sum_of_squares(
313316 const V& a, const W& b, size_t start, size_t stop) {
314317 float sum = 0.0 ;
@@ -323,8 +326,10 @@ inline float naive_sum_of_squares(
323326 * Compute l2 distance between vector of uint8_t and vector of uint8_t
324327 */
325328template <feature_vector V, feature_vector W>
326- requires std::same_as<typename V::value_type, uint8_t > &&
327- std::same_as<typename W::value_type, uint8_t >
329+ requires (std::same_as<typename V::value_type, uint8_t > ||
330+ std::same_as<typename V::value_type, int8_t >) &&
331+ (std::same_as<typename W::value_type, uint8_t > ||
332+ std::same_as<typename W::value_type, int8_t >)
328333inline float naive_sum_of_squares(
329334 const V& a, const W& b, size_t start, size_t stop) {
330335 float sum = 0.0 ;
@@ -337,7 +342,7 @@ inline float naive_sum_of_squares(
337342
338343/* ***************************************************************
339344 *
340- * 4x unrolled algorithms with start and stop. We have separate
345+ * 4x unrolled algorithms with start and stop. We have separate
341346 * functions despite the code duplication to make sure about
342347 * performance in the common case (no start / stop).
343348 *
@@ -375,7 +380,8 @@ inline float unroll4_sum_of_squares(
375380 */
376381template <feature_vector V, feature_vector W>
377382 requires std::same_as<typename V::value_type, float > &&
378- std::same_as<typename W::value_type, uint8_t >
383+ (std::same_as<typename W::value_type, uint8_t > ||
384+ std::same_as<typename W::value_type, int8_t >)
379385inline float unroll4_sum_of_squares (
380386 const V& a, const W& b, size_t begin, size_t end) {
381387 size_t loops = 4 * ((end - begin) / 4 );
@@ -401,8 +407,9 @@ inline float unroll4_sum_of_squares(
401407 * Unrolled l2 distance between vector of uint8_t and vector of float
402408 */
403409template <feature_vector V, feature_vector W>
404- requires std::same_as<typename V::value_type, uint8_t > &&
405- std::same_as<typename W::value_type, float >
410+ requires (std::same_as<typename V::value_type, uint8_t > ||
411+ std::same_as<typename V::value_type, int8_t >) &&
412+ std::same_as<typename W::value_type, float>
406413inline float unroll4_sum_of_squares(
407414 const V& a, const W& b, size_t begin, size_t end) {
408415 size_t loops = 4 * ((end - begin) / 4 );
@@ -428,8 +435,10 @@ inline float unroll4_sum_of_squares(
428435 * Unrolled l2 distance between vector of uint8_t and vector of uint8_t
429436 */
430437template <feature_vector V, feature_vector W>
431- requires std::same_as<typename V::value_type, uint8_t > &&
432- std::same_as<typename W::value_type, uint8_t >
438+ requires (std::same_as<typename V::value_type, uint8_t > ||
439+ std::same_as<typename V::value_type, int8_t >) &&
440+ (std::same_as<typename W::value_type, uint8_t > ||
441+ std::same_as<typename W::value_type, int8_t >)
433442inline float unroll4_sum_of_squares(
434443 const V& a, const W& b, size_t begin, size_t end) {
435444 size_t loops = 4 * ((end - begin) / 4 );
0 commit comments