@@ -999,16 +999,21 @@ TEST(PreprocessorsTest, QuantizationTest) {
999999
10001000 // === Storage blob expected values ===
10011001 // For L2 metric: quantized values + min + delta + sum + sum_squares = dim bytes + 4 floats
1002- constexpr size_t quantized_blob_bytes_count = dim * sizeof (uint8_t ) + 4 * sizeof (float );
1002+ constexpr size_t quantized_blob_bytes_count =
1003+ dim * sizeof (uint8_t ) + sq8::storage_metadata_count<VecSimMetric_L2>() * sizeof (float );
10031004 uint8_t expected_storage_blob[quantized_blob_bytes_count] = {0 };
10041005 ComputeSQ8Quantization (original_blob, dim, expected_storage_blob);
10051006
10061007 // === Query blob expected values ===
10071008 // Query layout: | query_values[dim] | y_sum_squares (for L2) |
1008- constexpr size_t query_blob_bytes_count = (dim + 1 ) * sizeof (float );
1009- // Compute expected sum of squares for L2: 1² + 2² + 3² + 4² + 5² + 6² = 91
1009+ constexpr size_t query_blob_bytes_count =
1010+ (dim + sq8::query_metadata_count<VecSimMetric_L2>()) * sizeof (float );
1011+
1012+ // Compute expected sum and sum of squares for L2:
1013+ float expected_query_sum = 0 ;
10101014 float expected_query_sum_squares = 0 ;
10111015 for (size_t i = 0 ; i < dim; ++i) {
1016+ expected_query_sum += original_blob[i];
10121017 expected_query_sum_squares += original_blob[i] * original_blob[i];
10131018 }
10141019
@@ -1038,7 +1043,8 @@ TEST(PreprocessorsTest, QuantizationTest) {
10381043 // Verify query blob content
10391044 const float *query_floats = static_cast <const float *>(query_blob);
10401045 EXPECT_NO_FATAL_FAILURE (CompareVectors<float >(query_floats, original_blob, dim));
1041- ASSERT_FLOAT_EQ (query_floats[dim], expected_query_sum_squares);
1046+ ASSERT_FLOAT_EQ (query_floats[dim + sq8::SUM_QUERY], expected_query_sum);
1047+ ASSERT_FLOAT_EQ (query_floats[dim + sq8::SUM_SQUARES_QUERY], expected_query_sum_squares);
10421048 }
10431049
10441050 // Test preprocessForStorage
@@ -1060,7 +1066,8 @@ TEST(PreprocessorsTest, QuantizationTest) {
10601066 // Verify query blob content: original floats followed by sum_squares
10611067 const float *query_floats = static_cast <const float *>(query_blob.get ());
10621068 EXPECT_NO_FATAL_FAILURE (CompareVectors<float >(query_floats, original_blob, dim));
1063- ASSERT_FLOAT_EQ (query_floats[dim], expected_query_sum_squares);
1069+ ASSERT_FLOAT_EQ (query_floats[dim + sq8::SUM_QUERY], expected_query_sum);
1070+ ASSERT_FLOAT_EQ (query_floats[dim + sq8::SUM_SQUARES_QUERY], expected_query_sum_squares);
10641071
10651072 // Check address is aligned
10661073 unsigned char address_alignment = (uintptr_t )(query_blob.get ()) % alignment;
@@ -1152,38 +1159,33 @@ class QuantPreprocessorMetricTest : public testing::TestWithParam<VecSimMetric>
11521159 // Storage layout: | quantized_values[dim] | min | delta | sum | (sum_squares for L2) |
11531160 // L2: dim bytes + 4 floats (min, delta, sum, sum_squares)
11541161 // IP/Cosine: dim bytes + 3 floats (min, delta, sum)
1155- static size_t getExpectedStorageSize (VecSimMetric metric) {
1156- size_t extra_floats = (metric == VecSimMetric_L2) ? 4 : 3 ;
1162+ template <VecSimMetric Metric>
1163+ static size_t getExpectedStorageSize () {
1164+ constexpr size_t extra_floats = sq8::storage_metadata_count<Metric>();
11571165 return dim * sizeof (uint8_t ) + extra_floats * sizeof (float );
11581166 }
11591167
11601168 // === Query blob helpers ===
11611169
11621170 // Query layout: | query_values[dim] | y_sum (IP/Cosine) OR y_sum_squares (L2) |
11631171 // All metrics: (dim + 1) floats
1164- static constexpr size_t getExpectedQuerySize () { return (dim + 1 ) * sizeof (float ); }
1165-
1166- // Compute expected precomputed value for query blob based on metric
11671172 template <VecSimMetric Metric>
1168- float getExpectedQueryPrecomputedValue () {
1169- float sum = 0 ;
1170- for (size_t i = 0 ; i < dim; ++i) {
1171- if constexpr (Metric == VecSimMetric_L2) {
1172- // sum of squares: 1² + 2² + 3² + 4² + 5² = 55
1173- sum += original_blob[i] * original_blob[i];
1174- } else {
1175- // sum: 1 + 2 + 3 + 4 + 5 = 15
1176- sum += original_blob[i];
1177- }
1178- }
1179- return sum;
1173+ static constexpr size_t getExpectedQuerySize () {
1174+ return (dim + sq8::query_metadata_count<Metric>()) * sizeof (float );
11801175 }
11811176
11821177 // Helper to run quantization test for a specific metric
11831178 template <VecSimMetric Metric>
11841179 void runQuantizationTest () {
1185- size_t expected_storage_size = getExpectedStorageSize (Metric);
1186- size_t expected_query_size = getExpectedQuerySize ();
1180+ size_t expected_storage_size = getExpectedStorageSize<Metric>();
1181+ size_t expected_query_size = getExpectedQuerySize<Metric>();
1182+
1183+ float expected_query_sum = 0 ;
1184+ float expected_query_sum_squares = 0 ;
1185+ for (size_t i = 0 ; i < dim; ++i) {
1186+ expected_query_sum += original_blob[i];
1187+ expected_query_sum_squares += original_blob[i] * original_blob[i];
1188+ }
11871189
11881190 auto quant_preprocessor = new (allocator) QuantPreprocessor<float , Metric>(allocator, dim);
11891191
@@ -1221,9 +1223,12 @@ class QuantPreprocessorMetricTest : public testing::TestWithParam<VecSimMetric>
12211223 const float *query_floats = static_cast <const float *>(query_blob);
12221224 EXPECT_NO_FATAL_FAILURE (CompareVectors<float >(query_floats, original_blob, dim));
12231225
1224- // Verify precomputed value (sum for IP/Cosine, sum_squares for L2)
1225- float expected_precomputed = getExpectedQueryPrecomputedValue<Metric>();
1226- ASSERT_FLOAT_EQ (query_floats[dim], expected_precomputed);
1226+ // Verify precomputed value (sum for IP/Cosine, sum and sum_squares for L2)
1227+ ASSERT_FLOAT_EQ (query_floats[dim + sq8::SUM_QUERY], expected_query_sum);
1228+ if constexpr (Metric == VecSimMetric_L2) {
1229+ ASSERT_FLOAT_EQ (query_floats[dim + sq8::SUM_SQUARES_QUERY],
1230+ expected_query_sum_squares);
1231+ }
12271232
12281233 allocator->free_allocation (storage_blob);
12291234 allocator->free_allocation (query_blob);
@@ -1254,10 +1259,12 @@ class QuantPreprocessorMetricTest : public testing::TestWithParam<VecSimMetric>
12541259 const float *query_floats = static_cast <const float *>(blob);
12551260 EXPECT_NO_FATAL_FAILURE (CompareVectors<float >(query_floats, original_blob, dim));
12561261
1257- // Verify precomputed value (sum for IP/Cosine, sum_squares for L2)
1258- float expected_precomputed = getExpectedQueryPrecomputedValue<Metric>();
1259- ASSERT_FLOAT_EQ (query_floats[dim], expected_precomputed);
1260-
1262+ // Verify precomputed value (sum for IP/Cosine, sum and sum_squares for L2)
1263+ ASSERT_FLOAT_EQ (query_floats[dim + sq8::SUM_QUERY], expected_query_sum);
1264+ if constexpr (Metric == VecSimMetric_L2) {
1265+ ASSERT_FLOAT_EQ (query_floats[dim + sq8::SUM_SQUARES_QUERY],
1266+ expected_query_sum_squares);
1267+ }
12611268 allocator->free_allocation (blob);
12621269 }
12631270
0 commit comments