Skip to content

Commit 4eace8a

Browse files
GPU: Restrict GPU indexing to FLOAT element types (#139084)
Reject BYTE, BFLOAT16, and BIT element types in GPUPlugin by returning null, falling back to CPU-based vector indexing.
1 parent fcf3550 commit 4eace8a

File tree

3 files changed

+63
-9
lines changed

3 files changed

+63
-9
lines changed

docs/changelog/139084.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 139084
2+
summary: "GPU: Restrict GPU indexing to FLOAT element types"
3+
area: Vector Search
4+
type: enhancement
5+
issues: []

x-pack/plugin/gpu/src/internalClusterTest/java/org/elasticsearch/xpack/gpu/GPUPluginInitializationWithGPUIT.java

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ public void testAutoModeSupportedVectorType() {
9797
settings,
9898
indexOptions,
9999
randomGPUSupportedSimilarity(indexOptions.getType()),
100-
// TODO add other type support
101100
DenseVectorFieldMapper.ElementType.FLOAT
102101
);
103102
assertNotNull(format);
@@ -118,12 +117,36 @@ public void testAutoModeUnsupportedVectorType() {
118117
settings,
119118
indexOptions,
120119
randomGPUSupportedSimilarity(indexOptions.getType()),
121-
// TODO add other type support
122120
DenseVectorFieldMapper.ElementType.FLOAT
123121
);
124122
assertNull(format);
125123
}
126124

125+
public void testAutoModeUnsupportedElementType() {
126+
gpuMode = GPUPlugin.GpuMode.AUTO;
127+
assumeTrue("GPU_FORMAT feature flag enabled", GPUPlugin.GPU_FORMAT.isEnabled());
128+
129+
GPUPlugin gpuPlugin = internalCluster().getInstance(TestGPUPlugin.class);
130+
VectorsFormatProvider vectorsFormatProvider = gpuPlugin.getVectorsFormatProvider();
131+
132+
createIndex("index1");
133+
IndexSettings settings = getIndexSettings();
134+
final var indexOptions = DenseVectorFieldTypeTests.randomGpuSupportedIndexOptions();
135+
final var unsupportedElementType = randomFrom(
136+
DenseVectorFieldMapper.ElementType.BYTE,
137+
DenseVectorFieldMapper.ElementType.BFLOAT16,
138+
DenseVectorFieldMapper.ElementType.BIT
139+
);
140+
141+
var format = vectorsFormatProvider.getKnnVectorsFormat(
142+
settings,
143+
indexOptions,
144+
randomGPUSupportedSimilarity(indexOptions.getType()),
145+
unsupportedElementType
146+
);
147+
assertNull(format);
148+
}
149+
127150
public void testAutoModeLicenseNotSupported() {
128151
gpuMode = GPUPlugin.GpuMode.AUTO;
129152
isGpuIndexingFeatureAllowed = false;
@@ -140,7 +163,6 @@ public void testAutoModeLicenseNotSupported() {
140163
settings,
141164
indexOptions,
142165
randomGPUSupportedSimilarity(indexOptions.getType()),
143-
// TODO add other type support
144166
DenseVectorFieldMapper.ElementType.FLOAT
145167
);
146168
assertNull(format);
@@ -162,7 +184,6 @@ public void testTrueModeSupportedVectorType() {
162184
settings,
163185
indexOptions,
164186
randomGPUSupportedSimilarity(indexOptions.getType()),
165-
// TODO add other type support
166187
DenseVectorFieldMapper.ElementType.FLOAT
167188
);
168189
assertNotNull(format);
@@ -183,12 +204,36 @@ public void testTrueModeUnsupportedVectorType() {
183204
settings,
184205
indexOptions,
185206
randomGPUSupportedSimilarity(indexOptions.getType()),
186-
// TODO add other type support
187207
DenseVectorFieldMapper.ElementType.FLOAT
188208
);
189209
assertNull(format);
190210
}
191211

212+
public void testTrueModeUnsupportedElementType() {
213+
gpuMode = GPUPlugin.GpuMode.TRUE;
214+
assumeTrue("GPU_FORMAT feature flag enabled", GPUPlugin.GPU_FORMAT.isEnabled());
215+
216+
GPUPlugin gpuPlugin = internalCluster().getInstance(TestGPUPlugin.class);
217+
VectorsFormatProvider vectorsFormatProvider = gpuPlugin.getVectorsFormatProvider();
218+
219+
createIndex("index1");
220+
IndexSettings settings = getIndexSettings();
221+
final var indexOptions = DenseVectorFieldTypeTests.randomGpuSupportedIndexOptions();
222+
final var unsupportedElementType = randomFrom(
223+
DenseVectorFieldMapper.ElementType.BYTE,
224+
DenseVectorFieldMapper.ElementType.BFLOAT16,
225+
DenseVectorFieldMapper.ElementType.BIT
226+
);
227+
228+
var format = vectorsFormatProvider.getKnnVectorsFormat(
229+
settings,
230+
indexOptions,
231+
randomGPUSupportedSimilarity(indexOptions.getType()),
232+
unsupportedElementType
233+
);
234+
assertNull(format);
235+
}
236+
192237
public void testTrueModeLicenseNotSupported() {
193238
gpuMode = GPUPlugin.GpuMode.TRUE;
194239
isGpuIndexingFeatureAllowed = false;
@@ -205,7 +250,6 @@ public void testTrueModeLicenseNotSupported() {
205250
settings,
206251
indexOptions,
207252
randomGPUSupportedSimilarity(indexOptions.getType()),
208-
// TODO add other type support
209253
DenseVectorFieldMapper.ElementType.FLOAT
210254
);
211255
assertNull(format);
@@ -227,7 +271,6 @@ public void testFalseModeNeverUsesGpu() {
227271
settings,
228272
indexOptions,
229273
randomGPUSupportedSimilarity(indexOptions.getType()),
230-
// TODO add other type support
231274
DenseVectorFieldMapper.ElementType.FLOAT
232275
);
233276
assertNull(format);

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/GPUPlugin.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,21 @@ public VectorsFormatProvider getVectorsFormatProvider() {
117117
return (indexSettings, indexOptions, similarity, elementType) -> {
118118
if (GPU_FORMAT.isEnabled() && isGpuIndexingFeatureAllowed()) {
119119
if ((gpuMode == GpuMode.TRUE || (gpuMode == GpuMode.AUTO && GPUSupport.isSupported()))
120-
&& vectorIndexTypeSupported(indexOptions.getType())) {
120+
&& vectorIndexAndElementTypeSupported(indexOptions.getType(), elementType)) {
121121
return getVectorsFormat(indexOptions, similarity);
122122
}
123123
}
124124
return null;
125125
};
126126
}
127127

128-
private boolean vectorIndexTypeSupported(DenseVectorFieldMapper.VectorIndexType type) {
128+
private boolean vectorIndexAndElementTypeSupported(
129+
DenseVectorFieldMapper.VectorIndexType type,
130+
DenseVectorFieldMapper.ElementType elementType
131+
) {
132+
if (elementType != DenseVectorFieldMapper.ElementType.FLOAT) {
133+
return false;
134+
}
129135
return type == DenseVectorFieldMapper.VectorIndexType.HNSW || type == DenseVectorFieldMapper.VectorIndexType.INT8_HNSW;
130136
}
131137

0 commit comments

Comments
 (0)