2727import org .elasticsearch .common .settings .Settings ;
2828import org .elasticsearch .common .xcontent .XContentHelper ;
2929import org .elasticsearch .core .Nullable ;
30+ import org .elasticsearch .inference .TaskType ;
3031import org .elasticsearch .logging .LogManager ;
3132import org .elasticsearch .logging .Logger ;
3233import org .elasticsearch .test .rest .ESRestTestCase ;
@@ -278,7 +279,7 @@ public static void main(String[] args) throws IOException {
278279 }
279280
280281 public static Set <TestDataset > availableDatasetsForEs (RestClient client , boolean supportsIndexModeLookup ) throws IOException {
281- boolean inferenceEnabled = clusterHasInferenceEndpoint (client );
282+ boolean inferenceEnabled = clusterHasSparseEmbeddingInferenceEndpoint (client );
282283
283284 Set <TestDataset > testDataSets = new HashSet <>();
284285
@@ -319,79 +320,90 @@ private static void loadDataSetIntoEs(RestClient client, boolean supportsIndexMo
319320 }
320321 }
321322
322- /**
323- * The semantic_text mapping type require an inference endpoint that needs to be setup before creating the index.
324- */
325- public static void createInferenceEndpoint (RestClient client ) throws IOException {
326- Request request = new Request ("PUT" , "_inference/sparse_embedding/test_sparse_inference" );
327- request .setJsonEntity ("""
323+ public static void createInferenceEndpoints (RestClient client ) throws IOException {
324+ if (clusterHasSparseEmbeddingInferenceEndpoint (client ) == false ) {
325+ createSparseEmbeddingInferenceEndpoint (client );
326+ }
327+
328+ if (clusterHasRerankInferenceEndpoint (client ) == false ) {
329+ createRerankInferenceEndpoint (client );
330+ }
331+
332+ if (clusterHasCompletionInferenceEndpoint (client ) == false ) {
333+ createCompletionInferenceEndpoint (client );
334+ }
335+ }
336+
337+ public static void deleteInferenceEndpoints (RestClient client ) throws IOException {
338+ deleteSparseEmbeddingInferenceEndpoint (client );
339+ deleteRerankInferenceEndpoint (client );
340+ deleteCompletionInferenceEndpoint (client );
341+ }
342+
343+ /** The semantic_text mapping type require an inference endpoint that needs to be setup before creating the index. */
344+ public static void createSparseEmbeddingInferenceEndpoint (RestClient client ) throws IOException {
345+ createInferenceEndpoint (client , TaskType .SPARSE_EMBEDDING , "test_sparse_inference" , """
328346 {
329347 "service": "test_service",
330- "service_settings": {
331- "model": "my_model",
332- "api_key": "abc64"
333- },
334- "task_settings": {
335- }
348+ "service_settings": { "model": "my_model", "api_key": "abc64" },
349+ "task_settings": { }
336350 }
337351 """ );
338- client .performRequest (request );
339352 }
340353
341- public static void deleteInferenceEndpoint (RestClient client ) throws IOException {
342- try {
343- client .performRequest (new Request ("DELETE" , "_inference/sparse_embedding/test_sparse_inference" ));
344- } catch (ResponseException e ) {
345- // 404 here means the endpoint was not created
346- if (e .getResponse ().getStatusLine ().getStatusCode () != 404 ) {
347- throw e ;
348- }
349- }
354+ public static void deleteSparseEmbeddingInferenceEndpoint (RestClient client ) throws IOException {
355+ deleteInferenceEndpoint (client , "test_sparse_inference" );
350356 }
351357
352- public static boolean clusterHasInferenceEndpoint (RestClient client ) throws IOException {
353- Request request = new Request ("GET" , "_inference/sparse_embedding/test_sparse_inference" );
354- try {
355- client .performRequest (request );
356- } catch (ResponseException e ) {
357- if (e .getResponse ().getStatusLine ().getStatusCode () == 404 ) {
358- return false ;
359- }
360- throw e ;
361- }
362- return true ;
358+ public static boolean clusterHasSparseEmbeddingInferenceEndpoint (RestClient client ) throws IOException {
359+ return clusterHasInferenceEndpoint (client , TaskType .SPARSE_EMBEDDING , "test_sparse_inference" );
363360 }
364361
365362 public static void createRerankInferenceEndpoint (RestClient client ) throws IOException {
366- Request request = new Request ("PUT" , "_inference/rerank/test_reranker" );
367- request .setJsonEntity ("""
363+ createInferenceEndpoint (client , TaskType .RERANK , "test_reranker" , """
368364 {
369365 "service": "test_reranking_service",
370- "service_settings": {
371- "model_id": "my_model",
372- "api_key": "abc64"
373- },
374- "task_settings": {
375- "use_text_length": true
376- }
366+ "service_settings": { "model_id": "my_model", "api_key": "abc64" },
367+ "task_settings": { "use_text_length": true }
377368 }
378369 """ );
379- client .performRequest (request );
380370 }
381371
382372 public static void deleteRerankInferenceEndpoint (RestClient client ) throws IOException {
383- try {
384- client .performRequest (new Request ("DELETE" , "_inference/rerank/test_reranker" ));
385- } catch (ResponseException e ) {
386- // 404 here means the endpoint was not created
387- if (e .getResponse ().getStatusLine ().getStatusCode () != 404 ) {
388- throw e ;
389- }
390- }
373+ deleteInferenceEndpoint (client , "test_reranker" );
391374 }
392375
393376 public static boolean clusterHasRerankInferenceEndpoint (RestClient client ) throws IOException {
394- Request request = new Request ("GET" , "_inference/rerank/test_reranker" );
377+ return clusterHasInferenceEndpoint (client , TaskType .RERANK , "test_reranker" );
378+ }
379+
380+ public static void createCompletionInferenceEndpoint (RestClient client ) throws IOException {
381+ createInferenceEndpoint (client , TaskType .COMPLETION , "test_completion" , """
382+ {
383+ "service": "completion_test_service",
384+ "service_settings": { "model": "my_model", "api_key": "abc64" },
385+ "task_settings": { "temperature": 3 }
386+ }
387+ """ );
388+ }
389+
390+ public static void deleteCompletionInferenceEndpoint (RestClient client ) throws IOException {
391+ deleteInferenceEndpoint (client , "test_completion" );
392+ }
393+
394+ public static boolean clusterHasCompletionInferenceEndpoint (RestClient client ) throws IOException {
395+ return clusterHasInferenceEndpoint (client , TaskType .COMPLETION , "test_completion" );
396+ }
397+
398+ private static void createInferenceEndpoint (RestClient client , TaskType taskType , String inferenceId , String modelSettings )
399+ throws IOException {
400+ Request request = new Request ("PUT" , "_inference/" + taskType .name () + "/" + inferenceId );
401+ request .setJsonEntity (modelSettings );
402+ client .performRequest (request );
403+ }
404+
405+ private static boolean clusterHasInferenceEndpoint (RestClient client , TaskType taskType , String inferenceId ) throws IOException {
406+ Request request = new Request ("GET" , "_inference/" + taskType .name () + "/" + inferenceId );
395407 try {
396408 client .performRequest (request );
397409 } catch (ResponseException e ) {
@@ -403,6 +415,17 @@ public static boolean clusterHasRerankInferenceEndpoint(RestClient client) throw
403415 return true ;
404416 }
405417
418+ private static void deleteInferenceEndpoint (RestClient client , String inferenceId ) throws IOException {
419+ try {
420+ client .performRequest (new Request ("DELETE" , "_inference/" + inferenceId ));
421+ } catch (ResponseException e ) {
422+ // 404 here means the endpoint was not created
423+ if (e .getResponse ().getStatusLine ().getStatusCode () != 404 ) {
424+ throw e ;
425+ }
426+ }
427+ }
428+
406429 private static void loadEnrichPolicy (RestClient client , String policyName , String policyFileName , Logger logger ) throws IOException {
407430 URL policyMapping = CsvTestsDataLoader .class .getResource ("/" + policyFileName );
408431 if (policyMapping == null ) {
0 commit comments