Skip to content

Commit 51d5247

Browse files
authored
[MOD-13442] Refactor SQ8 metadata access to use sq8 enum constants (#881)
* Add SQ8 distance functions and tests for metric validation * use metadata indices from sq8 struct Refactor SQ8 inner product and L2 implementations to * Refactor SQ8 parameter access in inner product implementations for improved readability * Refactor metadata handling in SQ8 quantization to use structured constants for improved clarity and maintainability * Refactor quantization test to use structured constants for metadata access
1 parent 1d52631 commit 51d5247

File tree

14 files changed

+113
-59
lines changed

14 files changed

+113
-59
lines changed

src/VecSim/spaces/IP/IP.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
#include "IP.h"
1010
#include "VecSim/types/bfloat16.h"
1111
#include "VecSim/types/float16.h"
12+
#include "VecSim/types/sq8.h"
1213
#include <cstring>
1314

1415
using bfloat16 = vecsim_types::bfloat16;
1516
using float16 = vecsim_types::float16;
17+
using sq8 = vecsim_types::sq8;
1618

1719
float FLOAT_INTEGER_InnerProduct(const float *pVect1v, const uint8_t *pVect2v, size_t dimension,
1820
float min_val, float delta) {
@@ -63,14 +65,16 @@ float SQ8_SQ8_InnerProduct_Impl(const void *pVect1v, const void *pVect2v, size_t
6365
}
6466

6567
// Get quantization parameters from pVect1
66-
const float min_val1 = *reinterpret_cast<const float *>(pVect1 + dimension);
67-
const float delta1 = *reinterpret_cast<const float *>(pVect1 + dimension + sizeof(float));
68-
const float sum1 = *reinterpret_cast<const float *>(pVect1 + dimension + 2 * sizeof(float));
68+
const float *params1 = reinterpret_cast<const float *>(pVect1 + dimension);
69+
const float min_val1 = params1[sq8::MIN_VAL];
70+
const float delta1 = params1[sq8::DELTA];
71+
const float sum1 = params1[sq8::SUM];
6972

7073
// Get quantization parameters from pVect2
71-
const float min_val2 = *reinterpret_cast<const float *>(pVect2 + dimension);
72-
const float delta2 = *reinterpret_cast<const float *>(pVect2 + dimension + sizeof(float));
73-
const float sum2 = *reinterpret_cast<const float *>(pVect2 + dimension + 2 * sizeof(float));
74+
const float *params2 = reinterpret_cast<const float *>(pVect2 + dimension);
75+
const float min_val2 = params2[sq8::MIN_VAL];
76+
const float delta2 = params2[sq8::DELTA];
77+
const float sum2 = params2[sq8::SUM];
7478

7579
// Apply the algebraic formula using precomputed sums:
7680
// IP = min1*sum2 + min2*sum1 + delta1*delta2*Σ(q1[i]*q2[i]) - dim*min1*min2

src/VecSim/spaces/IP/IP_AVX512F_BW_VL_VNNI_SQ8_SQ8.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
#pragma once
1010
#include "VecSim/spaces/space_includes.h"
1111
#include "VecSim/spaces/IP/IP_AVX512F_BW_VL_VNNI_UINT8.h"
12+
#include "VecSim/types/sq8.h"
1213
#include <immintrin.h>
1314

15+
using sq8 = vecsim_types::sq8;
16+
1417
/**
1518
* SQ8-to-SQ8 distance functions using AVX512 VNNI with precomputed sum.
1619
* These functions compute distance between two SQ8 (scalar quantized 8-bit) vectors,
@@ -45,14 +48,14 @@ float SQ8_SQ8_InnerProductImp(const void *pVec1v, const void *pVec2v, size_t dim
4548
const uint8_t *pVec2 = static_cast<const uint8_t *>(pVec2v);
4649

4750
const float *params1 = reinterpret_cast<const float *>(pVec1 + dimension);
48-
const float min1 = params1[0];
49-
const float delta1 = params1[1];
50-
const float sum1 = params1[2]; // Precomputed sum of original float elements
51+
const float min1 = params1[sq8::MIN_VAL];
52+
const float delta1 = params1[sq8::DELTA];
53+
const float sum1 = params1[sq8::SUM]; // Precomputed sum of original float elements
5154

5255
const float *params2 = reinterpret_cast<const float *>(pVec2 + dimension);
53-
const float min2 = params2[0];
54-
const float delta2 = params2[1];
55-
const float sum2 = params2[2]; // Precomputed sum of original float elements
56+
const float min2 = params2[sq8::MIN_VAL];
57+
const float delta2 = params2[sq8::DELTA];
58+
const float sum2 = params2[sq8::SUM]; // Precomputed sum of original float elements
5659

5760
// Apply the algebraic formula using precomputed sums:
5861
// IP = min1*sum2 + min2*sum1 + δ1*δ2 * Σ(q1[i]*q2[i]) - dim*min1*min2

src/VecSim/spaces/IP/IP_NEON_DOTPROD_SQ8_SQ8.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
#pragma once
1010
#include "VecSim/spaces/space_includes.h"
1111
#include "VecSim/spaces/IP/IP_NEON_DOTPROD_UINT8.h"
12+
#include "VecSim/types/sq8.h"
1213
#include <arm_neon.h>
1314

15+
using sq8 = vecsim_types::sq8;
16+
1417
/**
1518
* SQ8-to-SQ8 distance functions using ARM NEON DOTPROD with precomputed sum.
1619
* These functions compute distance between two SQ8 (scalar quantized 8-bit) vectors,
@@ -46,14 +49,14 @@ float SQ8_SQ8_InnerProductSIMD64_NEON_DOTPROD_IMP(const void *pVec1v, const void
4649
const uint8_t *pVec2 = static_cast<const uint8_t *>(pVec2v);
4750

4851
const float *params1 = reinterpret_cast<const float *>(pVec1 + dimension);
49-
const float min1 = params1[0];
50-
const float delta1 = params1[1];
51-
const float sum1 = params1[2]; // Precomputed sum of original float elements
52+
const float min1 = params1[sq8::MIN_VAL];
53+
const float delta1 = params1[sq8::DELTA];
54+
const float sum1 = params1[sq8::SUM]; // Precomputed sum of original float elements
5255

5356
const float *params2 = reinterpret_cast<const float *>(pVec2 + dimension);
54-
const float min2 = params2[0];
55-
const float delta2 = params2[1];
56-
const float sum2 = params2[2]; // Precomputed sum of original float elements
57+
const float min2 = params2[sq8::MIN_VAL];
58+
const float delta2 = params2[sq8::DELTA];
59+
const float sum2 = params2[sq8::SUM]; // Precomputed sum of original float elements
5760

5861
// Apply algebraic formula using precomputed sums:
5962
// IP = min1*sum2 + min2*sum1 + δ1*δ2 * Σ(q1*q2) - dim*min1*min2

src/VecSim/spaces/IP/IP_NEON_SQ8_SQ8.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
#pragma once
1010
#include "VecSim/spaces/space_includes.h"
1111
#include "VecSim/spaces/IP/IP_NEON_UINT8.h"
12+
#include "VecSim/types/sq8.h"
1213
#include <arm_neon.h>
1314

15+
using sq8 = vecsim_types::sq8;
16+
1417
/**
1518
* SQ8-to-SQ8 distance functions using ARM NEON with precomputed sum.
1619
* These functions compute distance between two SQ8 (scalar quantized 8-bit) vectors,
@@ -46,15 +49,15 @@ float SQ8_SQ8_InnerProductSIMD64_NEON_IMP(const void *pVec1v, const void *pVec2v
4649
const uint8_t *pVec2 = static_cast<const uint8_t *>(pVec2v);
4750

4851
const float *params1 = reinterpret_cast<const float *>(pVec1 + dimension);
49-
const float min1 = params1[0];
50-
const float delta1 = params1[1];
51-
const float sum1 = params1[2]; // Precomputed sum of original float elements
52+
const float min1 = params1[sq8::MIN_VAL];
53+
const float delta1 = params1[sq8::DELTA];
54+
const float sum1 = params1[sq8::SUM]; // Precomputed sum of original float elements
5255

5356
// Get dequantization parameters and precomputed values from the end of pVec2
5457
const float *params2 = reinterpret_cast<const float *>(pVec2 + dimension);
55-
const float min2 = params2[0];
56-
const float delta2 = params2[1];
57-
const float sum2 = params2[2]; // Precomputed sum of original float elements
58+
const float min2 = params2[sq8::MIN_VAL];
59+
const float delta2 = params2[sq8::DELTA];
60+
const float sum2 = params2[sq8::SUM]; // Precomputed sum of original float elements
5861

5962
// Apply algebraic formula using precomputed sums:
6063
// IP = min1*sum2 + min2*sum1 + δ1*δ2 * Σ(q1*q2) - dim*min1*min2

src/VecSim/spaces/IP/IP_SVE_SQ8_SQ8.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
#pragma once
1010
#include "VecSim/spaces/space_includes.h"
1111
#include "VecSim/spaces/IP/IP_SVE_UINT8.h"
12+
#include "VecSim/types/sq8.h"
1213
#include <arm_sve.h>
1314

15+
using sq8 = vecsim_types::sq8;
16+
1417
/**
1518
* SQ8-to-SQ8 distance functions using ARM SVE with precomputed sum.
1619
* These functions compute distance between two SQ8 (scalar quantized 8-bit) vectors,
@@ -46,14 +49,14 @@ float SQ8_SQ8_InnerProductSIMD_SVE_IMP(const void *pVec1v, const void *pVec2v, s
4649
const uint8_t *pVec2 = static_cast<const uint8_t *>(pVec2v);
4750

4851
const float *params1 = reinterpret_cast<const float *>(pVec1 + dimension);
49-
const float min1 = params1[0];
50-
const float delta1 = params1[1];
51-
const float sum1 = params1[2]; // Precomputed sum of original float elements
52+
const float min1 = params1[sq8::MIN_VAL];
53+
const float delta1 = params1[sq8::DELTA];
54+
const float sum1 = params1[sq8::SUM]; // Precomputed sum of original float elements
5255

5356
const float *params2 = reinterpret_cast<const float *>(pVec2 + dimension);
54-
const float min2 = params2[0];
55-
const float delta2 = params2[1];
56-
const float sum2 = params2[2]; // Precomputed sum of original float elements
57+
const float min2 = params2[sq8::MIN_VAL];
58+
const float delta2 = params2[sq8::DELTA];
59+
const float sum2 = params2[sq8::SUM]; // Precomputed sum of original float elements
5760

5861
// Apply algebraic formula with float conversion only at the end:
5962
// IP = min1*sum2 + min2*sum1 + δ1*δ2 * Σ(q1*q2) - dim*min1*min2

src/VecSim/spaces/L2/L2.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
#include "VecSim/spaces/IP/IP.h"
1111
#include "VecSim/types/bfloat16.h"
1212
#include "VecSim/types/float16.h"
13+
#include "VecSim/types/sq8.h"
1314
#include <cstring>
1415
#include <iostream>
1516

1617
using bfloat16 = vecsim_types::bfloat16;
1718
using float16 = vecsim_types::float16;
19+
using sq8 = vecsim_types::sq8;
1820

1921
float SQ8_L2Sqr(const void *pVect1v, const void *pVect2v, size_t dimension) {
2022
const auto *pVect1 = static_cast<const float *>(pVect1v);
@@ -149,8 +151,10 @@ float SQ8_SQ8_L2Sqr(const void *pVect1v, const void *pVect2v, size_t dimension)
149151

150152
// Get precomputed sum of squares from both vectors
151153
// Layout: [uint8_t values (dim)] [min_val] [delta] [sum] [sum_of_squares]
152-
const float sum_sq_1 = *reinterpret_cast<const float *>(pVect1 + dimension + 3 * sizeof(float));
153-
const float sum_sq_2 = *reinterpret_cast<const float *>(pVect2 + dimension + 3 * sizeof(float));
154+
const float sum_sq_1 =
155+
*reinterpret_cast<const float *>(pVect1 + dimension + sq8::SUM_SQUARES * sizeof(float));
156+
const float sum_sq_2 =
157+
*reinterpret_cast<const float *>(pVect2 + dimension + sq8::SUM_SQUARES * sizeof(float));
154158

155159
// Use the common inner product implementation
156160
const float ip = SQ8_SQ8_InnerProduct_Impl(pVect1v, pVect2v, dimension);

src/VecSim/spaces/L2/L2_NEON_DOTPROD_SQ8_SQ8.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
#pragma once
1010
#include "VecSim/spaces/space_includes.h"
1111
#include "VecSim/spaces/IP/IP_NEON_DOTPROD_SQ8_SQ8.h"
12+
#include "VecSim/types/sq8.h"
13+
14+
using sq8 = vecsim_types::sq8;
1215

1316
/**
1417
* SQ8-to-SQ8 L2 squared distance functions for NEON with DOTPROD extension.
@@ -34,8 +37,10 @@ float SQ8_SQ8_L2SqrSIMD64_NEON_DOTPROD(const void *pVec1v, const void *pVec2v, s
3437

3538
// Get precomputed sum of squares from both vectors
3639
// Layout: [uint8_t values (dim)] [min_val] [delta] [sum] [sum_of_squares]
37-
const float sum_sq_1 = *reinterpret_cast<const float *>(pVec1 + dimension + 3 * sizeof(float));
38-
const float sum_sq_2 = *reinterpret_cast<const float *>(pVec2 + dimension + 3 * sizeof(float));
40+
const float sum_sq_1 =
41+
*reinterpret_cast<const float *>(pVec1 + dimension + sq8::SUM_SQUARES * sizeof(float));
42+
const float sum_sq_2 =
43+
*reinterpret_cast<const float *>(pVec2 + dimension + sq8::SUM_SQUARES * sizeof(float));
3944

4045
// L2² = ||x||² + ||y||² - 2*IP(x, y)
4146
return sum_sq_1 + sum_sq_2 - 2.0f * ip;

src/VecSim/spaces/L2/L2_NEON_SQ8_SQ8.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
#pragma once
1010
#include "VecSim/spaces/space_includes.h"
1111
#include "VecSim/spaces/IP/IP_NEON_SQ8_SQ8.h"
12+
#include "VecSim/types/sq8.h"
13+
14+
using sq8 = vecsim_types::sq8;
1215

1316
/**
1417
* SQ8-to-SQ8 L2 squared distance functions for NEON.
@@ -33,8 +36,10 @@ float SQ8_SQ8_L2SqrSIMD64_NEON(const void *pVec1v, const void *pVec2v, size_t di
3336

3437
// Get precomputed sum of squares from both vectors
3538
// Layout: [uint8_t values (dim)] [min_val] [delta] [sum] [sum_of_squares]
36-
const float sum_sq_1 = *reinterpret_cast<const float *>(pVec1 + dimension + 3 * sizeof(float));
37-
const float sum_sq_2 = *reinterpret_cast<const float *>(pVec2 + dimension + 3 * sizeof(float));
39+
const float sum_sq_1 =
40+
*reinterpret_cast<const float *>(pVec1 + dimension + sq8::SUM_SQUARES * sizeof(float));
41+
const float sum_sq_2 =
42+
*reinterpret_cast<const float *>(pVec2 + dimension + sq8::SUM_SQUARES * sizeof(float));
3843

3944
// L2² = ||x||² + ||y||² - 2*IP(x, y)
4045
return sum_sq_1 + sum_sq_2 - 2.0f * ip;

src/VecSim/spaces/L2/L2_SVE_SQ8_SQ8.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
#pragma once
1010
#include "VecSim/spaces/space_includes.h"
1111
#include "VecSim/spaces/IP/IP_SVE_SQ8_SQ8.h"
12+
#include "VecSim/types/sq8.h"
13+
14+
using sq8 = vecsim_types::sq8;
1215

1316
/**
1417
* SQ8-to-SQ8 L2 squared distance functions for SVE.
@@ -34,8 +37,10 @@ float SQ8_SQ8_L2SqrSIMD_SVE(const void *pVec1v, const void *pVec2v, size_t dimen
3437

3538
// Get precomputed sum of squares from both vectors
3639
// Layout: [uint8_t values (dim)] [min_val] [delta] [sum] [sum_of_squares]
37-
const float sum_sq_1 = *reinterpret_cast<const float *>(pVec1 + dimension + 3 * sizeof(float));
38-
const float sum_sq_2 = *reinterpret_cast<const float *>(pVec2 + dimension + 3 * sizeof(float));
40+
const float sum_sq_1 =
41+
*reinterpret_cast<const float *>(pVec1 + dimension + sq8::SUM_SQUARES * sizeof(float));
42+
const float sum_sq_2 =
43+
*reinterpret_cast<const float *>(pVec2 + dimension + sq8::SUM_SQUARES * sizeof(float));
3944

4045
// L2² = ||x||² + ||y||² - 2*IP(x, y)
4146
return sum_sq_1 + sum_sq_2 - 2.0f * ip;

src/VecSim/spaces/computer/preprocessors.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "VecSim/memory/vecsim_base.h"
1919
#include "VecSim/spaces/spaces.h"
2020
#include "VecSim/memory/memory_utils.h"
21+
#include "VecSim/types/sq8.h"
2122

2223
class PreprocessorInterface : public VecsimBaseObject {
2324
public:
@@ -216,10 +217,8 @@ class CosinePreprocessor : public PreprocessorInterface {
216217
template <typename DataType, VecSimMetric Metric>
217218
class QuantPreprocessor : public PreprocessorInterface {
218219
using OUTPUT_TYPE = uint8_t;
220+
using sq8 = vecsim_types::sq8;
219221

220-
// For L2: store sum + sum_of_squares (2 extra values)
221-
// For IP/Cosine: store only sum (1 extra value)
222-
static constexpr size_t extra_storage_values_count = (Metric == VecSimMetric_L2) ? 2 : 1;
223222
static_assert(Metric == VecSimMetric_L2 || Metric == VecSimMetric_IP ||
224223
Metric == VecSimMetric_Cosine,
225224
"QuantPreprocessor only supports L2, IP and Cosine metrics");
@@ -294,13 +293,13 @@ class QuantPreprocessor : public PreprocessorInterface {
294293
DataType *metadata = reinterpret_cast<DataType *>(quantized + this->dim);
295294

296295
// Store min_val, delta, in the metadata
297-
metadata[0] = min_val;
298-
metadata[1] = delta;
296+
metadata[sq8::MIN_VAL] = min_val;
297+
metadata[sq8::DELTA] = delta;
299298

300299
// Store sum (for all metrics) and sum_squares (for L2 only)
301-
metadata[2] = sum;
300+
metadata[sq8::SUM] = sum;
302301
if constexpr (Metric == VecSimMetric_L2) {
303-
metadata[3] = sum_squares;
302+
metadata[sq8::SUM_SQUARES] = sum_squares;
304303
}
305304
}
306305

@@ -352,7 +351,7 @@ class QuantPreprocessor : public PreprocessorInterface {
352351
QuantPreprocessor(std::shared_ptr<VecSimAllocator> allocator, size_t dim)
353352
: PreprocessorInterface(allocator), dim(dim),
354353
storage_bytes_count(dim * sizeof(OUTPUT_TYPE) +
355-
(2 + extra_storage_values_count) * sizeof(DataType)),
354+
(vecsim_types::sq8::metadata_count<Metric>()) * sizeof(DataType)),
356355
query_bytes_count((dim + 1) * sizeof(DataType)) {
357356
static_assert(std::is_floating_point_v<DataType>,
358357
"QuantPreprocessor only supports floating-point types");

0 commit comments

Comments
 (0)