Skip to content

Commit fd1bafb

Browse files
authored
Fix AVX2 distance function and build with AVX2 if available on host machine (#524)
1 parent 555b2e9 commit fd1bafb

File tree

8 files changed

+413
-149
lines changed

8 files changed

+413
-149
lines changed

src/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,14 @@ if (NOT $ENV{CIBUILDWHEEL} EQUAL 1 AND NOT $ENV{CONDA_BUILD} EQUAL 1)
135135
endif()
136136
endif()
137137

138+
# AVX2 flag
139+
include(CheckAVX2Support)
140+
CheckAVX2Support()
141+
if (COMPILER_SUPPORTS_AVX2)
142+
add_compile_options(${COMPILER_AVX2_FLAG} -mfma)
143+
add_definitions(-DAVX2_ENABLED)
144+
endif()
145+
138146
# Default to Release build
139147
if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
140148
message(STATUS "No build type selected, default to Release")
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#
2+
# CheckAVX2Support.cmake
3+
#
4+
#
5+
# The MIT License
6+
#
7+
# Copyright (c) 2018-2021 TileDB, Inc.
8+
#
9+
# Permission is hereby granted, free of charge, to any person obtaining a copy
10+
# of this software and associated documentation files (the "Software"), to deal
11+
# in the Software without restriction, including without limitation the rights
12+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13+
# copies of the Software, and to permit persons to whom the Software is
14+
# furnished to do so, subject to the following conditions:
15+
#
16+
# The above copyright notice and this permission notice shall be included in
17+
# all copies or substantial portions of the Software.
18+
#
19+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
25+
# THE SOFTWARE.
26+
#
27+
# This file defines a function to detect toolchain support for AVX2.
28+
#
29+
30+
include(CheckCXXSourceRuns)
31+
include(CMakePushCheckState)
32+
33+
#
34+
# Determines if AVX2 is available.
35+
#
36+
# This function sets two variables in the cache:
37+
# COMPILER_SUPPORTS_AVX2 - Set to true if the compiler supports AVX2.
38+
# COMPILER_AVX2_FLAG - Set to the appropriate flag to enable AVX2.
39+
#
40+
function (CheckAVX2Support)
41+
# If defined to a false value other than "", return without checking for avx2 support
42+
if (DEFINED COMPILER_SUPPORTS_AVX2 AND
43+
NOT COMPILER_SUPPORTS_AVX2 STREQUAL "" AND
44+
NOT COMPILER_SUPPORTS_AVX2)
45+
message("AVX2 compiler support disabled by COMPILER_SUPPORTS_AVX2=${COMPILER_SUPPORTS_AVX2}")
46+
return()
47+
endif()
48+
49+
if (MSVC)
50+
set(COMPILER_AVX2_FLAG "/arch:AVX2" CACHE STRING "Compiler flag for AVX2 support.")
51+
else()
52+
set(COMPILER_AVX2_FLAG "-mavx2" CACHE STRING "Compiler flag for AVX2 support.")
53+
endif()
54+
55+
cmake_push_check_state()
56+
set(CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS} ${COMPILER_AVX2_FLAG}")
57+
check_cxx_source_runs("
58+
#include <immintrin.h>
59+
int main() {
60+
__m256i packed = _mm256_set_epi32(-1, -2, -3, -4, -5, -6, -7, -8);
61+
__m256i absolute_values = _mm256_abs_epi32(packed);
62+
return 0;
63+
}"
64+
COMPILER_SUPPORTS_AVX2
65+
)
66+
cmake_pop_check_state()
67+
if (COMPILER_SUPPORTS_AVX2)
68+
message(STATUS "AVX2 support detected.")
69+
else()
70+
message(STATUS "AVX2 support not detected.")
71+
endif()
72+
endfunction()

src/include/detail/scoring/inner_product_avx.h

Lines changed: 64 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -88,29 +88,38 @@ inline float avx2_inner_product(const V& a, const W& b) {
8888

8989
template <feature_vector V, feature_vector W>
9090
requires std::same_as<typename V::value_type, float> &&
91-
std::same_as<typename W::value_type, uint8_t>
91+
(std::same_as<typename W::value_type, uint8_t> ||
92+
std::same_as<typename W::value_type, int8_t>)
9293
inline float avx2_inner_product(const V& a, const W& b) {
9394
// @todo Align on 256 bit boundaries
9495
const size_t start = 0;
9596
const size_t size_a = size(a);
9697
const size_t stop = size_a - (size_a % 8);
9798

9899
const float* a_ptr = a.data();
99-
const uint8_t* b_ptr = b.data();
100+
// Can be uint8_t* or int8_t*
101+
const auto* b_ptr = b.data();
100102

101103
__m256 vec_sum = _mm256_setzero_ps();
102104

103105
for (size_t i = start; i < stop; i += 8) {
104106
// Load 8 floats
105-
__m256 a_floats = _mm256_loadu_ps(a_ptr + i + 0);
107+
__m256 a_floats = _mm256_loadu_ps(a_ptr + i);
106108

107-
// Load 8 bytes
109+
// Load 8 bytes (uint8_t or int8_t)
108110
__m128i vec_b = _mm_loadu_si64((__m64*)(b_ptr + i));
109111

110-
// Zero extend 8bit to 32bit ints
111-
__m256i b_ints = _mm256_cvtepu8_epi32(vec_b);
112-
113-
// Convert signed integers to floats
112+
// Conditionally convert based on the type of W::value_type
113+
__m256i b_ints;
114+
if constexpr (std::same_as<typename W::value_type, uint8_t>) {
115+
// Zero extend uint8_t to int32_t
116+
b_ints = _mm256_cvtepu8_epi32(vec_b);
117+
} else if constexpr (std::same_as<typename W::value_type, int8_t>) {
118+
// Sign extend int8_t to int32_t
119+
b_ints = _mm256_cvtepi8_epi32(vec_b);
120+
}
121+
122+
// Convert the 32-bit integers to floats
114123
__m256 b_floats = _mm256_cvtepi32_ps(b_ints);
115124

116125
// Multiply and accumulate
@@ -139,30 +148,39 @@ inline float avx2_inner_product(const V& a, const W& b) {
139148
}
140149

141150
template <feature_vector V, feature_vector W>
142-
requires std::same_as<typename V::value_type, uint8_t> &&
143-
std::same_as<typename W::value_type, float>
151+
requires(std::same_as<typename V::value_type, uint8_t> ||
152+
std::same_as<typename V::value_type, int8_t>) &&
153+
std::same_as<typename W::value_type, float>
144154
inline float avx2_inner_product(const V& a, const W& b) {
145155
// @todo Align on 256 bit boundaries
146156
const size_t start = 0;
147157
const size_t size_a = size(a);
148158
const size_t stop = size_a - (size_a % 8);
149159

150-
const uint8_t* a_ptr = a.data();
160+
// Can be uint8_t* or int8_t*
161+
const auto* a_ptr = a.data();
151162
const float* b_ptr = b.data();
152163

153164
__m256 vec_sum = _mm256_setzero_ps();
154165

155166
for (size_t i = start; i < stop; i += 8) {
156-
// Load 8 bytes == 64 bits -- zeros out top 8 bytes
167+
// Load 8 bytes (either uint8_t or int8_t)
157168
__m128i vec_a = _mm_loadu_si64((__m64*)(a_ptr + i));
158169

159170
// Load 8 floats
160171
__m256 b_floats = _mm256_loadu_ps(b_ptr + i + 0);
161172

162-
// Zero extend 8bit to 32bit ints
163-
__m256i a_ints = _mm256_cvtepu8_epi32(vec_a);
164-
165-
// Convert signed integers to floats
173+
// Extend 8 bit to 32 bit ints
174+
__m256i a_ints;
175+
if constexpr (std::same_as<typename V::value_type, uint8_t>) {
176+
// Zero extend uint8_t to int32_t
177+
a_ints = _mm256_cvtepu8_epi32(vec_a);
178+
} else if constexpr (std::same_as<typename V::value_type, int8_t>) {
179+
// Sign extend int8_t to int32_t
180+
a_ints = _mm256_cvtepi8_epi32(vec_a);
181+
}
182+
183+
// Convert the 32-bit integers to floats
166184
__m256 a_floats = _mm256_cvtepi32_ps(a_ints);
167185

168186
// Multiply and accumulate
@@ -191,29 +209,49 @@ inline float avx2_inner_product(const V& a, const W& b) {
191209
}
192210

193211
template <feature_vector V, feature_vector W>
194-
requires std::same_as<typename V::value_type, uint8_t> &&
195-
std::same_as<typename W::value_type, uint8_t>
212+
requires(std::same_as<typename V::value_type, uint8_t> ||
213+
std::same_as<typename V::value_type, int8_t>) &&
214+
(std::same_as<typename W::value_type, uint8_t> ||
215+
std::same_as<typename W::value_type, int8_t>)
196216
inline float avx2_inner_product(const V& a, const W& b) {
197217
// @todo Align on 256 bit boundaries
198218
const size_t start = 0;
199219
const size_t size_a = size(a);
200220
const size_t stop = size_a - (size_a % 8);
201221

202-
const uint8_t* a_ptr = a.data();
203-
const uint8_t* b_ptr = b.data();
222+
// Can be either uint8_t* or int8_t*
223+
const auto* a_ptr = a.data();
224+
// Can be either uint8_t* or int8_t*
225+
const auto* b_ptr = b.data();
204226

205227
__m256 vec_sum = _mm256_setzero_ps();
206228

207229
for (size_t i = start; i < stop; i += 8) {
208-
// Load 8 bytes == 64 bits -- zeros out top 8 bytes
230+
// Load 8 bytes (uint8_t or int8_t) from both vectors
209231
__m128i vec_a = _mm_loadu_si64((__m64*)(a_ptr + i));
210232
__m128i vec_b = _mm_loadu_si64((__m64*)(b_ptr + i));
211233

212-
// Zero extend 8bit to 32bit ints
213-
__m256i a_ints = _mm256_cvtepu8_epi32(vec_a);
214-
__m256i b_ints = _mm256_cvtepu8_epi32(vec_b);
215-
216-
// Convert signed integers to floats
234+
// Extend 8 bit to 32 bit ints
235+
__m256i a_ints;
236+
if constexpr (std::same_as<typename V::value_type, uint8_t>) {
237+
// Zero extend uint8_t to int32_t
238+
a_ints = _mm256_cvtepu8_epi32(vec_a);
239+
} else if constexpr (std::same_as<typename V::value_type, int8_t>) {
240+
// Sign extend int8_t to int32_t
241+
a_ints = _mm256_cvtepi8_epi32(vec_a);
242+
}
243+
244+
// Conditionally convert based on the type of W::value_type
245+
__m256i b_ints;
246+
if constexpr (std::same_as<typename W::value_type, uint8_t>) {
247+
// Zero extend uint8_t to int32_t
248+
b_ints = _mm256_cvtepu8_epi32(vec_b);
249+
} else if constexpr (std::same_as<typename W::value_type, int8_t>) {
250+
// Sign extend int8_t to int32_t
251+
b_ints = _mm256_cvtepi8_epi32(vec_b);
252+
}
253+
254+
// Convert the 32-bit integers to floats for both vectors
217255
__m256 a_floats = _mm256_cvtepi32_ps(a_ints);
218256
__m256 b_floats = _mm256_cvtepi32_ps(b_ints);
219257

src/include/detail/scoring/l2_distance.h

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,12 @@ inline float naive_sum_of_squares(const V& a, const W& b) {
7878
}
7979

8080
/**
81-
* Compute l2 distance between vector of float and vector of uint8_t
81+
* Compute l2 distance between vector of float and vector of uint8_t or int8_t
8282
*/
8383
template <feature_vector V, feature_vector W>
8484
requires std::same_as<typename V::value_type, float> &&
85-
std::same_as<typename W::value_type, uint8_t>
85+
(std::same_as<typename W::value_type, uint8_t> ||
86+
std::same_as<typename W::value_type, int8_t>)
8687
inline float naive_sum_of_squares(const V& a, const W& b) {
8788
size_t size_a = size(a);
8889
float sum = 0.0;
@@ -94,11 +95,12 @@ inline float naive_sum_of_squares(const V& a, const W& b) {
9495
}
9596

9697
/**
97-
* Compute l2 distance between vector of uint8_t and vector of float
98+
* Compute l2 distance between vector of uint8_t or int8_t and vector of float
9899
*/
99100
template <feature_vector V, feature_vector W>
100-
requires std::same_as<typename V::value_type, uint8_t> &&
101-
std::same_as<typename W::value_type, float>
101+
requires(std::same_as<typename V::value_type, uint8_t> ||
102+
std::same_as<typename V::value_type, int8_t>) &&
103+
std::same_as<typename W::value_type, float>
102104
inline float naive_sum_of_squares(const V& a, const W& b) {
103105
size_t size_a = size(a);
104106
float sum = 0.0;
@@ -110,11 +112,14 @@ inline float naive_sum_of_squares(const V& a, const W& b) {
110112
}
111113

112114
/**
113-
* Compute l2 distance between vector of uint8_t and vector of uint8_t
115+
* Compute l2 distance between vector of uint8_t or int8_t and vector of uint8_t
116+
* or int8_t
114117
*/
115118
template <feature_vector V, feature_vector W>
116-
requires std::same_as<typename V::value_type, uint8_t> &&
117-
std::same_as<typename W::value_type, uint8_t>
119+
requires(std::same_as<typename V::value_type, uint8_t> ||
120+
std::same_as<typename V::value_type, int8_t>) &&
121+
(std::same_as<typename W::value_type, uint8_t> ||
122+
std::same_as<typename W::value_type, int8_t>)
118123
inline float naive_sum_of_squares(const V& a, const W& b) {
119124
size_t size_a = size(a);
120125
float sum = 0.0;

0 commit comments

Comments
 (0)