77
88package org .elasticsearch .plugin .gpu ;
99
10- import com .nvidia .cuvs .CuVSResources ;
11- import com .nvidia .cuvs .CuVSResourcesInfo ;
1210import com .nvidia .cuvs .GPUInfo ;
1311import com .nvidia .cuvs .GPUInfoProvider ;
14- import com .nvidia .cuvs .spi .CuVSProvider ;
15- import com .nvidia .cuvs .spi .CuVSServiceProvider ;
1612
1713import org .elasticsearch .common .settings .Settings ;
1814import org .elasticsearch .index .IndexService ;
2319import org .elasticsearch .plugins .Plugin ;
2420import org .elasticsearch .test .ESIntegTestCase ;
2521import org .elasticsearch .xpack .gpu .GPUPlugin ;
26- import org .junit .After ;
2722
2823import java .util .Collection ;
2924import java .util .List ;
30- import java .util .function .Function ;
3125
3226import static org .hamcrest .Matchers .containsString ;
33- import static org .hamcrest .Matchers .equalTo ;
3427import static org .hamcrest .Matchers .startsWith ;
3528
36- public class GPUPluginInitializationIT extends ESIntegTestCase {
29+ public class GPUPluginInitializationWithGPUIT extends ESIntegTestCase {
3730
38- private static final Function < CuVSProvider , GPUInfoProvider > SUPPORTED_GPU_PROVIDER =
39- p -> new TestCuVSServiceProvider .TestGPUInfoProvider (
31+ static {
32+ TestCuVSServiceProvider . mockedGPUInfoProvider = SUPPORTEp -> new TestCuVSServiceProvider .TestGPUInfoProvider (
4033 List .of (
4134 new GPUInfo (
4235 0 ,
@@ -49,60 +42,13 @@ public class GPUPluginInitializationIT extends ESIntegTestCase {
4942 )
5043 )
5144 );
52-
53- private static final Function <CuVSProvider , GPUInfoProvider > NO_GPU_PROVIDER = p -> new TestCuVSServiceProvider .TestGPUInfoProvider (
54- List .of ()
55- );
45+ }
5646
5747 @ Override
5848 protected Collection <Class <? extends Plugin >> nodePlugins () {
5949 return List .of (GPUPlugin .class );
6050 }
6151
62- public static class TestCuVSServiceProvider extends CuVSServiceProvider {
63-
64- static final Function <CuVSProvider , GPUInfoProvider > BUILTIN_GPU_INFO_PROVIDER = CuVSProvider ::gpuInfoProvider ;
65- static Function <CuVSProvider , GPUInfoProvider > mockedGPUInfoProvider = BUILTIN_GPU_INFO_PROVIDER ;
66-
67- @ Override
68- public CuVSProvider get (CuVSProvider builtin ) {
69- return new CuVSProviderDelegate (builtin ) {
70- @ Override
71- public GPUInfoProvider gpuInfoProvider () {
72- return mockedGPUInfoProvider .apply (builtin );
73- }
74- };
75- }
76-
77- private static class TestGPUInfoProvider implements GPUInfoProvider {
78- private final List <GPUInfo > gpuList ;
79-
80- private TestGPUInfoProvider (List <GPUInfo > gpuList ) {
81- this .gpuList = gpuList ;
82- }
83-
84- @ Override
85- public List <GPUInfo > availableGPUs () {
86- return gpuList ;
87- }
88-
89- @ Override
90- public List <GPUInfo > compatibleGPUs () {
91- return gpuList ;
92- }
93-
94- @ Override
95- public CuVSResourcesInfo getCurrentInfo (CuVSResources cuVSResources ) {
96- return null ;
97- }
98- }
99- }
100-
101- @ After
102- public void disableMock () {
103- TestCuVSServiceProvider .mockedGPUInfoProvider = TestCuVSServiceProvider .BUILTIN_GPU_INFO_PROVIDER ;
104- }
105-
10652 public void testFFOff () {
10753 assumeFalse ("GPU_FORMAT feature flag disabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
10854
@@ -127,7 +73,6 @@ public void testFFOffIndexSettingNotSupported() {
12773
12874 public void testFFOffGPUFormatNull () {
12975 assumeFalse ("GPU_FORMAT feature flag disabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
130- TestCuVSServiceProvider .mockedGPUInfoProvider = SUPPORTED_GPU_PROVIDER ;
13176
13277 GPUPlugin gpuPlugin = internalCluster ().getInstance (GPUPlugin .class );
13378 VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
@@ -146,7 +91,6 @@ public void testFFOffGPUFormatNull() {
14691
14792 public void testIndexSettingOnIndexTypeSupportedGPUSupported () {
14893 assumeTrue ("GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
149- TestCuVSServiceProvider .mockedGPUInfoProvider = SUPPORTED_GPU_PROVIDER ;
15094
15195 GPUPlugin gpuPlugin = internalCluster ().getInstance (GPUPlugin .class );
15296 VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
@@ -165,7 +109,6 @@ public void testIndexSettingOnIndexTypeSupportedGPUSupported() {
165109
166110 public void testIndexSettingOnIndexTypeNotSupportedThrows () {
167111 assumeTrue ("GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
168- TestCuVSServiceProvider .mockedGPUInfoProvider = SUPPORTED_GPU_PROVIDER ;
169112
170113 GPUPlugin gpuPlugin = internalCluster ().getInstance (GPUPlugin .class );
171114 VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
@@ -185,60 +128,8 @@ public void testIndexSettingOnIndexTypeNotSupportedThrows() {
185128 assertThat (ex .getMessage (), startsWith ("[index.vectors.indexing.use_gpu] doesn't support [index_options.type] of" ));
186129 }
187130
188- public void testIndexSettingOnGPUNotSupportedThrows () {
189- assumeTrue ("GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
190- TestCuVSServiceProvider .mockedGPUInfoProvider = NO_GPU_PROVIDER ;
191-
192- GPUPlugin gpuPlugin = internalCluster ().getInstance (GPUPlugin .class );
193- VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
194-
195- createIndex ("index1" , Settings .builder ().put (GPUPlugin .VECTORS_INDEXING_USE_GPU_SETTING .getKey (), GPUPlugin .GpuMode .TRUE ).build ());
196- IndexSettings settings = getIndexSettings ();
197- final var indexOptions = DenseVectorFieldTypeTests .randomGpuSupportedIndexOptions ();
198-
199- var ex = expectThrows (
200- IllegalArgumentException .class ,
201- () -> vectorsFormatProvider .getKnnVectorsFormat (
202- settings ,
203- indexOptions ,
204- DenseVectorFieldTypeTests .randomGPUSupportedSimilarity (indexOptions .getType ())
205- )
206- );
207- assertThat (
208- ex .getMessage (),
209- equalTo ("[index.vectors.indexing.use_gpu] was set to [true], but GPU resources are not accessible on the node." )
210- );
211- }
212-
213- public void testIndexSettingOnGPUSupportThrowsRethrows () {
214- assumeTrue ("GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
215- // Mocks a cuvs-java UnsupportedProvider
216- TestCuVSServiceProvider .mockedGPUInfoProvider = p -> { throw new UnsupportedOperationException ("cuvs-java UnsupportedProvider" ); };
217-
218- GPUPlugin gpuPlugin = internalCluster ().getInstance (GPUPlugin .class );
219- VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
220-
221- createIndex ("index1" , Settings .builder ().put (GPUPlugin .VECTORS_INDEXING_USE_GPU_SETTING .getKey (), GPUPlugin .GpuMode .TRUE ).build ());
222- IndexSettings settings = getIndexSettings ();
223- final var indexOptions = DenseVectorFieldTypeTests .randomGpuSupportedIndexOptions ();
224-
225- var ex = expectThrows (
226- IllegalArgumentException .class ,
227- () -> vectorsFormatProvider .getKnnVectorsFormat (
228- settings ,
229- indexOptions ,
230- DenseVectorFieldTypeTests .randomGPUSupportedSimilarity (indexOptions .getType ())
231- )
232- );
233- assertThat (
234- ex .getMessage (),
235- equalTo ("[index.vectors.indexing.use_gpu] was set to [true], but GPU resources are not accessible on the node." )
236- );
237- }
238-
239131 public void testIndexSettingAutoIndexTypeSupportedGPUSupported () {
240132 assumeTrue ("GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
241- TestCuVSServiceProvider .mockedGPUInfoProvider = SUPPORTED_GPU_PROVIDER ;
242133
243134 GPUPlugin gpuPlugin = internalCluster ().getInstance (GPUPlugin .class );
244135 VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
@@ -255,28 +146,8 @@ public void testIndexSettingAutoIndexTypeSupportedGPUSupported() {
255146 assertNotNull (format );
256147 }
257148
258- public void testIndexSettingAutoGPUNotSupported () {
259- assumeTrue ("GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
260- TestCuVSServiceProvider .mockedGPUInfoProvider = NO_GPU_PROVIDER ;
261-
262- GPUPlugin gpuPlugin = internalCluster ().getInstance (GPUPlugin .class );
263- VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
264-
265- createIndex ("index1" , Settings .builder ().put (GPUPlugin .VECTORS_INDEXING_USE_GPU_SETTING .getKey (), GPUPlugin .GpuMode .AUTO ).build ());
266- IndexSettings settings = getIndexSettings ();
267- final var indexOptions = DenseVectorFieldTypeTests .randomGpuSupportedIndexOptions ();
268-
269- var format = vectorsFormatProvider .getKnnVectorsFormat (
270- settings ,
271- indexOptions ,
272- DenseVectorFieldTypeTests .randomGPUSupportedSimilarity (indexOptions .getType ())
273- );
274- assertNull (format );
275- }
276-
277149 public void testIndexSettingAutoIndexTypeNotSupported () {
278150 assumeTrue ("GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
279- TestCuVSServiceProvider .mockedGPUInfoProvider = SUPPORTED_GPU_PROVIDER ;
280151
281152 GPUPlugin gpuPlugin = internalCluster ().getInstance (GPUPlugin .class );
282153 VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
@@ -295,7 +166,6 @@ public void testIndexSettingAutoIndexTypeNotSupported() {
295166
296167 public void testIndexSettingOff () {
297168 assumeTrue ("GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
298- TestCuVSServiceProvider .mockedGPUInfoProvider = SUPPORTED_GPU_PROVIDER ;
299169
300170 GPUPlugin gpuPlugin = internalCluster ().getInstance (GPUPlugin .class );
301171 VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
0 commit comments