11
11
import org .elasticsearch .action .support .PlainActionFuture ;
12
12
import org .elasticsearch .common .settings .Settings ;
13
13
import org .elasticsearch .core .TimeValue ;
14
+ import org .elasticsearch .index .mapper .vectors .DenseVectorFieldMapper ;
14
15
import org .elasticsearch .inference .InferenceService ;
15
16
import org .elasticsearch .inference .MinimalServiceSettings ;
16
17
import org .elasticsearch .inference .Model ;
43
44
import static org .elasticsearch .xpack .inference .external .http .Utils .getUrl ;
44
45
import static org .elasticsearch .xpack .inference .services .ServiceComponentsTests .createWithEmptySettings ;
45
46
import static org .hamcrest .CoreMatchers .is ;
47
+ import static org .hamcrest .Matchers .containsInAnyOrder ;
46
48
import static org .mockito .Mockito .mock ;
47
49
48
50
public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
@@ -94,7 +96,6 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
94
96
try (var service = createElasticInferenceService ()) {
95
97
ensureAuthorizationCallFinished (service );
96
98
assertThat (service .supportedStreamingTasks (), is (EnumSet .of (TaskType .CHAT_COMPLETION )));
97
-
98
99
assertThat (
99
100
service .defaultConfigIds (),
100
101
is (
@@ -191,13 +192,21 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
191
192
String responseJson = """
192
193
{
193
194
"models": [
195
+ {
196
+ "model_name": "elser-v2",
197
+ "task_types": ["embed/text/sparse"]
198
+ },
194
199
{
195
200
"model_name": "rainbow-sprinkles",
196
201
"task_types": ["chat"]
197
202
},
198
203
{
199
- "model_name": "elser-v2",
200
- "task_types": ["embed/text/sparse"]
204
+ "model_name": "multilingual-embed-v1",
205
+ "task_types": ["embed/text/dense"]
206
+ },
207
+ {
208
+ "model_name": "rerank-v1",
209
+ "task_types": ["rerank/text/text-similarity"]
201
210
}
202
211
]
203
212
}
@@ -211,27 +220,48 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
211
220
assertThat (service .supportedStreamingTasks (), is (EnumSet .of (TaskType .CHAT_COMPLETION )));
212
221
assertThat (
213
222
service .defaultConfigIds (),
214
- is (
215
- List .of (
216
- new InferenceService .DefaultConfigId (
217
- ".elser-v2-elastic" ,
218
- MinimalServiceSettings .sparseEmbedding (ElasticInferenceService .NAME ),
219
- service
223
+ containsInAnyOrder (
224
+ new InferenceService .DefaultConfigId (
225
+ ".elser-v2-elastic" ,
226
+ MinimalServiceSettings .sparseEmbedding (ElasticInferenceService .NAME ),
227
+ service
228
+ ),
229
+ new InferenceService .DefaultConfigId (
230
+ ".rainbow-sprinkles-elastic" ,
231
+ MinimalServiceSettings .chatCompletion (ElasticInferenceService .NAME ),
232
+ service
233
+ ),
234
+ new InferenceService .DefaultConfigId (
235
+ ".multilingual-embed-v1-elastic" ,
236
+ MinimalServiceSettings .textEmbedding (
237
+ ElasticInferenceService .NAME ,
238
+ ElasticInferenceService .DENSE_TEXT_EMBEDDINGS_DIMENSIONS ,
239
+ ElasticInferenceService .defaultDenseTextEmbeddingsSimilarity (),
240
+ DenseVectorFieldMapper .ElementType .FLOAT
220
241
),
221
- new InferenceService .DefaultConfigId (
222
- ".rainbow-sprinkles-elastic" ,
223
- MinimalServiceSettings .chatCompletion (ElasticInferenceService .NAME ),
224
- service
225
- )
242
+ service
243
+ ),
244
+ new InferenceService .DefaultConfigId (
245
+ ".rerank-v1-elastic" ,
246
+ MinimalServiceSettings .rerank (ElasticInferenceService .NAME ),
247
+ service
226
248
)
227
249
)
228
250
);
229
- assertThat (service .supportedTaskTypes (), is (EnumSet .of (TaskType .CHAT_COMPLETION , TaskType .SPARSE_EMBEDDING )));
251
+ assertThat (
252
+ service .supportedTaskTypes (),
253
+ is (EnumSet .of (TaskType .CHAT_COMPLETION , TaskType .SPARSE_EMBEDDING , TaskType .RERANK , TaskType .TEXT_EMBEDDING ))
254
+ );
230
255
231
256
PlainActionFuture <List <Model >> listener = new PlainActionFuture <>();
232
257
service .defaultConfigs (listener );
233
258
assertThat (listener .actionGet (TIMEOUT ).get (0 ).getConfigurations ().getInferenceEntityId (), is (".elser-v2-elastic" ));
234
- assertThat (listener .actionGet (TIMEOUT ).get (1 ).getConfigurations ().getInferenceEntityId (), is (".rainbow-sprinkles-elastic" ));
259
+ assertThat (
260
+ listener .actionGet (TIMEOUT ).get (1 ).getConfigurations ().getInferenceEntityId (),
261
+ is (".multilingual-embed-v1-elastic" )
262
+ );
263
+ assertThat (listener .actionGet (TIMEOUT ).get (2 ).getConfigurations ().getInferenceEntityId (), is (".rainbow-sprinkles-elastic" ));
264
+ assertThat (listener .actionGet (TIMEOUT ).get (3 ).getConfigurations ().getInferenceEntityId (), is (".rerank-v1-elastic" ));
235
265
236
266
var getModelListener = new PlainActionFuture <UnparsedModel >();
237
267
// persists the default endpoints
@@ -249,6 +279,14 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
249
279
{
250
280
"model_name": "elser-v2",
251
281
"task_types": ["embed/text/sparse"]
282
+ },
283
+ {
284
+ "model_name": "rerank-v1",
285
+ "task_types": ["rerank/text/text-similarity"]
286
+ },
287
+ {
288
+ "model_name": "multilingual-embed-v1",
289
+ "task_types": ["embed/text/dense"]
252
290
}
253
291
]
254
292
}
@@ -262,17 +300,33 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
262
300
assertThat (service .supportedStreamingTasks (), is (EnumSet .noneOf (TaskType .class )));
263
301
assertThat (
264
302
service .defaultConfigIds (),
265
- is (
266
- List .of (
267
- new InferenceService .DefaultConfigId (
268
- ".elser-v2-elastic" ,
269
- MinimalServiceSettings .sparseEmbedding (ElasticInferenceService .NAME ),
270
- service
271
- )
303
+ containsInAnyOrder (
304
+ new InferenceService .DefaultConfigId (
305
+ ".elser-v2-elastic" ,
306
+ MinimalServiceSettings .sparseEmbedding (ElasticInferenceService .NAME ),
307
+ service
308
+ ),
309
+ new InferenceService .DefaultConfigId (
310
+ ".multilingual-embed-v1-elastic" ,
311
+ MinimalServiceSettings .textEmbedding (
312
+ ElasticInferenceService .NAME ,
313
+ ElasticInferenceService .DENSE_TEXT_EMBEDDINGS_DIMENSIONS ,
314
+ ElasticInferenceService .defaultDenseTextEmbeddingsSimilarity (),
315
+ DenseVectorFieldMapper .ElementType .FLOAT
316
+ ),
317
+ service
318
+ ),
319
+ new InferenceService .DefaultConfigId (
320
+ ".rerank-v1-elastic" ,
321
+ MinimalServiceSettings .rerank (ElasticInferenceService .NAME ),
322
+ service
272
323
)
273
324
)
274
325
);
275
- assertThat (service .supportedTaskTypes (), is (EnumSet .of (TaskType .SPARSE_EMBEDDING )));
326
+ assertThat (
327
+ service .supportedTaskTypes (),
328
+ is (EnumSet .of (TaskType .TEXT_EMBEDDING , TaskType .SPARSE_EMBEDDING , TaskType .RERANK ))
329
+ );
276
330
277
331
var getModelListener = new PlainActionFuture <UnparsedModel >();
278
332
modelRegistry .getModel (".rainbow-sprinkles-elastic" , getModelListener );
0 commit comments