Skip to content

Commit f333089

Browse files
authored
Fix vec_caps to test for OS support too (on x64) (#126911) (#126925)
On x64, we are testing if we support vector capabilities (1 = "basic" = AVX2, 2 = "advanced" = AVX-512) in order to enable and choose a native implementation for some vector functions, using CPUID. However, under some circumstances, this is not sufficient: the OS on which we are running also needs to support AVX/AVX2 etc; basically, it needs to acknowledge it knows about the additional register and that it is able to handle them e.g. in context switches. To do that we need to a) test if the CPU has xsave feature and b) use the xgetbv to test if the OS set it (declaring it supports AVX/AVX2/etc). In most cases this is not needed, as all modern OSes do that, but for some virtualized situations (hypervisors, emulators, etc.) all the component along the chain must support it, and in some cases this is not a given. This PR introduces a change to the x64 version of vec_caps to check for OS support too, and a warning on the Java side in case the CPU supports vector capabilities but those are not enabled at OS level. Tested by passing noxsave to my linux box kernel boot options, and ensuring that the avx flags "disappear" from /proc/cpuinfo, and we fall back to the "no native vector" case. Fixes #126809
1 parent 3d598a9 commit f333089

File tree

6 files changed

+52
-7
lines changed

6 files changed

+52
-7
lines changed

docs/changelog/126911.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 126911
2+
summary: Fix `vec_caps` to test for OS support too (on x64)
3+
area: Vector Search
4+
type: bug
5+
issues:
6+
- 126809

libs/native/libraries/build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ configurations {
1919
}
2020

2121
var zstdVersion = "1.5.5"
22-
var vecVersion = "1.0.10"
22+
var vecVersion = "1.0.11"
2323

2424
repositories {
2525
exclusiveContent {

libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ public final class JdkVectorLibrary implements VectorLibrary {
4141
try {
4242
int caps = (int) vecCaps$mh.invokeExact();
4343
logger.info("vec_caps=" + caps);
44-
if (caps != 0) {
44+
if (caps > 0) {
4545
if (caps == 2) {
4646
dot7u$mh = downcallHandle(
4747
"dot7u_2",
@@ -67,6 +67,11 @@ public final class JdkVectorLibrary implements VectorLibrary {
6767
}
6868
INSTANCE = new JdkVectorSimilarityFunctions();
6969
} else {
70+
if (caps < 0) {
71+
logger.warn("""
72+
Your CPU supports vector capabilities, but they are disabled at OS level. For optimal performance, \
73+
enable them in your OS/Hypervisor/VM/container""");
74+
}
7075
dot7u$mh = null;
7176
sqr7u$mh = null;
7277
INSTANCE = null;

libs/simdvec/native/build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ apply plugin: 'cpp'
1111

1212
var os = org.gradle.internal.os.OperatingSystem.current()
1313

14-
// To update this library run publish_vec_binaries.sh ( or ./gradlew vecSharedLibrary )
14+
// To update this library run publish_vec_binaries.sh ( or ./gradlew buildSharedLibrary )
1515
// Or
1616
// For local development, build the docker image with:
1717
// docker build --platform linux/arm64 --progress=plain --file=Dockerfile.aarch64 . (for aarch64)

libs/simdvec/native/publish_vec_binaries.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ if [ -z "$ARTIFACTORY_API_KEY" ]; then
2020
exit 1;
2121
fi
2222

23-
VERSION="1.0.10"
23+
VERSION="1.0.11"
2424
ARTIFACTORY_REPOSITORY="${ARTIFACTORY_REPOSITORY:-https://artifactory.elastic.dev/artifactory/elasticsearch-native/}"
2525
TEMP=$(mktemp -d)
2626

libs/simdvec/native/src/vec/c/amd64/vec.c

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,23 @@ static inline void cpuid(int output[4], int functionNumber) {
4646
#endif
4747
}
4848

49+
// Multi-platform XGETBV "intrinsic"
50+
static inline int64_t xgetbv(int ctr) {
51+
#if defined(__GNUC__) || defined(__clang__)
52+
// use inline assembly, Gnu/AT&T syntax
53+
uint32_t a, d;
54+
__asm("xgetbv" : "=a"(a),"=d"(d) : "c"(ctr) : );
55+
return a | (((uint64_t) d) << 32);
56+
57+
#elif (defined (_MSC_FULL_VER) && _MSC_FULL_VER >= 160040000) || (defined (__INTEL_COMPILER) && __INTEL_COMPILER >= 1200)
58+
// Microsoft or Intel compiler supporting _xgetbv intrinsic
59+
return _xgetbv(ctr);
60+
61+
#else
62+
#error Unsupported compiler
63+
#endif
64+
}
65+
4966
// Utility function to horizontally add 8 32-bit integers
5067
static inline int hsum_i32_8(const __m256i a) {
5168
const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
@@ -57,11 +74,20 @@ static inline int hsum_i32_8(const __m256i a) {
5774

5875
EXPORT int vec_caps() {
5976
int cpuInfo[4] = {-1};
60-
// Calling __cpuid with 0x0 as the function_id argument
77+
// Calling CPUID function 0x0 as the function_id argument
6178
// gets the number of the highest valid function ID.
6279
cpuid(cpuInfo, 0);
6380
int functionIds = cpuInfo[0];
81+
if (functionIds == 0) {
82+
// No CPUID functions
83+
return 0;
84+
}
85+
// call CPUID function 0x1 for feature flags
86+
cpuid(cpuInfo, 1);
87+
int hasOsXsave = (cpuInfo[2] & (1 << 27)) != 0;
88+
int avxEnabledInOS = hasOsXsave && ((xgetbv(0) & 6) == 6);
6489
if (functionIds >= 7) {
90+
// call CPUID function 0x7 for AVX2/512 flags
6591
cpuid(cpuInfo, 7);
6692
int ebx = cpuInfo[1];
6793
int ecx = cpuInfo[2];
@@ -72,10 +98,18 @@ EXPORT int vec_caps() {
7298
// int avx512_vnni = (ecx & 0x00000800) != 0;
7399
// if (avx512 && avx512_vnni) {
74100
if (avx512) {
75-
return 2;
101+
if (avxEnabledInOS) {
102+
return 2;
103+
} else {
104+
return -2;
105+
}
76106
}
77107
if (avx2) {
78-
return 1;
108+
if (avxEnabledInOS) {
109+
return 1;
110+
} else {
111+
return -1;
112+
}
79113
}
80114
}
81115
return 0;

0 commit comments

Comments
 (0)