1616import org .elasticsearch .index .mapper .vectors .DenseVectorFieldTypeTests ;
1717import org .elasticsearch .index .mapper .vectors .VectorsFormatProvider ;
1818import org .elasticsearch .indices .IndicesService ;
19+ import org .elasticsearch .injection .guice .Inject ;
1920import org .elasticsearch .plugins .Plugin ;
2021import org .elasticsearch .test .ESIntegTestCase ;
2122import org .junit .After ;
2425import java .util .List ;
2526
2627import static org .elasticsearch .xpack .gpu .TestVectorsFormatUtils .randomGPUSupportedSimilarity ;
27- import static org .hamcrest .Matchers .containsString ;
28- import static org .hamcrest .Matchers .equalTo ;
29- import static org .hamcrest .Matchers .startsWith ;
3028
3129public class GPUPluginInitializationWithGPUIT extends ESIntegTestCase {
3230
@@ -47,11 +45,13 @@ public class GPUPluginInitializationWithGPUIT extends ESIntegTestCase {
4745 }
4846
4947 private static boolean isGpuIndexingFeatureAllowed = true ;
48+ private static GPUPlugin .GpuMode gpuMode = GPUPlugin .GpuMode .AUTO ;
5049
5150 public static class TestGPUPlugin extends GPUPlugin {
5251
53- public TestGPUPlugin () {
54- super ();
52+ @ Inject
53+ public TestGPUPlugin (Settings settings ) {
54+ super (Settings .builder ().put (settings ).put ("vectors.indexing.use_gpu" , gpuMode .name ()).build ());
5555 }
5656
5757 @ Override
@@ -63,13 +63,15 @@ protected boolean isGpuIndexingFeatureAllowed() {
6363 @ After
6464 public void reset () {
6565 isGpuIndexingFeatureAllowed = true ;
66+ gpuMode = GPUPlugin .GpuMode .AUTO ;
6667 }
6768
6869 @ Override
6970 protected Collection <Class <? extends Plugin >> nodePlugins () {
7071 return List .of (TestGPUPlugin .class );
7172 }
7273
74+ // Feature flag disabled tests
7375 public void testFFOff () {
7476 assumeFalse ("GPU_FORMAT feature flag disabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
7577
@@ -80,25 +82,15 @@ public void testFFOff() {
8082 assertNull (format );
8183 }
8284
83- public void testFFOffIndexSettingNotSupported () {
84- assumeFalse ("GPU_FORMAT feature flag disabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
85- IllegalArgumentException exception = expectThrows (
86- IllegalArgumentException .class ,
87- () -> createIndex (
88- "index1" ,
89- Settings .builder ().put (GPUPlugin .VECTORS_INDEXING_USE_GPU_SETTING .getKey (), GPUPlugin .GpuMode .TRUE ).build ()
90- )
91- );
92- assertThat (exception .getMessage (), containsString ("unknown setting [index.vectors.indexing.use_gpu]" ));
93- }
94-
95- public void testFFOffGPUFormatNull () {
96- assumeFalse ("GPU_FORMAT feature flag disabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
85+ // AUTO mode tests
86+ public void testAutoModeSupportedVectorType () {
87+ gpuMode = GPUPlugin .GpuMode .AUTO ;
88+ assumeTrue ("GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
9789
9890 GPUPlugin gpuPlugin = internalCluster ().getInstance (TestGPUPlugin .class );
9991 VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
10092
101- createIndex ("index1" , Settings . EMPTY );
93+ createIndex ("index1" );
10294 IndexSettings settings = getIndexSettings ();
10395 final var indexOptions = DenseVectorFieldTypeTests .randomGpuSupportedIndexOptions ();
10496
@@ -107,72 +99,57 @@ public void testFFOffGPUFormatNull() {
10799 indexOptions ,
108100 randomGPUSupportedSimilarity (indexOptions .getType ())
109101 );
110- assertNull (format );
102+ assertNotNull (format );
111103 }
112104
113- public void testIndexSettingOnIndexAllSupported () {
105+ public void testAutoModeUnsupportedVectorType () {
106+ gpuMode = GPUPlugin .GpuMode .AUTO ;
114107 assumeTrue ("GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
115108
116109 GPUPlugin gpuPlugin = internalCluster ().getInstance (TestGPUPlugin .class );
117110 VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
118111
119- createIndex ("index1" , Settings . builder (). put ( GPUPlugin . VECTORS_INDEXING_USE_GPU_SETTING . getKey (), GPUPlugin . GpuMode . TRUE ). build () );
112+ createIndex ("index1" );
120113 IndexSettings settings = getIndexSettings ();
121- final var indexOptions = DenseVectorFieldTypeTests .randomGpuSupportedIndexOptions ();
114+ final var indexOptions = DenseVectorFieldTypeTests .randomFlatIndexOptions ();
122115
123116 var format = vectorsFormatProvider .getKnnVectorsFormat (
124117 settings ,
125118 indexOptions ,
126119 randomGPUSupportedSimilarity (indexOptions .getType ())
127120 );
128- assertNotNull (format );
129- }
130-
131- public void testIndexSettingOnIndexTypeNotSupportedThrows () {
132- assumeTrue ("GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
133-
134- GPUPlugin gpuPlugin = internalCluster ().getInstance (TestGPUPlugin .class );
135- VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
136-
137- createIndex ("index1" , Settings .builder ().put (GPUPlugin .VECTORS_INDEXING_USE_GPU_SETTING .getKey (), GPUPlugin .GpuMode .TRUE ).build ());
138- IndexSettings settings = getIndexSettings ();
139- final var indexOptions = DenseVectorFieldTypeTests .randomFlatIndexOptions ();
140-
141- var ex = expectThrows (
142- IllegalArgumentException .class ,
143- () -> vectorsFormatProvider .getKnnVectorsFormat (settings , indexOptions , randomGPUSupportedSimilarity (indexOptions .getType ()))
144- );
145- assertThat (ex .getMessage (), startsWith ("[index.vectors.indexing.use_gpu] doesn't support [index_options.type] of" ));
121+ assertNull (format );
146122 }
147123
148- public void testIndexSettingOnIndexLicenseNotSupportedThrows () {
149- assumeTrue ( "GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT . isEnabled ()) ;
124+ public void testAutoModeLicenseNotSupported () {
125+ gpuMode = GPUPlugin .GpuMode . AUTO ;
150126 isGpuIndexingFeatureAllowed = false ;
127+ assumeTrue ("GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
151128
152129 GPUPlugin gpuPlugin = internalCluster ().getInstance (TestGPUPlugin .class );
153130 VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
154131
155- createIndex ("index1" , Settings . builder (). put ( GPUPlugin . VECTORS_INDEXING_USE_GPU_SETTING . getKey (), GPUPlugin . GpuMode . TRUE ). build () );
132+ createIndex ("index1" );
156133 IndexSettings settings = getIndexSettings ();
157134 final var indexOptions = DenseVectorFieldTypeTests .randomGpuSupportedIndexOptions ();
158135
159- var ex = expectThrows (
160- IllegalArgumentException .class ,
161- () -> vectorsFormatProvider .getKnnVectorsFormat (settings , indexOptions , randomGPUSupportedSimilarity (indexOptions .getType ()))
162- );
163- assertThat (
164- ex .getMessage (),
165- equalTo ("[index.vectors.indexing.use_gpu] was set to [true], but GPU indexing is a [ENTERPRISE] level feature" )
136+ var format = vectorsFormatProvider .getKnnVectorsFormat (
137+ settings ,
138+ indexOptions ,
139+ randomGPUSupportedSimilarity (indexOptions .getType ())
166140 );
141+ assertNull (format );
167142 }
168143
169- public void testIndexSettingAutoAllSupported () {
144+ // TRUE mode tests
145+ public void testTrueModeSupportedVectorType () {
146+ gpuMode = GPUPlugin .GpuMode .TRUE ;
170147 assumeTrue ("GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
171148
172149 GPUPlugin gpuPlugin = internalCluster ().getInstance (TestGPUPlugin .class );
173150 VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
174151
175- createIndex ("index1" , Settings . builder (). put ( GPUPlugin . VECTORS_INDEXING_USE_GPU_SETTING . getKey (), GPUPlugin . GpuMode . AUTO ). build () );
152+ createIndex ("index1" );
176153 IndexSettings settings = getIndexSettings ();
177154 final var indexOptions = DenseVectorFieldTypeTests .randomGpuSupportedIndexOptions ();
178155
@@ -184,16 +161,16 @@ public void testIndexSettingAutoAllSupported() {
184161 assertNotNull (format );
185162 }
186163
187- public void testIndexSettingAutoLicenseNotSupported () {
164+ public void testTrueModeUnsupportedVectorType () {
165+ gpuMode = GPUPlugin .GpuMode .TRUE ;
188166 assumeTrue ("GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
189- isGpuIndexingFeatureAllowed = false ;
190167
191168 GPUPlugin gpuPlugin = internalCluster ().getInstance (TestGPUPlugin .class );
192169 VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
193170
194- createIndex ("index1" , Settings . builder (). put ( GPUPlugin . VECTORS_INDEXING_USE_GPU_SETTING . getKey (), GPUPlugin . GpuMode . AUTO ). build () );
171+ createIndex ("index1" );
195172 IndexSettings settings = getIndexSettings ();
196- final var indexOptions = DenseVectorFieldTypeTests .randomGpuSupportedIndexOptions ();
173+ final var indexOptions = DenseVectorFieldTypeTests .randomFlatIndexOptions ();
197174
198175 var format = vectorsFormatProvider .getKnnVectorsFormat (
199176 settings ,
@@ -203,15 +180,17 @@ public void testIndexSettingAutoLicenseNotSupported() {
203180 assertNull (format );
204181 }
205182
206- public void testIndexSettingAutoIndexTypeNotSupported () {
183+ public void testTrueModeLicenseNotSupported () {
184+ gpuMode = GPUPlugin .GpuMode .TRUE ;
185+ isGpuIndexingFeatureAllowed = false ;
207186 assumeTrue ("GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
208187
209188 GPUPlugin gpuPlugin = internalCluster ().getInstance (TestGPUPlugin .class );
210189 VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
211190
212- createIndex ("index1" , Settings . builder (). put ( GPUPlugin . VECTORS_INDEXING_USE_GPU_SETTING . getKey (), GPUPlugin . GpuMode . AUTO ). build () );
191+ createIndex ("index1" );
213192 IndexSettings settings = getIndexSettings ();
214- final var indexOptions = DenseVectorFieldTypeTests .randomFlatIndexOptions ();
193+ final var indexOptions = DenseVectorFieldTypeTests .randomGpuSupportedIndexOptions ();
215194
216195 var format = vectorsFormatProvider .getKnnVectorsFormat (
217196 settings ,
@@ -221,13 +200,15 @@ public void testIndexSettingAutoIndexTypeNotSupported() {
221200 assertNull (format );
222201 }
223202
224- public void testIndexSettingOff () {
203+ // FALSE mode tests
204+ public void testFalseModeNeverUsesGpu () {
205+ gpuMode = GPUPlugin .GpuMode .FALSE ;
225206 assumeTrue ("GPU_FORMAT feature flag enabled" , GPUPlugin .GPU_FORMAT .isEnabled ());
226207
227208 GPUPlugin gpuPlugin = internalCluster ().getInstance (TestGPUPlugin .class );
228209 VectorsFormatProvider vectorsFormatProvider = gpuPlugin .getVectorsFormatProvider ();
229210
230- createIndex ("index1" , Settings . builder (). put ( GPUPlugin . VECTORS_INDEXING_USE_GPU_SETTING . getKey (), GPUPlugin . GpuMode . FALSE ). build () );
211+ createIndex ("index1" );
231212 IndexSettings settings = getIndexSettings ();
232213 final var indexOptions = DenseVectorFieldTypeTests .randomGpuSupportedIndexOptions ();
233214
@@ -252,4 +233,5 @@ private IndexSettings getIndexSettings() {
252233 assertNotNull (settings );
253234 return settings ;
254235 }
236+
255237}
0 commit comments