2424import java .util .List ;
2525
2626import static org .elasticsearch .xpack .gpu .TestVectorsFormatUtils .randomGPUSupportedSimilarity ;
27- import static org .hamcrest .Matchers .containsString ;
28- import static org .hamcrest .Matchers .equalTo ;
29- import static org .hamcrest .Matchers .startsWith ;
3027
3128public class GPUPluginInitializationWithGPUIT extends ESIntegTestCase {
3229
@@ -47,11 +44,12 @@ public class GPUPluginInitializationWithGPUIT extends ESIntegTestCase {
4744 }
4845
4946 private static boolean isGpuIndexingFeatureAllowed = true ;
47+ private static GPUPlugin .GpuMode gpuMode = GPUPlugin .GpuMode .AUTO ;
5048
5149 public static class TestGPUPlugin extends GPUPlugin {
5250
5351 public TestGPUPlugin () {
54- super ();
52+ super (Settings . builder (). put ( "vectors.indexing.use_gpu" , gpuMode . name ()). build () );
5553 }
5654
5755 @ Override
@@ -63,13 +61,15 @@ protected boolean isGpuIndexingFeatureAllowed() {
6361 @ After
6462 public void reset () {
6563 isGpuIndexingFeatureAllowed = true ;
64+ gpuMode = GPUPlugin .GpuMode .AUTO ;
6665 }
6766
6867 @ Override
6968 protected Collection <Class <? extends Plugin >> nodePlugins () {
7069 return List .of (TestGPUPlugin .class );
7170 }
7271
72+ // Feature flag disabled tests
7373 public void testFFOff () {
7474 assumeFalse ("GPU_FORMAT feature flag disabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
7575
@@ -80,25 +80,15 @@ public void testFFOff() {
8080 assertNull (format );
8181 }
8282
83- public void testFFOffIndexSettingNotSupported () {
84- assumeFalse ("GPU_FORMAT feature flag disabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
85- IllegalArgumentException exception = expectThrows (
86- IllegalArgumentException .class ,
87- () -> createIndex (
88- "index1" ,
89- Settings .builder ().put (GPUPlugin .VECTORS_INDEXING_USE_GPU_SETTING .getKey (), GPUPlugin .GpuMode .TRUE ).build ()
90- )
91- );
92- assertThat (exception .getMessage (), containsString ("unknown setting [index.vectors.indexing.use_gpu]" ));
93- }
94-
95- public void testFFOffGPUFormatNull () {
96- assumeFalse ("GPU_FORMAT feature flag disabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
83+ // AUTO mode tests
84+ public void testAutoModeSupportedVectorType () {
85+ gpuMode = GPUPlugin .GpuMode .AUTO ;
86+ assumeTrue ("GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
9787
9888 GPUPlugin gpuPlugin = internalCluster ().getInstance (TestGPUPlugin .class );
9989 VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
10090
101- createIndex ("index1" , Settings . EMPTY );
91+ createIndex ("index1" );
10292 IndexSettings settings = getIndexSettings ();
10393 final var indexOptions = DenseVectorFieldTypeTests .randomGpuSupportedIndexOptions ();
10494
@@ -107,72 +97,57 @@ public void testFFOffGPUFormatNull() {
10797 indexOptions ,
10898 randomGPUSupportedSimilarity (indexOptions .getType ())
10999 );
110- assertNull (format );
100+ assertNotNull (format );
111101 }
112102
113- public void testIndexSettingOnIndexAllSupported () {
103+ public void testAutoModeUnsupportedVectorType () {
104+ gpuMode = GPUPlugin .GpuMode .AUTO ;
114105 assumeTrue ("GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
115106
116107 GPUPlugin gpuPlugin = internalCluster ().getInstance (TestGPUPlugin .class );
117108 VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
118109
119- createIndex ("index1" , Settings . builder (). put ( GPUPlugin . VECTORS_INDEXING_USE_GPU_SETTING . getKey (), GPUPlugin . GpuMode . TRUE ). build () );
110+ createIndex ("index1" );
120111 IndexSettings settings = getIndexSettings ();
121- final var indexOptions = DenseVectorFieldTypeTests .randomGpuSupportedIndexOptions ();
112+ final var indexOptions = DenseVectorFieldTypeTests .randomFlatIndexOptions ();
122113
123114 var format = vectorsFormatProvider .getKnnVectorsFormat (
124115 settings ,
125116 indexOptions ,
126117 randomGPUSupportedSimilarity (indexOptions .getType ())
127118 );
128- assertNotNull (format );
129- }
130-
131- public void testIndexSettingOnIndexTypeNotSupportedThrows () {
132- assumeTrue ("GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
133-
134- GPUPlugin gpuPlugin = internalCluster ().getInstance (TestGPUPlugin .class );
135- VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
136-
137- createIndex ("index1" , Settings .builder ().put (GPUPlugin .VECTORS_INDEXING_USE_GPU_SETTING .getKey (), GPUPlugin .GpuMode .TRUE ).build ());
138- IndexSettings settings = getIndexSettings ();
139- final var indexOptions = DenseVectorFieldTypeTests .randomFlatIndexOptions ();
140-
141- var ex = expectThrows (
142- IllegalArgumentException .class ,
143- () -> vectorsFormatProvider .getKnnVectorsFormat (settings , indexOptions , randomGPUSupportedSimilarity (indexOptions .getType ()))
144- );
145- assertThat (ex .getMessage (), startsWith ("[index.vectors.indexing.use_gpu] doesn't support [index_options.type] of" ));
119+ assertNull (format );
146120 }
147121
148- public void testIndexSettingOnIndexLicenseNotSupportedThrows () {
149- assumeTrue ( "GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT . isEnabled ()) ;
122+ public void testAutoModeLicenseNotSupported () {
123+ gpuMode = GPUPlugin .GpuMode . AUTO ;
150124 isGpuIndexingFeatureAllowed = false ;
125+ assumeTrue ("GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
151126
152127 GPUPlugin gpuPlugin = internalCluster ().getInstance (TestGPUPlugin .class );
153128 VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
154129
155- createIndex ("index1" , Settings . builder (). put ( GPUPlugin . VECTORS_INDEXING_USE_GPU_SETTING . getKey (), GPUPlugin . GpuMode . TRUE ). build () );
130+ createIndex ("index1" );
156131 IndexSettings settings = getIndexSettings ();
157132 final var indexOptions = DenseVectorFieldTypeTests .randomGpuSupportedIndexOptions ();
158133
159- var ex = expectThrows (
160- IllegalArgumentException .class ,
161- () -> vectorsFormatProvider .getKnnVectorsFormat (settings , indexOptions , randomGPUSupportedSimilarity (indexOptions .getType ()))
162- );
163- assertThat (
164- ex .getMessage (),
165- equalTo ("[index.vectors.indexing.use_gpu] was set to [true], but GPU indexing is a [ENTERPRISE] level feature" )
134+ var format = vectorsFormatProvider .getKnnVectorsFormat (
135+ settings ,
136+ indexOptions ,
137+ randomGPUSupportedSimilarity (indexOptions .getType ())
166138 );
139+ assertNull (format );
167140 }
168141
169- public void testIndexSettingAutoAllSupported () {
142+ // TRUE mode tests
143+ public void testTrueModeSupportedVectorType () {
144+ gpuMode = GPUPlugin .GpuMode .TRUE ;
170145 assumeTrue ("GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
171146
172147 GPUPlugin gpuPlugin = internalCluster ().getInstance (TestGPUPlugin .class );
173148 VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
174149
175- createIndex ("index1" , Settings . builder (). put ( GPUPlugin . VECTORS_INDEXING_USE_GPU_SETTING . getKey (), GPUPlugin . GpuMode . AUTO ). build () );
150+ createIndex ("index1" );
176151 IndexSettings settings = getIndexSettings ();
177152 final var indexOptions = DenseVectorFieldTypeTests .randomGpuSupportedIndexOptions ();
178153
@@ -184,16 +159,16 @@ public void testIndexSettingAutoAllSupported() {
184159 assertNotNull (format );
185160 }
186161
187- public void testIndexSettingAutoLicenseNotSupported () {
162+ public void testTrueModeUnsupportedVectorType () {
163+ gpuMode = GPUPlugin .GpuMode .TRUE ;
188164 assumeTrue ("GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
189- isGpuIndexingFeatureAllowed = false ;
190165
191166 GPUPlugin gpuPlugin = internalCluster ().getInstance (TestGPUPlugin .class );
192167 VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
193168
194- createIndex ("index1" , Settings . builder (). put ( GPUPlugin . VECTORS_INDEXING_USE_GPU_SETTING . getKey (), GPUPlugin . GpuMode . AUTO ). build () );
169+ createIndex ("index1" );
195170 IndexSettings settings = getIndexSettings ();
196- final var indexOptions = DenseVectorFieldTypeTests .randomGpuSupportedIndexOptions ();
171+ final var indexOptions = DenseVectorFieldTypeTests .randomFlatIndexOptions ();
197172
198173 var format = vectorsFormatProvider .getKnnVectorsFormat (
199174 settings ,
@@ -203,15 +178,17 @@ public void testIndexSettingAutoLicenseNotSupported() {
203178 assertNull (format );
204179 }
205180
206- public void testIndexSettingAutoIndexTypeNotSupported () {
181+ public void testTrueModeLicenseNotSupported () {
182+ gpuMode = GPUPlugin .GpuMode .TRUE ;
183+ isGpuIndexingFeatureAllowed = false ;
207184 assumeTrue ("GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
208185
209186 GPUPlugin gpuPlugin = internalCluster ().getInstance (TestGPUPlugin .class );
210187 VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
211188
212- createIndex ("index1" , Settings . builder (). put ( GPUPlugin . VECTORS_INDEXING_USE_GPU_SETTING . getKey (), GPUPlugin . GpuMode . AUTO ). build () );
189+ createIndex ("index1" );
213190 IndexSettings settings = getIndexSettings ();
214- final var indexOptions = DenseVectorFieldTypeTests .randomFlatIndexOptions ();
191+ final var indexOptions = DenseVectorFieldTypeTests .randomGpuSupportedIndexOptions ();
215192
216193 var format = vectorsFormatProvider .getKnnVectorsFormat (
217194 settings ,
@@ -221,13 +198,15 @@ public void testIndexSettingAutoIndexTypeNotSupported() {
221198 assertNull (format );
222199 }
223200
224- public void testIndexSettingOff () {
201+ // FALSE mode tests
202+ public void testFalseModeNeverUsesGpu () {
203+ gpuMode = GPUPlugin .GpuMode .FALSE ;
225204 assumeTrue ("GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
226205
227206 GPUPlugin gpuPlugin = internalCluster ().getInstance (TestGPUPlugin .class );
228207 VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
229208
230- createIndex ("index1" , Settings . builder (). put ( GPUPlugin . VECTORS_INDEXING_USE_GPU_SETTING . getKey (), GPUPlugin . GpuMode . FALSE ). build () );
209+ createIndex ("index1" );
231210 IndexSettings settings = getIndexSettings ();
232211 final var indexOptions = DenseVectorFieldTypeTests .randomGpuSupportedIndexOptions ();
233212
@@ -252,4 +231,5 @@ private IndexSettings getIndexSettings() {
252231 assertNotNull (settings );
253232 return settings ;
254233 }
234+
255235}
0 commit comments