Skip to content

Commit 32d0279

Browse files
authored
Fp32 sq8 dist functions L2Sqr [MOD-13392] (#885)
* Add SQ8-to-SQ8 distance functions and optimizations - Implemented inner product and cosine distance functions for SQ8-to-SQ8 vectors in SVE, NEON, and AVX512 architectures. - Added corresponding distance function selection logic in IP_space.cpp and function headers in IP_space.h. - Created benchmarks for SQ8-to-SQ8 distance functions to evaluate performance across different architectures. - Developed unit tests to validate the correctness of the new distance functions against expected results. - Ensured compatibility with existing optimization features for various CPU architectures. * Add SQ8-to-SQ8 benchmark tests and update related scripts * Format * Orgnizing * Add full sq8 bencharks * Optimize the sq8 sq8 * Optimize SQ8 distance functions for NEON by reducing operations and improving performance * format * Add NEON DOTPROD-optimized distance functions for SQ8-to-SQ8 calculations * PR * Remove NEON DOTPROD-optimized distance functions for INT8, UINT8, and SQ8-to-SQ8 calculations * Fix vector layout documentation by removing inv_norm from comments in NEON and AVX512 headers * Remove 'constexpr' from ones vector declaration in NEON inner product function * Add SQ8-to-SQ8 L2 squared distance functions with SIMD optimizations - Implemented NEON, SVE, and AVX512F optimized functions for calculating L2 squared distance between SQ8 (scalar quantized 8-bit) vectors. - Introduced helper functions for processing vector elements using NEON and SVE intrinsics. - Updated L2_space.cpp and L2_space.h to include new distance function for SQ8-to-SQ8. - Enhanced AVX512F, NEON, and SVE function selectors to choose the appropriate implementation based on CPU features. - Added unit tests to validate the correctness of the new L2 squared distance functions. - Updated benchmark tests to include performance measurements for the new implementations. * Change the name * Add full range tests for SQ8 distance functions with SIMD optimizations * Refactor distance functions to remove inv_norm parameter and update documentation accordingly * Update SQ8 Cosine test to normalize both input vectors and adjust distance assertion tolerance * Rename 'compressed' to 'quantized' in SQ8 functions for clarity and consistency * Rename 'compressed' to 'quantized' in SQ8 distance tests for clarity * Refactor quantization function to remove unused normalization calculations * Add TODO to store vector's norm and sum in L2 squared distance calculation * Implement SQ8-to-SQ8 distance functions with precomputed sum and norm using AVX512 VNNI; add benchmarks and tests for new functionality * Add edge case tests for SQ8-to-SQ8 precomputed cosine distance functions * Refactor SQ8 test cases to use CreateSQ8QuantizedVector for vector population * Implement SQ8-to-SQ8 precomputed distance functions using ARM NEON, SVE, and AVX512; add corresponding selection functions and update tests for consistency. * Implement SQ8-to-SQ8 precomputed inner product and cosine functions; update benchmarks and tests for new functionality * Refactor SQ8 distance functions and remove precomputed variants - Updated distance function declarations in IP_space.h to clarify that SQ8-to-SQ8 functions use precomputed sum/norm. - Removed precomputed distance function implementations for AVX512F, NEON, and SVE architectures from their respective source files. - Adjusted benchmark tests to remove references to precomputed distance functions and ensure they utilize the updated quantization methods. - Modified utility functions to support the creation of SQ8 quantized vectors with precomputed sum and norm. - Updated unit tests to reflect changes in the quantization process and removed tests specifically for precomputed distance functions. * Refactor SQ8 distance functions and tests for improved clarity and consistency - Updated include paths in AVX512F_BW_VL_VNNI.cpp to reflect new naming conventions. - Modified unit tests in test_spaces.cpp to streamline vector initialization and quantization processes. - Replaced repetitive code with utility functions for populating and quantizing vectors. - Enhanced assertions in tests to ensure optimized distance functions are correctly chosen and validated. - Removed unnecessary parameters from utility functions to simplify their interfaces. - Improved test coverage for edge cases, including zero and constant vectors, ensuring accuracy across various scenarios. * Refactor SQ8 benchmarks by removing precomputed variants and updating vector population methods * foramt * Remove serialization benchmark script for HNSW disk serialization * Refactor SQ8 distance functions and tests to remove precomputed norm references * format * Refactor SQ8 distance tests to use compressed vectors and improve normalization calculations * Update vector layout documentation to reflect removal of sum of squares in SQ8 implementations * Refactor L2 SQ8 distance computation to remove unused accumulators and streamline calculations * Refactor SQ8 distance functions to remove norm computation - Updated comments and documentation to reflect that the SQ8-to-SQ8 distance functions now only utilize precomputed sums, removing references to norms. - Modified function signatures and implementations across various SIMD architectures (AVX512F, NEON, SVE) to align with the new approach. - Adjusted utility functions for populating SQ8 vectors to include metadata for sums and normalization. - Updated unit tests and benchmarks to ensure compatibility with the new SQ8 vector population methods and to validate the correctness of distance calculations. * Update SQ8-to-SQ8 distance function comment to remove norm reference * Refactor cosine similarity functions to remove unnecessary subtraction in AVX2, SSE4, and SVE implementations * Refactor L2 SQ8 distance functions to eliminate unused accumulators and streamline calculations * Refactor SQ8 L2 and IP implementations to use common inner product function - Introduced SQ8_SQ8_InnerProduct_Impl for shared inner product calculations in SQ8 space. - Updated SQ8_SQ8_L2Sqr to utilize the new inner product implementation, improving performance and reducing code duplication. - Modified AVX512 and NEON SIMD implementations to leverage the common inner product function for L2 squared distance calculations. - Removed redundant code and tests related to full range vector comparisons, streamlining the test suite. - Ensured that vector layouts include sum of squares for optimized distance calculations. * Refactor cosine similarity functions to use specific SIMD implementations for improved clarity and performance * Refactor L2 distance functions for SQ8 vectors to utilize common inner product implementation and update metadata extraction in tests * Refactor benchmark setup to allocate additional space for sum and sum_squares in SQ8 vector tests * Add CPU feature checks to disable optimizations for AArch64 in SQ8 distance function * Add CPU feature checks to disable optimizations for AArch64 in SQ8 distance function tests * Fix formatting issues in SQ8 inner product function and clean up conditional compilation in tests * Refactor SQ8 distance functions and tests for improved readability and consistency * Refactor SQ8 L2Sqr tests to use quantized vectors and improve alignment checks * Enhance SQ8 Inner Product Implementations with Optimized Dot Product Calculations - Refactored inner product calculations for SQ8 vectors using NEON and SVE optimizations. - Integrated UINT8_InnerProductImp for efficient dot product computation in NEON and SVE implementations. - Updated inner product functions to handle 64-element chunks for improved performance. - Adjusted distance function selection logic to ensure optimizations are applied only for dimensions >= 16. - Added tests for zero vectors and constant vectors to validate optimized implementations against baseline results. - Ensured consistency in assertions for symmetry tests across various optimization flags. - Improved code readability and maintainability by removing redundant code and comments. * Fix header guard duplication and update test assertion for floating-point comparison * Add missing pragma once directive in NEON header files * Refactor SQ8 distance functions for improved performance and clarity - Updated inner product functions for NEON, SSE4, and SVE to streamline dequantization and reduce unnecessary calculations. - Consolidated common logic for inner product and cosine calculations across different SIMD implementations. - Enhanced the handling of vector normalization and quantization in unit tests, ensuring consistency in compressed vector sizes. - Adjusted benchmark tests to reflect changes in vector compression and distance function calls. - Corrected include paths for AVX512 implementations to maintain consistency across the codebase. * Update SQ8 vector population functions to include metadata and adjust compressed size calculations * Refactor SQ8 inner product functions for improved clarity and performance * Refactor L2 distance functions to utilize common inner product implementations for improved clarity and performance * Rename inner product implementation functions for AVX2 and AVX512 for clarity * Refactor SQ8 cosine function to utilize inner product function for improved clarity * Remove redundant inner product edge case tests for SQ8 distance functions * Add SVE2 support to SQ8-to-SQ8 Inner Product distance function * Fix SQ8_Cosine to call the correct inner product function for improved accuracy * Remove SVE2 and other optimizations from SQ8 cosine function test for ARM architecture * Add L2 distance function without optimizations for testing purposes * Refactor L2 distance function and update test assertions for precision * Update L2 squared distance functions to support 64 residuals in NEON implementation * Refactor L2 distance function conditions for NEON optimizations * Adjust NEON_DOTPROD benchmark initialization to use a dimension of 16 * Update NEON benchmarks to support 64 dimensions for L2 and Cosine metrics * Optimize SQ8 Inner Product Implementation - Refactor the SQ8 inner product computation to eliminate unnecessary dequantization steps, improving performance. - Introduce a new helper function `InnerProductStepSQ8` that computes the inner product directly using quantized values. - Update the main inner product function `SQ8_InnerProductSIMD_SVE_IMP` to utilize the new helper function, streamlining the computation process. - Modify the test suite to validate the new implementation, ensuring correctness against the baseline non-optimized version. - Add edge case tests for self-distance, symmetry, zero vectors, constant vectors, and extreme values to ensure robustness of the SQ8 cosine distance function. - Introduce utility functions for preprocessing and populating SQ8 queries, enhancing test clarity and maintainability. * Refactor SQ8 inner product functions to clarify FMA usage and improve performance * Update SQ8 test cases to improve alignment checks and adjust quantized size calculations * Add optimized SQ8 inner product implementation and update test cases * Fix pointer usage in SQ8 inner product implementation to reference original vectors * Add sq8 type definition and update inner product implementations for quantization parameters * Refactor SQ8 inner product implementations to use structured quantization parameters and clean up code formatting * Fix SQ8 EdgeCases test by adjusting vector size for constant vector test * Fix formatting in SQ8_EdgeCases test by adjusting vector initialization * Refactor SQ8 inner product implementations to use precomputed y_sum from query blob * Fix formatting in SQ8_EdgeCases test for better readability * Refactor SQ8 cosine distance calculation to use optimized function * Refactor SQ8 L2 squared distance calculations for optimized performance - Implemented algebraic identity for L2 squared distance to avoid dequantization in hot loops across AVX2, AVX512, NEON, SSE4, SVE implementations. - Updated L2 distance functions to utilize precomputed sum and sum of squares, improving efficiency. - Modified unit tests to validate the new implementations and ensure consistency with previous non-optimized calculations. - Enhanced test utilities to support preprocessing of float vectors for SQ8 L2 space. * Fix formatting in IP.cpp and IP.h documentation for better readability * Remove unused CreateSQ8CompressedVector helper function from test_spaces.cpp * Add self-distance L2 test for SQ8 edge cases with optimization checks * Refactor SQ8 query handling to unify preprocessing for IP/Cosine/L2 spaces and optimize memory allocation * Fix query population seed in SQ8 benchmark for consistency
1 parent f5e69ac commit 32d0279

File tree

13 files changed

+437
-701
lines changed

13 files changed

+437
-701
lines changed

src/VecSim/spaces/IP/IP.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,13 @@ using sq8 = vecsim_types::sq8;
2424
* = min * y_sum + delta * quantized_dot_product
2525
*
2626
* Uses 4x loop unrolling with multiple accumulators for ILP.
27-
* pVect1 is a vector of float32, pVect2 is a quantized uint8_t vector
27+
* pVect1 is query (FP32): [float values (dim)] [y_sum] [y_sum_squares (L2 only)]
28+
* pVect2 is storage (SQ8): [uint8_t values (dim)] [min_val] [delta] [x_sum] [x_sum_squares (L2
29+
* only)]
30+
*
31+
* Returns raw inner product value (not distance). Used by SQ8_InnerProduct, SQ8_Cosine, SQ8_L2Sqr.
2832
*/
29-
float SQ8_InnerProduct(const void *pVect1v, const void *pVect2v, size_t dimension) {
30-
33+
float SQ8_InnerProduct_Impl(const void *pVect1v, const void *pVect2v, size_t dimension) {
3134
const auto *pVect1 = static_cast<const float *>(pVect1v);
3235
const auto *pVect2 = static_cast<const uint8_t *>(pVect2v);
3336

@@ -61,8 +64,11 @@ float SQ8_InnerProduct(const void *pVect1v, const void *pVect2v, size_t dimensio
6164
const float y_sum = pVect1[dimension + sq8::SUM_QUERY];
6265

6366
// Apply formula: IP = min * y_sum + delta * Σ(q_i * y_i)
64-
const float ip = min_val * y_sum + delta * quantized_dot;
65-
return 1.0f - ip;
67+
return min_val * y_sum + delta * quantized_dot;
68+
}
69+
70+
float SQ8_InnerProduct(const void *pVect1v, const void *pVect2v, size_t dimension) {
71+
return 1.0f - SQ8_InnerProduct_Impl(pVect1v, pVect2v, dimension);
6672
}
6773

6874
float SQ8_Cosine(const void *pVect1v, const void *pVect2v, size_t dimension) {

src/VecSim/spaces/IP/IP.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@
1010

1111
#include <cstdlib>
1212

13+
// FP32-to-SQ8: Common inner product implementation that returns the raw inner product value
14+
// (not distance). Used by SQ8_InnerProduct, SQ8_Cosine, and SQ8_L2Sqr.
15+
// pVect1 is query (FP32): [float values (dim)] [y_sum] [y_sum_squares (L2 only)]
16+
// pVect2 is storage (SQ8): [uint8_t values (dim)] [min_val] [delta] [x_sum] [x_sum_squares (L2
17+
// only)]
18+
float SQ8_InnerProduct_Impl(const void *pVect1v, const void *pVect2v, size_t dimension);
19+
1320
// pVect1v vector of type fp32 and pVect2v vector of type uint8
1421
float SQ8_InnerProduct(const void *pVect1v, const void *pVect2v, size_t dimension);
1522

src/VecSim/spaces/L2/L2.cpp

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,30 @@ using bfloat16 = vecsim_types::bfloat16;
1818
using float16 = vecsim_types::float16;
1919
using sq8 = vecsim_types::sq8;
2020

21+
/*
22+
* Optimized asymmetric SQ8 L2 squared distance using algebraic identity:
23+
* ||x - y||² = Σx_i² - 2*IP(x, y) + Σy_i²
24+
* = x_sum_squares - 2 * IP(x, y) + y_sum_squares
25+
* where IP(x, y) = min * y_sum + delta * Σ(q_i * y_i)
26+
*
27+
* pVect1 is query (FP32): [float values (dim)] [y_sum] [y_sum_squares]
28+
* pVect2 is storage (SQ8): [uint8_t values (dim)] [min_val] [delta] [x_sum] [x_sum_squares]
29+
*/
2130
float SQ8_L2Sqr(const void *pVect1v, const void *pVect2v, size_t dimension) {
22-
const auto *pVect1 = static_cast<const float *>(pVect1v);
31+
// Get the raw inner product using the common implementation
32+
const float ip = SQ8_InnerProduct_Impl(pVect1v, pVect2v, dimension);
33+
34+
// Get precomputed sum of squares from storage blob
2335
const auto *pVect2 = static_cast<const uint8_t *>(pVect2v);
24-
// pvect2 is a vector of uint8_t, so we need to dequantize it, normalize it and then multiply
25-
// it. it structred as [quantized values (uint8_t * dim)][min_val (float)][delta
26-
// (float)][inv_norm (float)] The last two values are used to dequantize the vector.
27-
const float min_val = *reinterpret_cast<const float *>(pVect2 + dimension);
28-
const float delta = *reinterpret_cast<const float *>(pVect2 + dimension + sizeof(float));
36+
const float *params = reinterpret_cast<const float *>(pVect2 + dimension);
37+
const float x_sum_sq = params[sq8::SUM_SQUARES];
2938

30-
float res = 0;
31-
for (size_t i = 0; i < dimension; i++) {
32-
auto dequantized_V2 = (pVect2[i] * delta + min_val);
33-
float t = pVect1[i] - dequantized_V2;
34-
res += t * t;
35-
}
36-
return res;
39+
// Get precomputed sum of squares from query blob
40+
const auto *pVect1 = static_cast<const float *>(pVect1v);
41+
const float y_sum_sq = pVect1[dimension + sq8::SUM_SQUARES_QUERY];
42+
43+
// L2² = ||x||² + ||y||² - 2*IP(x, y)
44+
return x_sum_sq + y_sum_sq - 2.0f * ip;
3745
}
3846

3947
float FP32_L2Sqr(const void *pVect1v, const void *pVect2v, size_t dimension) {
Lines changed: 26 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
/*
32
* Copyright (c) 2006-Present, Redis Ltd.
43
* All rights reserved.
@@ -7,88 +6,40 @@
76
* (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the
87
* GNU Affero General Public License v3 (AGPLv3).
98
*/
9+
#pragma once
1010
#include "VecSim/spaces/space_includes.h"
1111
#include "VecSim/spaces/AVX_utils.h"
12+
#include "VecSim/spaces/IP/IP_AVX2_FMA_SQ8.h"
13+
#include "VecSim/types/sq8.h"
1214

13-
static inline void L2StepSQ8_FMA(const float *&pVect1, const uint8_t *&pVect2, __m256 &sum256,
14-
const __m256 &min_val_vec, const __m256 &delta_vec) {
15-
// Load 8 float elements from pVect1
16-
__m256 v1 = _mm256_loadu_ps(pVect1);
17-
pVect1 += 8;
18-
19-
// Load 8 uint8 elements from pVect2, convert to int32, then to float
20-
__m128i v2_128 = _mm_loadl_epi64((__m128i *)pVect2);
21-
pVect2 += 8;
22-
23-
// Zero-extend uint8 to int32
24-
__m256i v2_256 = _mm256_cvtepu8_epi32(v2_128);
25-
26-
// Convert int32 to float
27-
__m256 v2_f = _mm256_cvtepi32_ps(v2_256);
28-
29-
// Dequantize: v2_dequant = v2_f * delta_vec + min_val_vec
30-
__m256 v2_dequant = _mm256_fmadd_ps(v2_f, delta_vec, min_val_vec);
31-
32-
// Calculate squared difference - simple and efficient approach
33-
__m256 diff = _mm256_sub_ps(v1, v2_dequant);
15+
using sq8 = vecsim_types::sq8;
3416

35-
// Use FMA for diff² + sum in one instruction
36-
sum256 = _mm256_fmadd_ps(diff, diff, sum256);
37-
}
17+
/*
18+
* Optimized asymmetric SQ8 L2 squared distance using algebraic identity:
19+
*
20+
* ||x - y||² = Σx_i² - 2*IP(x, y) + Σy_i²
21+
* = x_sum_squares - 2 * IP(x, y) + y_sum_squares
22+
*
23+
* where:
24+
* - IP(x, y) = min * y_sum + delta * Σ(q_i * y_i) (computed via SQ8_InnerProductImp_FMA)
25+
* - x_sum_squares and y_sum_squares are precomputed
26+
*
27+
* This avoids dequantization in the hot loop.
28+
*/
3829

3930
template <unsigned char residual> // 0..15
4031
float SQ8_L2SqrSIMD16_AVX2_FMA(const void *pVect1v, const void *pVect2v, size_t dimension) {
41-
const float *pVect1 = static_cast<const float *>(pVect1v);
42-
// pVect2 is a quantized uint8_t vector
43-
const uint8_t *pVect2 = static_cast<const uint8_t *>(pVect2v);
44-
const float *pEnd1 = pVect1 + dimension;
45-
46-
// Get dequantization parameters from the end of quantized vector
47-
const float min_val = *reinterpret_cast<const float *>(pVect2 + dimension);
48-
const float delta = *reinterpret_cast<const float *>(pVect2 + dimension + sizeof(float));
49-
// Create broadcast vectors for SIMD operations
50-
__m256 min_val_vec = _mm256_set1_ps(min_val);
51-
__m256 delta_vec = _mm256_set1_ps(delta);
52-
53-
__m256 sum256 = _mm256_setzero_ps();
54-
55-
// Deal with 1-7 floats with mask loading, if needed. `dim` is >16, so we have at least one
56-
// 16-float block, so mask loading is guaranteed to be safe.
57-
if constexpr (residual % 8) {
58-
__mmask8 constexpr mask = (1 << (residual % 8)) - 1;
59-
__m256 v1 = my_mm256_maskz_loadu_ps<mask>(pVect1);
60-
pVect1 += residual % 8;
61-
62-
// Load quantized values and dequantize
63-
__m128i v2_128 = _mm_loadl_epi64((__m128i *)pVect2);
64-
pVect2 += residual % 8;
32+
// Get the raw inner product using the common SIMD implementation
33+
const float ip = SQ8_InnerProductImp_FMA<residual>(pVect1v, pVect2v, dimension);
6534

66-
// Zero-extend uint8 to int32
67-
__m256i v2_256 = _mm256_cvtepu8_epi32(v2_128);
68-
69-
// Convert int32 to float
70-
__m256 v2_f = _mm256_cvtepi32_ps(v2_256);
71-
72-
// Dequantize using FMA: (val * delta) + min_val
73-
__m256 v2_dequant = _mm256_fmadd_ps(v2_f, delta_vec, min_val_vec);
74-
v2_dequant = _mm256_blend_ps(_mm256_setzero_ps(), v2_dequant, mask);
75-
76-
// Calculate squared difference
77-
__m256 diff = _mm256_sub_ps(v1, v2_dequant);
78-
sum256 = _mm256_mul_ps(diff, diff);
79-
}
80-
81-
// If the reminder is >=8, have another step of 8 floats
82-
if constexpr (residual >= 8) {
83-
L2StepSQ8_FMA(pVect1, pVect2, sum256, min_val_vec, delta_vec);
84-
}
35+
// Get precomputed sum of squares from storage blob
36+
const uint8_t *pVect2 = static_cast<const uint8_t *>(pVect2v);
37+
const float *params = reinterpret_cast<const float *>(pVect2 + dimension);
38+
const float x_sum_sq = params[sq8::SUM_SQUARES];
8539

86-
// We dealt with the residual part. We are left with some multiple of 16 floats.
87-
// In each iteration we calculate 16 floats = 512 bits.
88-
do {
89-
L2StepSQ8_FMA(pVect1, pVect2, sum256, min_val_vec, delta_vec);
90-
L2StepSQ8_FMA(pVect1, pVect2, sum256, min_val_vec, delta_vec);
91-
} while (pVect1 < pEnd1);
40+
// Get precomputed sum of squares from query blob
41+
const float y_sum_sq = static_cast<const float *>(pVect1v)[dimension + sq8::SUM_SQUARES_QUERY];
9242

93-
return my_mm256_reduce_add_ps(sum256);
43+
// L2² = ||x||² + ||y||² - 2*IP(x, y)
44+
return x_sum_sq + y_sum_sq - 2.0f * ip;
9445
}

src/VecSim/spaces/L2/L2_AVX2_SQ8.h

Lines changed: 26 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -6,89 +6,40 @@
66
* (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the
77
* GNU Affero General Public License v3 (AGPLv3).
88
*/
9+
#pragma once
910
#include "VecSim/spaces/space_includes.h"
1011
#include "VecSim/spaces/AVX_utils.h"
12+
#include "VecSim/spaces/IP/IP_AVX2_SQ8.h"
13+
#include "VecSim/types/sq8.h"
1114

12-
static inline void L2SqrStep(const float *&pVect1, const uint8_t *&pVect2, __m256 &sum,
13-
const __m256 &min_val_vec, const __m256 &delta_vec) {
14-
// Load 8 float elements from pVect1
15-
__m256 v1 = _mm256_loadu_ps(pVect1);
15+
using sq8 = vecsim_types::sq8;
1616

17-
// Load 8 uint8 elements from pVect2
18-
__m128i v2_128 = _mm_loadl_epi64((__m128i *)pVect2);
19-
20-
// Zero-extend uint8 to int32
21-
__m256i v2_256 = _mm256_cvtepu8_epi32(v2_128);
22-
23-
// Convert int32 to float
24-
__m256 v2_f = _mm256_cvtepi32_ps(v2_256);
25-
26-
// Dequantize: (val * delta) + min_val
27-
__m256 v2_dequant = _mm256_add_ps(_mm256_mul_ps(v2_f, delta_vec), min_val_vec);
28-
29-
// Compute difference
30-
__m256 diff = _mm256_sub_ps(v1, v2_dequant);
31-
32-
// Square difference and add to sum
33-
sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
34-
35-
// Advance pointers
36-
pVect1 += 8;
37-
pVect2 += 8;
38-
}
17+
/*
18+
* Optimized asymmetric SQ8 L2 squared distance using algebraic identity:
19+
*
20+
* ||x - y||² = Σx_i² - 2*IP(x, y) + Σy_i²
21+
* = x_sum_squares - 2 * IP(x, y) + y_sum_squares
22+
*
23+
* where:
24+
* - IP(x, y) = min * y_sum + delta * Σ(q_i * y_i) (computed via SQ8_InnerProductImp_AVX2)
25+
* - x_sum_squares and y_sum_squares are precomputed
26+
*
27+
* This avoids dequantization in the hot loop.
28+
*/
3929

4030
template <unsigned char residual> // 0..15
4131
float SQ8_L2SqrSIMD16_AVX2(const void *pVect1v, const void *pVect2v, size_t dimension) {
42-
const float *pVect1 = static_cast<const float *>(pVect1v);
43-
const uint8_t *pVect2 = static_cast<const uint8_t *>(pVect2v);
44-
// Get dequantization parameters from the end of quantized vector
45-
const float min_val = *reinterpret_cast<const float *>(pVect2 + dimension);
46-
const float delta = *reinterpret_cast<const float *>(pVect2 + dimension + sizeof(float));
47-
// Create broadcast vectors for SIMD operations
48-
__m256 min_val_vec = _mm256_set1_ps(min_val);
49-
__m256 delta_vec = _mm256_set1_ps(delta);
50-
51-
const float *pEnd1 = pVect1 + dimension;
52-
53-
__m256 sum = _mm256_setzero_ps();
54-
55-
// Deal with 1-7 floats with mask loading, if needed
56-
if constexpr (residual % 8) {
57-
__mmask8 constexpr mask = (1 << (residual % 8)) - 1;
58-
__m256 v1 = my_mm256_maskz_loadu_ps<mask>(pVect1);
59-
pVect1 += residual % 8;
60-
61-
// Direct load - safe because we only process the masked elements
62-
__m128i v2_128 = _mm_loadl_epi64((__m128i *)pVect2);
63-
pVect2 += residual % 8;
64-
65-
// Zero-extend uint8 to int32
66-
__m256i v2_256 = _mm256_cvtepu8_epi32(v2_128);
32+
// Get the raw inner product using the common SIMD implementation
33+
const float ip = SQ8_InnerProductImp_AVX2<residual>(pVect1v, pVect2v, dimension);
6734

68-
// Convert int32 to float
69-
__m256 v2_f = _mm256_cvtepi32_ps(v2_256);
70-
71-
// Dequantize: (val * delta) + min_val
72-
__m256 v2_dequant = _mm256_add_ps(_mm256_mul_ps(v2_f, delta_vec), min_val_vec);
73-
74-
// Apply mask to zero out unused elements
75-
v2_dequant = _mm256_blend_ps(_mm256_setzero_ps(), v2_dequant, mask);
76-
77-
__m256 diff = _mm256_sub_ps(v1, v2_dequant);
78-
sum = _mm256_mul_ps(diff, diff);
79-
}
80-
81-
// If the reminder is >= 8, have another step of 8 floats
82-
if constexpr (residual >= 8) {
83-
L2SqrStep(pVect1, pVect2, sum, min_val_vec, delta_vec);
84-
}
35+
// Get precomputed sum of squares from storage blob
36+
const uint8_t *pVect2 = static_cast<const uint8_t *>(pVect2v);
37+
const float *params = reinterpret_cast<const float *>(pVect2 + dimension);
38+
const float x_sum_sq = params[sq8::SUM_SQUARES];
8539

86-
// We dealt with the residual part. We are left with some multiple of 16 floats.
87-
// In each iteration we calculate 16 floats = 512 bits.
88-
do {
89-
L2SqrStep(pVect1, pVect2, sum, min_val_vec, delta_vec);
90-
L2SqrStep(pVect1, pVect2, sum, min_val_vec, delta_vec);
91-
} while (pVect1 < pEnd1);
40+
// Get precomputed sum of squares from query blob
41+
const float y_sum_sq = static_cast<const float *>(pVect1v)[dimension + sq8::SUM_SQUARES_QUERY];
9242

93-
return my_mm256_reduce_add_ps(sum);
43+
// L2² = ||x||² + ||y||² - 2*IP(x, y)
44+
return x_sum_sq + y_sum_sq - 2.0f * ip;
9445
}

0 commit comments

Comments
 (0)