Skip to content

Commit 5f5ffdc

Browse files
committed
CSV tests inference refactoring
1 parent 28cdd88 commit 5f5ffdc

File tree

2 files changed

+84
-64
lines changed

2 files changed

+84
-64
lines changed

x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,8 @@
6767
import static org.elasticsearch.xpack.esql.CsvTestUtils.isEnabled;
6868
import static org.elasticsearch.xpack.esql.CsvTestUtils.loadCsvSpecValues;
6969
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.availableDatasetsForEs;
70-
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.clusterHasInferenceEndpoint;
71-
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.clusterHasRerankInferenceEndpoint;
72-
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createInferenceEndpoint;
73-
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createRerankInferenceEndpoint;
74-
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteInferenceEndpoint;
75-
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteRerankInferenceEndpoint;
70+
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createInferenceEndpoints;
71+
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteInferenceEndpoints;
7672
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.loadDataSetIntoEs;
7773
import static org.elasticsearch.xpack.esql.EsqlTestUtils.classpathResources;
7874
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.METRICS_COMMAND;
@@ -138,12 +134,8 @@ protected EsqlSpecTestCase(
138134

139135
@Before
140136
public void setup() throws IOException {
141-
if (supportsInferenceTestService() && clusterHasInferenceEndpoint(client()) == false) {
142-
createInferenceEndpoint(client());
143-
}
144-
145-
if (supportsInferenceTestService() && clusterHasRerankInferenceEndpoint(client()) == false) {
146-
createRerankInferenceEndpoint(client());
137+
if (supportsInferenceTestService()) {
138+
createInferenceEndpoints(adminClient());
147139
}
148140

149141
boolean supportsLookup = supportsIndexModeLookup();
@@ -164,8 +156,8 @@ public static void wipeTestData() throws IOException {
164156
}
165157
}
166158

167-
deleteInferenceEndpoint(client());
168-
deleteRerankInferenceEndpoint(client());
159+
deleteInferenceEndpoints(adminClient());
160+
169161
}
170162

171163
public boolean logResults() {

x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java

Lines changed: 78 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.common.settings.Settings;
2828
import org.elasticsearch.common.xcontent.XContentHelper;
2929
import org.elasticsearch.core.Nullable;
30+
import org.elasticsearch.inference.TaskType;
3031
import org.elasticsearch.logging.LogManager;
3132
import org.elasticsearch.logging.Logger;
3233
import org.elasticsearch.test.rest.ESRestTestCase;
@@ -310,7 +311,7 @@ public static Set<TestDataset> availableDatasetsForEs(
310311
boolean supportsIndexModeLookup,
311312
boolean supportsSourceFieldMapping
312313
) throws IOException {
313-
boolean inferenceEnabled = clusterHasInferenceEndpoint(client);
314+
boolean inferenceEnabled = clusterHasSparseEmbeddingInferenceEndpoint(client);
314315

315316
Set<TestDataset> testDataSets = new HashSet<>();
316317

@@ -372,77 +373,93 @@ private static void loadDataSetIntoEs(
372373
}
373374
}
374375

376+
public static void createInferenceEndpoints(RestClient client) throws IOException {
377+
if (clusterHasSparseEmbeddingInferenceEndpoint(client) == false) {
378+
createSparseEmbeddingInferenceEndpoint(client);
379+
}
380+
381+
if (clusterHasRerankInferenceEndpoint(client) == false) {
382+
createRerankInferenceEndpoint(client);
383+
}
384+
385+
if (clusterHasCompletionInferenceEndpoint(client) == false) {
386+
createCompletionInferenceEndpoint(client);
387+
}
388+
}
389+
390+
public static void deleteInferenceEndpoints(RestClient client) throws IOException {
391+
deleteSparseEmbeddingInferenceEndpoint(client);
392+
deleteRerankInferenceEndpoint(client);
393+
deleteCompletionInferenceEndpoint(client);
394+
}
395+
396+
375397
/** The semantic_text mapping type require an inference endpoint that needs to be setup before creating the index. */
376-
public static void createInferenceEndpoint(RestClient client) throws IOException {
377-
Request request = new Request("PUT", "_inference/sparse_embedding/test_sparse_inference");
378-
request.setJsonEntity("""
398+
public static void createSparseEmbeddingInferenceEndpoint(RestClient client) throws IOException {
399+
createInferenceEndpoint(client, TaskType.SPARSE_EMBEDDING, "test_sparse_inference",
400+
"""
379401
{
380402
"service": "test_service",
381-
"service_settings": {
382-
"model": "my_model",
383-
"api_key": "abc64"
384-
},
385-
"task_settings": {
386-
}
403+
"service_settings": { "model": "my_model", "api_key": "abc64" },
404+
"task_settings": { }
387405
}
388406
""");
389-
client.performRequest(request);
390407
}
391408

392-
public static void deleteInferenceEndpoint(RestClient client) throws IOException {
393-
try {
394-
client.performRequest(new Request("DELETE", "_inference/test_sparse_inference"));
395-
} catch (ResponseException e) {
396-
// 404 here means the endpoint was not created
397-
if (e.getResponse().getStatusLine().getStatusCode() != 404) {
398-
throw e;
399-
}
400-
}
409+
public static void deleteSparseEmbeddingInferenceEndpoint(RestClient client) throws IOException {
410+
deleteInferenceEndpoint(client, "test_sparse_inference");
401411
}
402412

403-
public static boolean clusterHasInferenceEndpoint(RestClient client) throws IOException {
404-
Request request = new Request("GET", "_inference/sparse_embedding/test_sparse_inference");
405-
try {
406-
client.performRequest(request);
407-
} catch (ResponseException e) {
408-
if (e.getResponse().getStatusLine().getStatusCode() == 404) {
409-
return false;
410-
}
411-
throw e;
412-
}
413-
return true;
413+
public static boolean clusterHasSparseEmbeddingInferenceEndpoint(RestClient client) throws IOException {
414+
return clusterHasInferenceEndpoint(client, TaskType.SPARSE_EMBEDDING, "test_sparse_inference");
414415
}
415416

416417
public static void createRerankInferenceEndpoint(RestClient client) throws IOException {
417-
Request request = new Request("PUT", "_inference/rerank/test_reranker");
418-
request.setJsonEntity("""
418+
createInferenceEndpoint(client, TaskType.RERANK, "test_reranker", """
419419
{
420420
"service": "test_reranking_service",
421-
"service_settings": {
422-
"model_id": "my_model",
423-
"api_key": "abc64"
424-
},
425-
"task_settings": {
426-
"use_text_length": true
427-
}
421+
"service_settings": { "model_id": "my_model", "api_key": "abc64" },
422+
"task_settings": { "use_text_length": true }
428423
}
429424
""");
430-
client.performRequest(request);
431425
}
432426

433427
public static void deleteRerankInferenceEndpoint(RestClient client) throws IOException {
434-
try {
435-
client.performRequest(new Request("DELETE", "_inference/rerank/test_reranker"));
436-
} catch (ResponseException e) {
437-
// 404 here means the endpoint was not created
438-
if (e.getResponse().getStatusLine().getStatusCode() != 404) {
439-
throw e;
440-
}
441-
}
428+
deleteInferenceEndpoint(client, "test_reranker");
442429
}
443430

444431
public static boolean clusterHasRerankInferenceEndpoint(RestClient client) throws IOException {
445-
Request request = new Request("GET", "_inference/rerank/test_reranker");
432+
return clusterHasInferenceEndpoint(client, TaskType.RERANK, "test_reranker");
433+
}
434+
435+
public static void createCompletionInferenceEndpoint(RestClient client) throws IOException {
436+
createInferenceEndpoint(client, TaskType.COMPLETION, "test_completion", """
437+
{
438+
"service": "streaming_completion_test_service",
439+
"service_settings": { "model": "my_model", "api_key": "abc64" },
440+
"task_settings": { "temperature": 3 }
441+
}
442+
""");
443+
}
444+
445+
public static void deleteCompletionInferenceEndpoint(RestClient client) throws IOException {
446+
deleteInferenceEndpoint(client, "test_completion");
447+
}
448+
449+
public static boolean clusterHasCompletionInferenceEndpoint(RestClient client) throws IOException {
450+
return clusterHasInferenceEndpoint(client, TaskType.COMPLETION, "test_completion");
451+
}
452+
453+
454+
private static void createInferenceEndpoint(RestClient client, TaskType taskType, String inferenceId, String modelSettings) throws IOException {
455+
Request request = new Request("PUT", "_inference/" + taskType.name() + "/" + inferenceId);
456+
request.setJsonEntity(modelSettings);
457+
client.performRequest(request);
458+
}
459+
460+
461+
private static boolean clusterHasInferenceEndpoint(RestClient client, TaskType taskType, String inferenceId) throws IOException {
462+
Request request = new Request("GET", "_inference/" + taskType.name() + "/" + inferenceId);
446463
try {
447464
client.performRequest(request);
448465
} catch (ResponseException e) {
@@ -454,6 +471,17 @@ public static boolean clusterHasRerankInferenceEndpoint(RestClient client) throw
454471
return true;
455472
}
456473

474+
private static void deleteInferenceEndpoint(RestClient client, String inferenceId) throws IOException {
475+
try {
476+
client.performRequest(new Request("DELETE", "_inference/" + inferenceId));
477+
} catch (ResponseException e) {
478+
// 404 here means the endpoint was not created
479+
if (e.getResponse().getStatusLine().getStatusCode() != 404) {
480+
throw e;
481+
}
482+
}
483+
}
484+
457485
private static void loadEnrichPolicy(RestClient client, String policyName, String policyFileName, Logger logger) throws IOException {
458486
URL policyMapping = getResource("/" + policyFileName);
459487
String entity = readTextFile(policyMapping);

0 commit comments

Comments
 (0)