Skip to content

Commit b0c0209

Browse files
authored
Update l2 distance to support int8_t (#400)
1 parent f564a8d commit b0c0209

File tree

2 files changed

+27
-18
lines changed

2 files changed

+27
-18
lines changed

src/include/detail/scoring/l2_distance.h

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* @file l2_distance_avx.h
2+
* @file l2_distance.h
33
*
44
* @section LICENSE
55
*
@@ -30,9 +30,10 @@
3030
* C++ functions for computing L2 distance between two feature vectors
3131
*
3232
* This file contains a number of different implementations of the L2 distance
33-
* computation, including naive, unrolled, and start/stop versions. The
33+
* computation, including naive, unrolled, and start/stop versions. The
3434
* implementations are templated on the type of the feature vector, and
35-
* concepts are used to select between feature vectors of float and uint8_t.
35+
* concepts are used to select between feature vectors of float and uint8_t or
36+
* int8_t.
3637
*
3738
* Implementations include:
3839
* - naive_sum_of_squares: single loop, one statement in inner loop
@@ -41,7 +42,7 @@
4142
* - unroll4_sum_of_squares with start and stop
4243
*
4344
* The unrolled versions use simple 4x unrolling, and are faster than the
44-
* naive versions. The start/stop versions are useful for computing the
45+
* naive versions. The start/stop versions are useful for computing the
4546
* "sub" distance, that is, the distance between just a portion of two vectors,
4647
* which is used in pq distance computation.
4748
*
@@ -292,7 +293,8 @@ inline float naive_sum_of_squares(
292293
*/
293294
template <feature_vector V, feature_vector W>
294295
requires std::same_as<typename V::value_type, float> &&
295-
std::same_as<typename W::value_type, uint8_t>
296+
(std::same_as<typename W::value_type, uint8_t> ||
297+
std::same_as<typename W::value_type, int8_t>)
296298
inline float naive_sum_of_squares(
297299
const V& a, const W& b, size_t start, size_t stop) {
298300
float sum = 0.0;
@@ -307,8 +309,9 @@ inline float naive_sum_of_squares(
307309
* Compute l2 distance between vector of uint8_t and vector of float
308310
*/
309311
template <feature_vector V, feature_vector W>
310-
requires std::same_as<typename V::value_type, uint8_t> &&
311-
std::same_as<typename W::value_type, float>
312+
requires(std::same_as<typename V::value_type, uint8_t> ||
313+
std::same_as<typename V::value_type, int8_t>) &&
314+
std::same_as<typename W::value_type, float>
312315
inline float naive_sum_of_squares(
313316
const V& a, const W& b, size_t start, size_t stop) {
314317
float sum = 0.0;
@@ -323,8 +326,10 @@ inline float naive_sum_of_squares(
323326
* Compute l2 distance between vector of uint8_t and vector of uint8_t
324327
*/
325328
template <feature_vector V, feature_vector W>
326-
requires std::same_as<typename V::value_type, uint8_t> &&
327-
std::same_as<typename W::value_type, uint8_t>
329+
requires(std::same_as<typename V::value_type, uint8_t> ||
330+
std::same_as<typename V::value_type, int8_t>) &&
331+
(std::same_as<typename W::value_type, uint8_t> ||
332+
std::same_as<typename W::value_type, int8_t>)
328333
inline float naive_sum_of_squares(
329334
const V& a, const W& b, size_t start, size_t stop) {
330335
float sum = 0.0;
@@ -337,7 +342,7 @@ inline float naive_sum_of_squares(
337342

338343
/****************************************************************
339344
*
340-
* 4x unrolled algorithms with start and stop. We have separate
345+
* 4x unrolled algorithms with start and stop. We have separate
341346
* functions despite the code duplication to make sure about
342347
* performance in the common case (no start / stop).
343348
*
@@ -375,7 +380,8 @@ inline float unroll4_sum_of_squares(
375380
*/
376381
template <feature_vector V, feature_vector W>
377382
requires std::same_as<typename V::value_type, float> &&
378-
std::same_as<typename W::value_type, uint8_t>
383+
(std::same_as<typename W::value_type, uint8_t> ||
384+
std::same_as<typename W::value_type, int8_t>)
379385
inline float unroll4_sum_of_squares(
380386
const V& a, const W& b, size_t begin, size_t end) {
381387
size_t loops = 4 * ((end - begin) / 4);
@@ -401,8 +407,9 @@ inline float unroll4_sum_of_squares(
401407
* Unrolled l2 distance between vector of uint8_t and vector of float
402408
*/
403409
template <feature_vector V, feature_vector W>
404-
requires std::same_as<typename V::value_type, uint8_t> &&
405-
std::same_as<typename W::value_type, float>
410+
requires(std::same_as<typename V::value_type, uint8_t> ||
411+
std::same_as<typename V::value_type, int8_t>) &&
412+
std::same_as<typename W::value_type, float>
406413
inline float unroll4_sum_of_squares(
407414
const V& a, const W& b, size_t begin, size_t end) {
408415
size_t loops = 4 * ((end - begin) / 4);
@@ -428,8 +435,10 @@ inline float unroll4_sum_of_squares(
428435
* Unrolled l2 distance between vector of uint8_t and vector of uint8_t
429436
*/
430437
template <feature_vector V, feature_vector W>
431-
requires std::same_as<typename V::value_type, uint8_t> &&
432-
std::same_as<typename W::value_type, uint8_t>
438+
requires(std::same_as<typename V::value_type, uint8_t> ||
439+
std::same_as<typename V::value_type, int8_t>) &&
440+
(std::same_as<typename W::value_type, uint8_t> ||
441+
std::same_as<typename W::value_type, int8_t>)
433442
inline float unroll4_sum_of_squares(
434443
const V& a, const W& b, size_t begin, size_t end) {
435444
size_t loops = 4 * ((end - begin) / 4);

src/include/test/unit_l2_distance.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@
3434
#include "detail/linalg/matrix.h"
3535
#include "detail/scoring/l2_distance.h"
3636

37-
TEST_CASE("naive_sum_of_squares", "[l2_distance]") {
37+
TEMPLATE_TEST_CASE("naive_sum_of_squares", "[l2_distance]", int8_t, uint8_t) {
3838
// size_t n = GENERATE(1, 3, 127, 1021, 1024);
3939

4040
size_t n = GENERATE(127);
4141

42-
auto u = std::vector<uint8_t>(n);
43-
auto v = std::vector<uint8_t>(n);
42+
auto u = std::vector<TestType>(n);
43+
auto v = std::vector<TestType>(n);
4444
auto x = std::vector<float>(n);
4545
auto y = std::vector<float>(n);
4646

0 commit comments

Comments
 (0)