Skip to content

Commit bdcbf80

Browse files
authored
Adapt fp32 sq8 dist functions ip cosine [MOD-13392] (#882)
* Add SQ8-to-SQ8 distance functions and optimizations - Implemented inner product and cosine distance functions for SQ8-to-SQ8 vectors in SVE, NEON, and AVX512 architectures. - Added corresponding distance function selection logic in IP_space.cpp and function headers in IP_space.h. - Created benchmarks for SQ8-to-SQ8 distance functions to evaluate performance across different architectures. - Developed unit tests to validate the correctness of the new distance functions against expected results. - Ensured compatibility with existing optimization features for various CPU architectures. * Add SQ8-to-SQ8 benchmark tests and update related scripts * Format * Orgnizing * Add full sq8 bencharks * Optimize the sq8 sq8 * Optimize SQ8 distance functions for NEON by reducing operations and improving performance * format * Add NEON DOTPROD-optimized distance functions for SQ8-to-SQ8 calculations * PR * Remove NEON DOTPROD-optimized distance functions for INT8, UINT8, and SQ8-to-SQ8 calculations * Fix vector layout documentation by removing inv_norm from comments in NEON and AVX512 headers * Remove 'constexpr' from ones vector declaration in NEON inner product function * Add SQ8-to-SQ8 L2 squared distance functions with SIMD optimizations - Implemented NEON, SVE, and AVX512F optimized functions for calculating L2 squared distance between SQ8 (scalar quantized 8-bit) vectors. - Introduced helper functions for processing vector elements using NEON and SVE intrinsics. - Updated L2_space.cpp and L2_space.h to include new distance function for SQ8-to-SQ8. - Enhanced AVX512F, NEON, and SVE function selectors to choose the appropriate implementation based on CPU features. - Added unit tests to validate the correctness of the new L2 squared distance functions. - Updated benchmark tests to include performance measurements for the new implementations. * Change the name * Add full range tests for SQ8 distance functions with SIMD optimizations * Refactor distance functions to remove inv_norm parameter and update documentation accordingly * Update SQ8 Cosine test to normalize both input vectors and adjust distance assertion tolerance * Rename 'compressed' to 'quantized' in SQ8 functions for clarity and consistency * Rename 'compressed' to 'quantized' in SQ8 distance tests for clarity * Refactor quantization function to remove unused normalization calculations * Add TODO to store vector's norm and sum in L2 squared distance calculation * Implement SQ8-to-SQ8 distance functions with precomputed sum and norm using AVX512 VNNI; add benchmarks and tests for new functionality * Add edge case tests for SQ8-to-SQ8 precomputed cosine distance functions * Refactor SQ8 test cases to use CreateSQ8QuantizedVector for vector population * Implement SQ8-to-SQ8 precomputed distance functions using ARM NEON, SVE, and AVX512; add corresponding selection functions and update tests for consistency. * Implement SQ8-to-SQ8 precomputed inner product and cosine functions; update benchmarks and tests for new functionality * Refactor SQ8 distance functions and remove precomputed variants - Updated distance function declarations in IP_space.h to clarify that SQ8-to-SQ8 functions use precomputed sum/norm. - Removed precomputed distance function implementations for AVX512F, NEON, and SVE architectures from their respective source files. - Adjusted benchmark tests to remove references to precomputed distance functions and ensure they utilize the updated quantization methods. - Modified utility functions to support the creation of SQ8 quantized vectors with precomputed sum and norm. - Updated unit tests to reflect changes in the quantization process and removed tests specifically for precomputed distance functions. * Refactor SQ8 distance functions and tests for improved clarity and consistency - Updated include paths in AVX512F_BW_VL_VNNI.cpp to reflect new naming conventions. - Modified unit tests in test_spaces.cpp to streamline vector initialization and quantization processes. - Replaced repetitive code with utility functions for populating and quantizing vectors. - Enhanced assertions in tests to ensure optimized distance functions are correctly chosen and validated. - Removed unnecessary parameters from utility functions to simplify their interfaces. - Improved test coverage for edge cases, including zero and constant vectors, ensuring accuracy across various scenarios. * Refactor SQ8 benchmarks by removing precomputed variants and updating vector population methods * foramt * Remove serialization benchmark script for HNSW disk serialization * Refactor SQ8 distance functions and tests to remove precomputed norm references * format * Refactor SQ8 distance tests to use compressed vectors and improve normalization calculations * Update vector layout documentation to reflect removal of sum of squares in SQ8 implementations * Refactor L2 SQ8 distance computation to remove unused accumulators and streamline calculations * Refactor SQ8 distance functions to remove norm computation - Updated comments and documentation to reflect that the SQ8-to-SQ8 distance functions now only utilize precomputed sums, removing references to norms. - Modified function signatures and implementations across various SIMD architectures (AVX512F, NEON, SVE) to align with the new approach. - Adjusted utility functions for populating SQ8 vectors to include metadata for sums and normalization. - Updated unit tests and benchmarks to ensure compatibility with the new SQ8 vector population methods and to validate the correctness of distance calculations. * Update SQ8-to-SQ8 distance function comment to remove norm reference * Refactor cosine similarity functions to remove unnecessary subtraction in AVX2, SSE4, and SVE implementations * Refactor L2 SQ8 distance functions to eliminate unused accumulators and streamline calculations * Refactor SQ8 L2 and IP implementations to use common inner product function - Introduced SQ8_SQ8_InnerProduct_Impl for shared inner product calculations in SQ8 space. - Updated SQ8_SQ8_L2Sqr to utilize the new inner product implementation, improving performance and reducing code duplication. - Modified AVX512 and NEON SIMD implementations to leverage the common inner product function for L2 squared distance calculations. - Removed redundant code and tests related to full range vector comparisons, streamlining the test suite. - Ensured that vector layouts include sum of squares for optimized distance calculations. * Refactor cosine similarity functions to use specific SIMD implementations for improved clarity and performance * Refactor L2 distance functions for SQ8 vectors to utilize common inner product implementation and update metadata extraction in tests * Refactor benchmark setup to allocate additional space for sum and sum_squares in SQ8 vector tests * Add CPU feature checks to disable optimizations for AArch64 in SQ8 distance function * Add CPU feature checks to disable optimizations for AArch64 in SQ8 distance function tests * Fix formatting issues in SQ8 inner product function and clean up conditional compilation in tests * Refactor SQ8 distance functions and tests for improved readability and consistency * Refactor SQ8 L2Sqr tests to use quantized vectors and improve alignment checks * Enhance SQ8 Inner Product Implementations with Optimized Dot Product Calculations - Refactored inner product calculations for SQ8 vectors using NEON and SVE optimizations. - Integrated UINT8_InnerProductImp for efficient dot product computation in NEON and SVE implementations. - Updated inner product functions to handle 64-element chunks for improved performance. - Adjusted distance function selection logic to ensure optimizations are applied only for dimensions >= 16. - Added tests for zero vectors and constant vectors to validate optimized implementations against baseline results. - Ensured consistency in assertions for symmetry tests across various optimization flags. - Improved code readability and maintainability by removing redundant code and comments. * Fix header guard duplication and update test assertion for floating-point comparison * Add missing pragma once directive in NEON header files * Refactor SQ8 distance functions for improved performance and clarity - Updated inner product functions for NEON, SSE4, and SVE to streamline dequantization and reduce unnecessary calculations. - Consolidated common logic for inner product and cosine calculations across different SIMD implementations. - Enhanced the handling of vector normalization and quantization in unit tests, ensuring consistency in compressed vector sizes. - Adjusted benchmark tests to reflect changes in vector compression and distance function calls. - Corrected include paths for AVX512 implementations to maintain consistency across the codebase. * Update SQ8 vector population functions to include metadata and adjust compressed size calculations * Refactor SQ8 inner product functions for improved clarity and performance * Refactor L2 distance functions to utilize common inner product implementations for improved clarity and performance * Rename inner product implementation functions for AVX2 and AVX512 for clarity * Refactor SQ8 cosine function to utilize inner product function for improved clarity * Remove redundant inner product edge case tests for SQ8 distance functions * Add SVE2 support to SQ8-to-SQ8 Inner Product distance function * Fix SQ8_Cosine to call the correct inner product function for improved accuracy * Remove SVE2 and other optimizations from SQ8 cosine function test for ARM architecture * Add L2 distance function without optimizations for testing purposes * Refactor L2 distance function and update test assertions for precision * Update L2 squared distance functions to support 64 residuals in NEON implementation * Refactor L2 distance function conditions for NEON optimizations * Adjust NEON_DOTPROD benchmark initialization to use a dimension of 16 * Update NEON benchmarks to support 64 dimensions for L2 and Cosine metrics * Optimize SQ8 Inner Product Implementation - Refactor the SQ8 inner product computation to eliminate unnecessary dequantization steps, improving performance. - Introduce a new helper function `InnerProductStepSQ8` that computes the inner product directly using quantized values. - Update the main inner product function `SQ8_InnerProductSIMD_SVE_IMP` to utilize the new helper function, streamlining the computation process. - Modify the test suite to validate the new implementation, ensuring correctness against the baseline non-optimized version. - Add edge case tests for self-distance, symmetry, zero vectors, constant vectors, and extreme values to ensure robustness of the SQ8 cosine distance function. - Introduce utility functions for preprocessing and populating SQ8 queries, enhancing test clarity and maintainability. * Refactor SQ8 inner product functions to clarify FMA usage and improve performance * Update SQ8 test cases to improve alignment checks and adjust quantized size calculations * Add optimized SQ8 inner product implementation and update test cases * Fix pointer usage in SQ8 inner product implementation to reference original vectors * Add sq8 type definition and update inner product implementations for quantization parameters * Refactor SQ8 inner product implementations to use structured quantization parameters and clean up code formatting * Fix SQ8 EdgeCases test by adjusting vector size for constant vector test * Fix formatting in SQ8_EdgeCases test by adjusting vector initialization * Refactor SQ8 inner product implementations to use precomputed y_sum from query blob * Fix formatting in SQ8_EdgeCases test for better readability * Refactor SQ8 cosine distance calculation to use optimized function
1 parent f6df960 commit bdcbf80

File tree

12 files changed

+887
-526
lines changed

12 files changed

+887
-526
lines changed

src/VecSim/spaces/IP/IP.cpp

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,39 +16,57 @@ using bfloat16 = vecsim_types::bfloat16;
1616
using float16 = vecsim_types::float16;
1717
using sq8 = vecsim_types::sq8;
1818

19-
float FLOAT_INTEGER_InnerProduct(const float *pVect1v, const uint8_t *pVect2v, size_t dimension,
20-
float min_val, float delta) {
21-
float res = 0;
22-
for (size_t i = 0; i < dimension; i++) {
23-
float dequantized_V2 = (pVect2v[i] * delta + min_val);
24-
res += pVect1v[i] * dequantized_V2;
25-
}
26-
return res;
27-
}
28-
19+
/*
20+
* Optimized asymmetric SQ8 inner product using algebraic identity:
21+
* IP(x, y) = Σ(x_i * y_i)
22+
* ≈ Σ((min + delta * q_i) * y_i)
23+
* = min * Σy_i + delta * Σ(q_i * y_i)
24+
* = min * y_sum + delta * quantized_dot_product
25+
*
26+
* Uses 4x loop unrolling with multiple accumulators for ILP.
27+
* pVect1 is a vector of float32, pVect2 is a quantized uint8_t vector
28+
*/
2929
float SQ8_InnerProduct(const void *pVect1v, const void *pVect2v, size_t dimension) {
30+
3031
const auto *pVect1 = static_cast<const float *>(pVect1v);
3132
const auto *pVect2 = static_cast<const uint8_t *>(pVect2v);
32-
// pVect2 is a vector of uint8_t, so we need to de-quantize it, normalize it and then multiply
33-
// it. it is structured as [quantized values (int8_t * dim)][min_val (float)][delta
34-
// (float)]] The last two values are used to dequantize the vector.
35-
const float min_val = *reinterpret_cast<const float *>(pVect2 + dimension);
36-
const float delta = *reinterpret_cast<const float *>(pVect2 + dimension + sizeof(float));
37-
// Compute inner product with dequantization
38-
const float res = FLOAT_INTEGER_InnerProduct(pVect1, pVect2, dimension, min_val, delta);
39-
return 1.0f - res;
33+
34+
// Use 4 accumulators for instruction-level parallelism
35+
float sum0 = 0, sum1 = 0, sum2 = 0, sum3 = 0;
36+
37+
// Main loop: process 4 elements per iteration
38+
size_t i = 0;
39+
size_t dim4 = dimension & ~size_t(3); // dim4 is a multiple of 4
40+
for (; i < dim4; i += 4) {
41+
sum0 += pVect1[i + 0] * static_cast<float>(pVect2[i + 0]);
42+
sum1 += pVect1[i + 1] * static_cast<float>(pVect2[i + 1]);
43+
sum2 += pVect1[i + 2] * static_cast<float>(pVect2[i + 2]);
44+
sum3 += pVect1[i + 3] * static_cast<float>(pVect2[i + 3]);
45+
}
46+
47+
// Handle remainder (0-3 elements)
48+
for (; i < dimension; i++) {
49+
sum0 += pVect1[i] * static_cast<float>(pVect2[i]);
50+
}
51+
52+
// Combine accumulators
53+
float quantized_dot = (sum0 + sum1) + (sum2 + sum3);
54+
55+
// Get quantization parameters from stored vector
56+
const float *params = reinterpret_cast<const float *>(pVect2 + dimension);
57+
const float min_val = params[sq8::MIN_VAL];
58+
const float delta = params[sq8::DELTA];
59+
60+
// Get precomputed y_sum from query blob (stored after the dim floats)
61+
const float y_sum = pVect1[dimension + sq8::SUM_QUERY];
62+
63+
// Apply formula: IP = min * y_sum + delta * Σ(q_i * y_i)
64+
const float ip = min_val * y_sum + delta * quantized_dot;
65+
return 1.0f - ip;
4066
}
4167

4268
float SQ8_Cosine(const void *pVect1v, const void *pVect2v, size_t dimension) {
43-
const auto *pVect1 = static_cast<const float *>(pVect1v);
44-
const auto *pVect2 = static_cast<const uint8_t *>(pVect2v);
45-
46-
// Get quantization parameters
47-
const float min_val = *reinterpret_cast<const float *>(pVect2 + dimension);
48-
const float delta = *reinterpret_cast<const float *>(pVect2 + dimension + sizeof(float));
49-
// Compute inner product with dequantization
50-
const float res = FLOAT_INTEGER_InnerProduct(pVect1, pVect2, dimension, min_val, delta);
51-
return 1.0f - res;
69+
return SQ8_InnerProduct(pVect1v, pVect2v, dimension);
5270
}
5371

5472
// SQ8-to-SQ8: Common inner product implementation that returns the raw inner product value

src/VecSim/spaces/IP/IP_AVX2_FMA_SQ8.h

Lines changed: 52 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -6,91 +6,96 @@
66
* (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the
77
* GNU Affero General Public License v3 (AGPLv3).
88
*/
9+
#pragma once
910
#include "VecSim/spaces/space_includes.h"
1011
#include "VecSim/spaces/AVX_utils.h"
12+
#include "VecSim/types/sq8.h"
13+
using sq8 = vecsim_types::sq8;
1114

15+
/*
16+
* Optimized asymmetric SQ8 inner product using algebraic identity:
17+
*
18+
* IP(x, y) = Σ(x_i * y_i)
19+
* ≈ Σ((min + delta * q_i) * y_i)
20+
* = min * Σy_i + delta * Σ(q_i * y_i)
21+
* = min * y_sum + delta * quantized_dot_product
22+
*
23+
* where y_sum = Σy_i is precomputed and stored in the query blob.
24+
* This avoids dequantization in the hot loop - we only compute Σ(q_i * y_i).
25+
*
26+
* This version uses FMA instructions for better performance.
27+
*/
28+
29+
// Helper: compute Σ(q_i * y_i) for 8 elements using FMA (no dequantization)
1230
static inline void InnerProductStepSQ8_FMA(const float *&pVect1, const uint8_t *&pVect2,
13-
__m256 &sum256, const __m256 &min_val_vec,
14-
const __m256 &delta_vec) {
15-
// Load 8 float elements from pVect1
31+
__m256 &sum256) {
32+
// Load 8 float elements from query
1633
__m256 v1 = _mm256_loadu_ps(pVect1);
1734
pVect1 += 8;
1835

19-
// Load 8 uint8 elements from pVect2, convert to int32, then to float
20-
__m128i v2_128 = _mm_loadl_epi64((__m128i *)pVect2);
36+
// Load 8 uint8 elements and convert to float
37+
__m128i v2_128 = _mm_loadl_epi64(reinterpret_cast<const __m128i *>(pVect2));
2138
pVect2 += 8;
2239

23-
// Zero-extend uint8 to int32
2440
__m256i v2_256 = _mm256_cvtepu8_epi32(v2_128);
25-
26-
// Convert int32 to float
2741
__m256 v2_f = _mm256_cvtepi32_ps(v2_256);
2842

29-
// Dequantize and compute dot product in one step using FMA
30-
// (val * delta) + min_val -> v2_dequant
31-
// sum256 += v1 * v2_dequant
32-
// Using FMA: sum256 = v1 * v2_dequant + sum256
33-
34-
// First, compute v2_dequant = v2_f * delta_vec + min_val_vec
35-
__m256 v2_dequant = _mm256_fmadd_ps(v2_f, delta_vec, min_val_vec);
36-
37-
// Then, compute sum256 += v1 * v2_dequant using FMA
38-
sum256 = _mm256_fmadd_ps(v1, v2_dequant, sum256);
43+
// Accumulate q_i * y_i using FMA (no dequantization!)
44+
sum256 = _mm256_fmadd_ps(v2_f, v1, sum256);
3945
}
4046

4147
template <unsigned char residual> // 0..15
4248
float SQ8_InnerProductImp_FMA(const void *pVect1v, const void *pVect2v, size_t dimension) {
4349
const float *pVect1 = static_cast<const float *>(pVect1v);
44-
// pVect2 is a quantized uint8_t vector
4550
const uint8_t *pVect2 = static_cast<const uint8_t *>(pVect2v);
4651
const float *pEnd1 = pVect1 + dimension;
4752

48-
// Get dequantization parameters from the end of quantized vector
49-
const float min_val = *reinterpret_cast<const float *>(pVect2 + dimension);
50-
const float delta = *reinterpret_cast<const float *>(pVect2 + dimension + sizeof(float));
51-
// Create broadcast vectors for SIMD operations
52-
__m256 min_val_vec = _mm256_set1_ps(min_val);
53-
__m256 delta_vec = _mm256_set1_ps(delta);
54-
53+
// Initialize sum accumulator for Σ(q_i * y_i)
5554
__m256 sum256 = _mm256_setzero_ps();
5655

57-
// Deal with 1-7 floats with mask loading, if needed. `dim` is >16, so we have at least one
58-
// 16-float block, so mask loading is guaranteed to be safe.
56+
// Handle residual elements first (0-7 elements)
5957
if constexpr (residual % 8) {
6058
__mmask8 constexpr mask = (1 << (residual % 8)) - 1;
6159
__m256 v1 = my_mm256_maskz_loadu_ps<mask>(pVect1);
6260
pVect1 += residual % 8;
6361

64-
// Load quantized values and dequantize
65-
__m128i v2_128 = _mm_loadl_epi64((__m128i *)pVect2);
62+
// Load uint8 elements and convert to float
63+
__m128i v2_128 = _mm_loadl_epi64(reinterpret_cast<const __m128i *>(pVect2));
6664
pVect2 += residual % 8;
6765

68-
// Zero-extend uint8 to int32
6966
__m256i v2_256 = _mm256_cvtepu8_epi32(v2_128);
70-
71-
// Convert int32 to float
7267
__m256 v2_f = _mm256_cvtepi32_ps(v2_256);
7368

74-
// Dequantize using FMA: (val * delta) + min_val
75-
__m256 v2_dequant = _mm256_fmadd_ps(v2_f, delta_vec, min_val_vec);
76-
77-
// Compute dot product with masking
78-
sum256 = _mm256_mul_ps(v1, v2_dequant);
69+
// Compute q_i * y_i (no dequantization)
70+
sum256 = _mm256_mul_ps(v1, v2_f);
7971
}
8072

81-
// If the reminder is >=8, have another step of 8 floats
73+
// If the residual is >=8, have another step of 8 floats
8274
if constexpr (residual >= 8) {
83-
InnerProductStepSQ8_FMA(pVect1, pVect2, sum256, min_val_vec, delta_vec);
75+
InnerProductStepSQ8_FMA(pVect1, pVect2, sum256);
8476
}
8577

86-
// We dealt with the residual part. We are left with some multiple of 16 floats.
87-
// In each iteration we calculate 16 floats = 512 bits.
78+
// Process remaining full chunks of 16 elements (2x8)
79+
// Using do-while since dim > 16 guarantees at least one iteration
8880
do {
89-
InnerProductStepSQ8_FMA(pVect1, pVect2, sum256, min_val_vec, delta_vec);
90-
InnerProductStepSQ8_FMA(pVect1, pVect2, sum256, min_val_vec, delta_vec);
81+
InnerProductStepSQ8_FMA(pVect1, pVect2, sum256);
82+
InnerProductStepSQ8_FMA(pVect1, pVect2, sum256);
9183
} while (pVect1 < pEnd1);
9284

93-
return my_mm256_reduce_add_ps(sum256);
85+
// Reduce to get Σ(q_i * y_i)
86+
float quantized_dot = my_mm256_reduce_add_ps(sum256);
87+
88+
// Get quantization parameters from stored vector (after quantized data)
89+
const uint8_t *pVect2Base = static_cast<const uint8_t *>(pVect2v);
90+
const float *params2 = reinterpret_cast<const float *>(pVect2Base + dimension);
91+
const float min_val = params2[sq8::MIN_VAL];
92+
const float delta = params2[sq8::DELTA];
93+
94+
// Get precomputed y_sum from query blob (stored after the dim floats)
95+
const float y_sum = static_cast<const float *>(pVect1v)[dimension + sq8::SUM_QUERY];
96+
97+
// Apply the algebraic formula: IP = min * y_sum + delta * Σ(q_i * y_i)
98+
return min_val * y_sum + delta * quantized_dot;
9499
}
95100

96101
template <unsigned char residual> // 0..15
@@ -100,7 +105,6 @@ float SQ8_InnerProductSIMD16_AVX2_FMA(const void *pVect1v, const void *pVect2v,
100105

101106
template <unsigned char residual> // 0..15
102107
float SQ8_CosineSIMD16_AVX2_FMA(const void *pVect1v, const void *pVect2v, size_t dimension) {
103-
// Calculate inner product using common implementation with normalization
104-
float ip = SQ8_InnerProductImp_FMA<residual>(pVect1v, pVect2v, dimension);
105-
return 1.0f - ip;
108+
// Cosine distance = 1 - IP (vectors are pre-normalized)
109+
return SQ8_InnerProductSIMD16_AVX2_FMA<residual>(pVect1v, pVect2v, dimension);
106110
}

src/VecSim/spaces/IP/IP_AVX2_SQ8.h

Lines changed: 52 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,85 +6,96 @@
66
* (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the
77
* GNU Affero General Public License v3 (AGPLv3).
88
*/
9+
#pragma once
910
#include "VecSim/spaces/space_includes.h"
1011
#include "VecSim/spaces/AVX_utils.h"
12+
#include "VecSim/types/sq8.h"
1113

12-
static inline void InnerProductStepSQ8(const float *&pVect1, const uint8_t *&pVect2, __m256 &sum256,
13-
const __m256 &min_val_vec, const __m256 &delta_vec) {
14-
// Load 8 float elements from pVect1
14+
using sq8 = vecsim_types::sq8;
15+
16+
/*
17+
* Optimized asymmetric SQ8 inner product using algebraic identity:
18+
*
19+
* IP(x, y) = Σ(x_i * y_i)
20+
* ≈ Σ((min + delta * q_i) * y_i)
21+
* = min * Σy_i + delta * Σ(q_i * y_i)
22+
* = min * y_sum + delta * quantized_dot_product
23+
*
24+
* where y_sum = Σy_i is precomputed and stored in the query blob.
25+
* This avoids dequantization in the hot loop - we only compute Σ(q_i * y_i).
26+
*/
27+
28+
// Helper: compute Σ(q_i * y_i) for 8 elements (no dequantization)
29+
static inline void InnerProductStepSQ8(const float *&pVect1, const uint8_t *&pVect2,
30+
__m256 &sum256) {
31+
// Load 8 float elements from query
1532
__m256 v1 = _mm256_loadu_ps(pVect1);
1633
pVect1 += 8;
1734

18-
// Load 8 uint8 elements from pVect2, convert to int32, then to float
19-
__m128i v2_128 = _mm_loadl_epi64((__m128i *)pVect2);
35+
// Load 8 uint8 elements and convert to float
36+
__m128i v2_128 = _mm_loadl_epi64(reinterpret_cast<const __m128i *>(pVect2));
2037
pVect2 += 8;
2138

22-
// Zero-extend uint8 to int32
2339
__m256i v2_256 = _mm256_cvtepu8_epi32(v2_128);
24-
25-
// Convert int32 to float
2640
__m256 v2_f = _mm256_cvtepi32_ps(v2_256);
2741

28-
// Dequantize: (val * delta) + min_val
29-
__m256 v2_dequant = _mm256_add_ps(_mm256_mul_ps(v2_f, delta_vec), min_val_vec);
30-
31-
// Compute dot product and add to sum
32-
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2_dequant));
42+
// Accumulate q_i * y_i (no dequantization!)
43+
// Using mul + add since this is the non-FMA version
44+
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v2_f, v1));
3345
}
3446

3547
template <unsigned char residual> // 0..15
3648
float SQ8_InnerProductImp_AVX2(const void *pVect1v, const void *pVect2v, size_t dimension) {
3749
const float *pVect1 = static_cast<const float *>(pVect1v);
38-
// pVect2 is a quantized uint8_t vector
3950
const uint8_t *pVect2 = static_cast<const uint8_t *>(pVect2v);
4051
const float *pEnd1 = pVect1 + dimension;
4152

42-
// Get dequantization parameters from the end of quantized vector
43-
const float min_val = *reinterpret_cast<const float *>(pVect2 + dimension);
44-
const float delta = *reinterpret_cast<const float *>(pVect2 + dimension + sizeof(float));
45-
// Create broadcast vectors for SIMD operations
46-
__m256 min_val_vec = _mm256_set1_ps(min_val);
47-
__m256 delta_vec = _mm256_set1_ps(delta);
48-
53+
// Initialize sum accumulator for Σ(q_i * y_i)
4954
__m256 sum256 = _mm256_setzero_ps();
5055

51-
// Deal with 1-7 floats with mask loading, if needed. `dim` is >16, so we have at least one
52-
// 16-float block, so mask loading is guaranteed to be safe.
56+
// Handle residual elements first (0-7 elements)
5357
if constexpr (residual % 8) {
5458
__mmask8 constexpr mask = (1 << (residual % 8)) - 1;
5559
__m256 v1 = my_mm256_maskz_loadu_ps<mask>(pVect1);
5660
pVect1 += residual % 8;
5761

58-
// Load quantized values and dequantize
59-
__m128i v2_128 = _mm_loadl_epi64((__m128i *)pVect2);
62+
// Load uint8 elements and convert to float
63+
__m128i v2_128 = _mm_loadl_epi64(reinterpret_cast<const __m128i *>(pVect2));
6064
pVect2 += residual % 8;
6165

62-
// Zero-extend uint8 to int32
6366
__m256i v2_256 = _mm256_cvtepu8_epi32(v2_128);
64-
65-
// Convert int32 to float
6667
__m256 v2_f = _mm256_cvtepi32_ps(v2_256);
6768

68-
// Dequantize: (val * delta) + min_val
69-
__m256 v2_dequant = _mm256_add_ps(_mm256_mul_ps(v2_f, delta_vec), min_val_vec);
70-
71-
// Compute dot product with masking
72-
sum256 = _mm256_mul_ps(v1, v2_dequant);
69+
// Compute q_i * y_i (no dequantization)
70+
sum256 = _mm256_mul_ps(v1, v2_f);
7371
}
7472

75-
// If the reminder is >=8, have another step of 8 floats
73+
// If the residual is >=8, have another step of 8 floats
7674
if constexpr (residual >= 8) {
77-
InnerProductStepSQ8(pVect1, pVect2, sum256, min_val_vec, delta_vec);
75+
InnerProductStepSQ8(pVect1, pVect2, sum256);
7876
}
7977

80-
// We dealt with the residual part. We are left with some multiple of 16 floats.
81-
// In each iteration we calculate 16 floats = 512 bits.
78+
// Process remaining full chunks of 16 elements (2x8)
79+
// Using do-while since dim > 16 guarantees at least one iteration
8280
do {
83-
InnerProductStepSQ8(pVect1, pVect2, sum256, min_val_vec, delta_vec);
84-
InnerProductStepSQ8(pVect1, pVect2, sum256, min_val_vec, delta_vec);
81+
InnerProductStepSQ8(pVect1, pVect2, sum256);
82+
InnerProductStepSQ8(pVect1, pVect2, sum256);
8583
} while (pVect1 < pEnd1);
8684

87-
return my_mm256_reduce_add_ps(sum256);
85+
// Reduce to get Σ(q_i * y_i)
86+
float quantized_dot = my_mm256_reduce_add_ps(sum256);
87+
88+
// Get quantization parameters from stored vector (after quantized data)
89+
const uint8_t *pVect2Base = static_cast<const uint8_t *>(pVect2v);
90+
const float *params2 = reinterpret_cast<const float *>(pVect2Base + dimension);
91+
const float min_val = params2[sq8::MIN_VAL];
92+
const float delta = params2[sq8::DELTA];
93+
94+
// Get precomputed y_sum from query blob (stored after the dim floats)
95+
const float y_sum = static_cast<const float *>(pVect1v)[dimension + sq8::SUM_QUERY];
96+
97+
// Apply the algebraic formula: IP = min * y_sum + delta * Σ(q_i * y_i)
98+
return min_val * y_sum + delta * quantized_dot;
8899
}
89100

90101
template <unsigned char residual> // 0..15
@@ -95,6 +106,5 @@ float SQ8_InnerProductSIMD16_AVX2(const void *pVect1v, const void *pVect2v, size
95106
template <unsigned char residual> // 0..15
96107
float SQ8_CosineSIMD16_AVX2(const void *pVect1v, const void *pVect2v, size_t dimension) {
97108
// Calculate inner product using common implementation with normalization
98-
float ip = SQ8_InnerProductImp_AVX2<residual>(pVect1v, pVect2v, dimension);
99-
return 1.0f - ip;
109+
return SQ8_InnerProductSIMD16_AVX2<residual>(pVect1v, pVect2v, dimension);
100110
}

0 commit comments

Comments
 (0)