2727import org .elasticsearch .common .settings .Settings ;
2828import org .elasticsearch .common .xcontent .XContentHelper ;
2929import org .elasticsearch .core .Nullable ;
30+ import org .elasticsearch .inference .TaskType ;
3031import org .elasticsearch .logging .LogManager ;
3132import org .elasticsearch .logging .Logger ;
3233import org .elasticsearch .test .rest .ESRestTestCase ;
@@ -310,7 +311,7 @@ public static Set<TestDataset> availableDatasetsForEs(
310311 boolean supportsIndexModeLookup ,
311312 boolean supportsSourceFieldMapping
312313 ) throws IOException {
313- boolean inferenceEnabled = clusterHasInferenceEndpoint (client );
314+ boolean inferenceEnabled = clusterHasSparseEmbeddingInferenceEndpoint (client );
314315
315316 Set <TestDataset > testDataSets = new HashSet <>();
316317
@@ -372,77 +373,93 @@ private static void loadDataSetIntoEs(
372373 }
373374 }
374375
376+ public static void createInferenceEndpoints (RestClient client ) throws IOException {
377+ if (clusterHasSparseEmbeddingInferenceEndpoint (client ) == false ) {
378+ createSparseEmbeddingInferenceEndpoint (client );
379+ }
380+
381+ if (clusterHasRerankInferenceEndpoint (client ) == false ) {
382+ createRerankInferenceEndpoint (client );
383+ }
384+
385+ if (clusterHasCompletionInferenceEndpoint (client ) == false ) {
386+ createCompletionInferenceEndpoint (client );
387+ }
388+ }
389+
390+ public static void deleteInferenceEndpoints (RestClient client ) throws IOException {
391+ deleteSparseEmbeddingInferenceEndpoint (client );
392+ deleteRerankInferenceEndpoint (client );
393+ deleteCompletionInferenceEndpoint (client );
394+ }
395+
396+
375397 /** The semantic_text mapping type require an inference endpoint that needs to be setup before creating the index. */
376- public static void createInferenceEndpoint (RestClient client ) throws IOException {
377- Request request = new Request ( "PUT" , "_inference/sparse_embedding/ test_sparse_inference");
378- request . setJsonEntity ( """
398+ public static void createSparseEmbeddingInferenceEndpoint (RestClient client ) throws IOException {
399+ createInferenceEndpoint ( client , TaskType . SPARSE_EMBEDDING , " test_sparse_inference",
400+ """
379401 {
380402 "service": "test_service",
381- "service_settings": {
382- "model": "my_model",
383- "api_key": "abc64"
384- },
385- "task_settings": {
386- }
403+ "service_settings": { "model": "my_model", "api_key": "abc64" },
404+ "task_settings": { }
387405 }
388406 """ );
389- client .performRequest (request );
390407 }
391408
392- public static void deleteInferenceEndpoint (RestClient client ) throws IOException {
393- try {
394- client .performRequest (new Request ("DELETE" , "_inference/test_sparse_inference" ));
395- } catch (ResponseException e ) {
396- // 404 here means the endpoint was not created
397- if (e .getResponse ().getStatusLine ().getStatusCode () != 404 ) {
398- throw e ;
399- }
400- }
409+ public static void deleteSparseEmbeddingInferenceEndpoint (RestClient client ) throws IOException {
410+ deleteInferenceEndpoint (client , "test_sparse_inference" );
401411 }
402412
403- public static boolean clusterHasInferenceEndpoint (RestClient client ) throws IOException {
404- Request request = new Request ("GET" , "_inference/sparse_embedding/test_sparse_inference" );
405- try {
406- client .performRequest (request );
407- } catch (ResponseException e ) {
408- if (e .getResponse ().getStatusLine ().getStatusCode () == 404 ) {
409- return false ;
410- }
411- throw e ;
412- }
413- return true ;
413+ public static boolean clusterHasSparseEmbeddingInferenceEndpoint (RestClient client ) throws IOException {
414+ return clusterHasInferenceEndpoint (client , TaskType .SPARSE_EMBEDDING , "test_sparse_inference" );
414415 }
415416
416417 public static void createRerankInferenceEndpoint (RestClient client ) throws IOException {
417- Request request = new Request ("PUT" , "_inference/rerank/test_reranker" );
418- request .setJsonEntity ("""
418+ createInferenceEndpoint (client , TaskType .RERANK , "test_reranker" , """
419419 {
420420 "service": "test_reranking_service",
421- "service_settings": {
422- "model_id": "my_model",
423- "api_key": "abc64"
424- },
425- "task_settings": {
426- "use_text_length": true
427- }
421+ "service_settings": { "model_id": "my_model", "api_key": "abc64" },
422+ "task_settings": { "use_text_length": true }
428423 }
429424 """ );
430- client .performRequest (request );
431425 }
432426
433427 public static void deleteRerankInferenceEndpoint (RestClient client ) throws IOException {
434- try {
435- client .performRequest (new Request ("DELETE" , "_inference/rerank/test_reranker" ));
436- } catch (ResponseException e ) {
437- // 404 here means the endpoint was not created
438- if (e .getResponse ().getStatusLine ().getStatusCode () != 404 ) {
439- throw e ;
440- }
441- }
428+ deleteInferenceEndpoint (client , "test_reranker" );
442429 }
443430
444431 public static boolean clusterHasRerankInferenceEndpoint (RestClient client ) throws IOException {
445- Request request = new Request ("GET" , "_inference/rerank/test_reranker" );
432+ return clusterHasInferenceEndpoint (client , TaskType .RERANK , "test_reranker" );
433+ }
434+
435+ public static void createCompletionInferenceEndpoint (RestClient client ) throws IOException {
436+ createInferenceEndpoint (client , TaskType .COMPLETION , "test_completion" , """
437+ {
438+ "service": "streaming_completion_test_service",
439+ "service_settings": { "model": "my_model", "api_key": "abc64" },
440+ "task_settings": { "temperature": 3 }
441+ }
442+ """ );
443+ }
444+
445+ public static void deleteCompletionInferenceEndpoint (RestClient client ) throws IOException {
446+ deleteInferenceEndpoint (client , "test_completion" );
447+ }
448+
449+ public static boolean clusterHasCompletionInferenceEndpoint (RestClient client ) throws IOException {
450+ return clusterHasInferenceEndpoint (client , TaskType .COMPLETION , "test_completion" );
451+ }
452+
453+
454+ private static void createInferenceEndpoint (RestClient client , TaskType taskType , String inferenceId , String modelSettings ) throws IOException {
455+ Request request = new Request ("PUT" , "_inference/" + taskType .name () + "/" + inferenceId );
456+ request .setJsonEntity (modelSettings );
457+ client .performRequest (request );
458+ }
459+
460+
461+ private static boolean clusterHasInferenceEndpoint (RestClient client , TaskType taskType , String inferenceId ) throws IOException {
462+ Request request = new Request ("GET" , "_inference/" + taskType .name () + "/" + inferenceId );
446463 try {
447464 client .performRequest (request );
448465 } catch (ResponseException e ) {
@@ -454,6 +471,17 @@ public static boolean clusterHasRerankInferenceEndpoint(RestClient client) throw
454471 return true ;
455472 }
456473
474+ private static void deleteInferenceEndpoint (RestClient client , String inferenceId ) throws IOException {
475+ try {
476+ client .performRequest (new Request ("DELETE" , "_inference/" + inferenceId ));
477+ } catch (ResponseException e ) {
478+ // 404 here means the endpoint was not created
479+ if (e .getResponse ().getStatusLine ().getStatusCode () != 404 ) {
480+ throw e ;
481+ }
482+ }
483+ }
484+
457485 private static void loadEnrichPolicy (RestClient client , String policyName , String policyFileName , Logger logger ) throws IOException {
458486 URL policyMapping = getResource ("/" + policyFileName );
459487 String entity = readTextFile (policyMapping );
0 commit comments