2727import org .elasticsearch .common .settings .Settings ;
2828import org .elasticsearch .common .xcontent .XContentHelper ;
2929import org .elasticsearch .core .Nullable ;
30+ import org .elasticsearch .inference .TaskType ;
3031import org .elasticsearch .logging .LogManager ;
3132import org .elasticsearch .logging .Logger ;
3233import org .elasticsearch .test .rest .ESRestTestCase ;
@@ -320,7 +321,7 @@ public static Set<TestDataset> availableDatasetsForEs(
320321 boolean supportsIndexModeLookup ,
321322 boolean supportsSourceFieldMapping
322323 ) throws IOException {
323- boolean inferenceEnabled = clusterHasInferenceEndpoint (client );
324+ boolean inferenceEnabled = clusterHasSparseEmbeddingInferenceEndpoint (client );
324325
325326 Set <TestDataset > testDataSets = new HashSet <>();
326327
@@ -382,77 +383,90 @@ private static void loadDataSetIntoEs(
382383 }
383384 }
384385
386+ public static void createInferenceEndpoints (RestClient client ) throws IOException {
387+ if (clusterHasSparseEmbeddingInferenceEndpoint (client ) == false ) {
388+ createSparseEmbeddingInferenceEndpoint (client );
389+ }
390+
391+ if (clusterHasRerankInferenceEndpoint (client ) == false ) {
392+ createRerankInferenceEndpoint (client );
393+ }
394+
395+ if (clusterHasCompletionInferenceEndpoint (client ) == false ) {
396+ createCompletionInferenceEndpoint (client );
397+ }
398+ }
399+
400+ public static void deleteInferenceEndpoints (RestClient client ) throws IOException {
401+ deleteSparseEmbeddingInferenceEndpoint (client );
402+ deleteRerankInferenceEndpoint (client );
403+ deleteCompletionInferenceEndpoint (client );
404+ }
405+
385406 /** The semantic_text mapping type require an inference endpoint that needs to be setup before creating the index. */
386- public static void createInferenceEndpoint (RestClient client ) throws IOException {
387- Request request = new Request ("PUT" , "_inference/sparse_embedding/test_sparse_inference" );
388- request .setJsonEntity ("""
407+ public static void createSparseEmbeddingInferenceEndpoint (RestClient client ) throws IOException {
408+ createInferenceEndpoint (client , TaskType .SPARSE_EMBEDDING , "test_sparse_inference" , """
389409 {
390410 "service": "test_service",
391- "service_settings": {
392- "model": "my_model",
393- "api_key": "abc64"
394- },
395- "task_settings": {
396- }
411+ "service_settings": { "model": "my_model", "api_key": "abc64" },
412+ "task_settings": { }
397413 }
398414 """ );
399- client .performRequest (request );
400415 }
401416
402- public static void deleteInferenceEndpoint (RestClient client ) throws IOException {
403- try {
404- client .performRequest (new Request ("DELETE" , "_inference/test_sparse_inference" ));
405- } catch (ResponseException e ) {
406- // 404 here means the endpoint was not created
407- if (e .getResponse ().getStatusLine ().getStatusCode () != 404 ) {
408- throw e ;
409- }
410- }
417+ public static void deleteSparseEmbeddingInferenceEndpoint (RestClient client ) throws IOException {
418+ deleteInferenceEndpoint (client , "test_sparse_inference" );
411419 }
412420
413- public static boolean clusterHasInferenceEndpoint (RestClient client ) throws IOException {
414- Request request = new Request ("GET" , "_inference/sparse_embedding/test_sparse_inference" );
415- try {
416- client .performRequest (request );
417- } catch (ResponseException e ) {
418- if (e .getResponse ().getStatusLine ().getStatusCode () == 404 ) {
419- return false ;
420- }
421- throw e ;
422- }
423- return true ;
421+ public static boolean clusterHasSparseEmbeddingInferenceEndpoint (RestClient client ) throws IOException {
422+ return clusterHasInferenceEndpoint (client , TaskType .SPARSE_EMBEDDING , "test_sparse_inference" );
424423 }
425424
426425 public static void createRerankInferenceEndpoint (RestClient client ) throws IOException {
427- Request request = new Request ("PUT" , "_inference/rerank/test_reranker" );
428- request .setJsonEntity ("""
426+ createInferenceEndpoint (client , TaskType .RERANK , "test_reranker" , """
429427 {
430428 "service": "test_reranking_service",
431- "service_settings": {
432- "model_id": "my_model",
433- "api_key": "abc64"
434- },
435- "task_settings": {
436- "use_text_length": true
437- }
429+ "service_settings": { "model_id": "my_model", "api_key": "abc64" },
430+ "task_settings": { "use_text_length": true }
438431 }
439432 """ );
440- client .performRequest (request );
441433 }
442434
443435 public static void deleteRerankInferenceEndpoint (RestClient client ) throws IOException {
444- try {
445- client .performRequest (new Request ("DELETE" , "_inference/rerank/test_reranker" ));
446- } catch (ResponseException e ) {
447- // 404 here means the endpoint was not created
448- if (e .getResponse ().getStatusLine ().getStatusCode () != 404 ) {
449- throw e ;
450- }
451- }
436+ deleteInferenceEndpoint (client , "test_reranker" );
452437 }
453438
454439 public static boolean clusterHasRerankInferenceEndpoint (RestClient client ) throws IOException {
455- Request request = new Request ("GET" , "_inference/rerank/test_reranker" );
440+ return clusterHasInferenceEndpoint (client , TaskType .RERANK , "test_reranker" );
441+ }
442+
443+ public static void createCompletionInferenceEndpoint (RestClient client ) throws IOException {
444+ createInferenceEndpoint (client , TaskType .COMPLETION , "test_completion" , """
445+ {
446+ "service": "completion_test_service",
447+ "service_settings": { "model": "my_model", "api_key": "abc64" },
448+ "task_settings": { "temperature": 3 }
449+ }
450+ """ );
451+ }
452+
453+ public static void deleteCompletionInferenceEndpoint (RestClient client ) throws IOException {
454+ deleteInferenceEndpoint (client , "test_completion" );
455+ }
456+
457+ public static boolean clusterHasCompletionInferenceEndpoint (RestClient client ) throws IOException {
458+ return clusterHasInferenceEndpoint (client , TaskType .COMPLETION , "test_completion" );
459+ }
460+
461+ private static void createInferenceEndpoint (RestClient client , TaskType taskType , String inferenceId , String modelSettings )
462+ throws IOException {
463+ Request request = new Request ("PUT" , "_inference/" + taskType .name () + "/" + inferenceId );
464+ request .setJsonEntity (modelSettings );
465+ client .performRequest (request );
466+ }
467+
468+ private static boolean clusterHasInferenceEndpoint (RestClient client , TaskType taskType , String inferenceId ) throws IOException {
469+ Request request = new Request ("GET" , "_inference/" + taskType .name () + "/" + inferenceId );
456470 try {
457471 client .performRequest (request );
458472 } catch (ResponseException e ) {
@@ -464,6 +478,17 @@ public static boolean clusterHasRerankInferenceEndpoint(RestClient client) throw
464478 return true ;
465479 }
466480
481+ private static void deleteInferenceEndpoint (RestClient client , String inferenceId ) throws IOException {
482+ try {
483+ client .performRequest (new Request ("DELETE" , "_inference/" + inferenceId ));
484+ } catch (ResponseException e ) {
485+ // 404 here means the endpoint was not created
486+ if (e .getResponse ().getStatusLine ().getStatusCode () != 404 ) {
487+ throw e ;
488+ }
489+ }
490+ }
491+
467492 private static void loadEnrichPolicy (RestClient client , String policyName , String policyFileName , Logger logger ) throws IOException {
468493 URL policyMapping = getResource ("/" + policyFileName );
469494 String entity = readTextFile (policyMapping );
0 commit comments