Skip to content

Commit 606a130

Browse files
committed
Refactored inference operator.
1 parent 3bb8ff5 commit 606a130

19 files changed

+703
-339
lines changed

muted-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ tests:
315315
- class: org.elasticsearch.search.basic.SearchWithRandomDisconnectsIT
316316
method: testSearchWithRandomDisconnects
317317
issue: https://github.com/elastic/elasticsearch/issues/122707
318-
- class: org.elasticsearch.xpack.esql.inference.RerankOperatorTests
318+
- class: org.elasticsearch.xpack.esql.inference.rerank.RerankOperatorTests
319319
method: testSimpleCircuitBreaking
320320
issue: https://github.com/elastic/elasticsearch/issues/124337
321321
- class: org.elasticsearch.index.engine.ThreadPoolMergeSchedulerTests

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/CompletionOperator.java

Lines changed: 0 additions & 125 deletions
This file was deleted.

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,55 +8,55 @@
88
package org.elasticsearch.xpack.esql.inference;
99

1010
import org.elasticsearch.action.ActionListener;
11-
import org.elasticsearch.common.util.concurrent.ThreadContext;
11+
import org.elasticsearch.compute.data.Page;
1212
import org.elasticsearch.compute.operator.AsyncOperator;
1313
import org.elasticsearch.compute.operator.DriverContext;
1414
import org.elasticsearch.inference.InferenceServiceResults;
15+
import org.elasticsearch.inference.TaskType;
1516
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
17+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceOperation;
18+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceOutputBuilder;
19+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
1620

17-
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
18-
19-
abstract public class InferenceOperator<Fetched, InferenceResult extends InferenceServiceResults> extends AsyncOperator<Fetched> {
21+
public abstract class InferenceOperator<InferenceResult extends InferenceServiceResults> extends AsyncOperator<Page> {
2022

2123
// Move to a setting.
2224
private static final int MAX_INFERENCE_WORKER = 10;
2325
private final InferenceRunner inferenceRunner;
2426
private final String inferenceId;
25-
private final Class<InferenceResult> inferenceResultClass;
26-
27-
public InferenceOperator(
28-
DriverContext driverContext,
29-
ThreadContext threadContext,
30-
InferenceRunner inferenceRunner,
31-
String inferenceId,
32-
Class<InferenceResult> inferenceResultClass
33-
) {
34-
super(driverContext, threadContext, MAX_INFERENCE_WORKER);
27+
28+
public InferenceOperator(DriverContext driverContext, InferenceRunner inferenceRunner, String inferenceId) {
29+
super(driverContext, inferenceRunner.threadContext(), MAX_INFERENCE_WORKER);
3530
this.inferenceRunner = inferenceRunner;
3631
this.inferenceId = inferenceId;
37-
this.inferenceResultClass = inferenceResultClass;
32+
}
3833

39-
assert inferenceRunner.getThreadContext() != null;
34+
protected String inferenceId() {
35+
return inferenceId;
4036
}
4137

42-
protected final void doInference(InferenceAction.Request inferenceRequest, ActionListener<InferenceResult> listener) {
43-
inferenceRunner.doInference(inferenceRequest, listener.map(this::checkedInferenceResults));
38+
@Override
39+
protected void releaseFetchedOnAnyThread(Page page) {
40+
releasePageOnAnyThread(page);
4441
}
4542

46-
protected String inferenceId() {
47-
return inferenceId;
43+
@Override
44+
public Page getOutput() {
45+
return fetchFromBuffer();
46+
}
47+
48+
@Override
49+
protected void performAsync(Page input, ActionListener<Page> listener) {
50+
new BulkInferenceOperation<>(bulkInferenceRequestIterator(input), bulkOutputBuilder(input)).execute(inferenceRunner, listener);
4851
}
4952

50-
private InferenceResult checkedInferenceResults(InferenceAction.Response inferenceResponse) {
51-
if (inferenceResultClass.isInstance(inferenceResponse.getResults())) {
52-
return inferenceResultClass.cast(inferenceResponse.getResults());
53-
}
54-
throw new IllegalStateException(
55-
format(
56-
"Inference result has wrong type. Got [{}] while expecting [{}]",
57-
inferenceResponse.getResults().getClass().getName(),
58-
inferenceResultClass.getName()
59-
)
60-
);
53+
protected InferenceAction.Request.Builder inferenceRequestBuilder() {
54+
return InferenceAction.Request.builder(inferenceId, taskType());
6155
}
56+
57+
protected abstract TaskType taskType();
58+
59+
protected abstract BulkInferenceRequestIterator bulkInferenceRequestIterator(Page input);
60+
61+
protected abstract BulkInferenceOutputBuilder<InferenceResult, Page> bulkOutputBuilder(Page input);
6262
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference;
9+
10+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
11+
12+
import java.util.function.Supplier;
13+
14+
public interface InferenceRequestBuilderSupplier extends Supplier<InferenceAction.Request.Builder> {}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference;
9+
10+
import org.elasticsearch.common.CheckedSupplier;
11+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
12+
13+
public interface InferenceRequestSupplier extends CheckedSupplier<InferenceAction.Request, Exception> {}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,32 @@
1212
import org.elasticsearch.client.internal.Client;
1313
import org.elasticsearch.common.util.concurrent.ThreadContext;
1414
import org.elasticsearch.inference.TaskType;
15+
import org.elasticsearch.threadpool.ThreadPool;
1516
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
1617
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1718
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
1819
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
1920

2021
import java.util.List;
2122
import java.util.Set;
23+
import java.util.concurrent.ExecutorService;
2224
import java.util.stream.Collectors;
2325

2426
public class InferenceRunner {
2527

2628
private final Client client;
2729

30+
private final ExecutorService executorService;
31+
private final ThreadContext threadContext;
32+
2833
public InferenceRunner(Client client) {
2934
this.client = client;
35+
// TODO: revisit the executor service instantiation and thread pool choice.
36+
this.executorService = client.threadPool().executor(ThreadPool.Names.SEARCH);
37+
this.threadContext = client.threadPool().getThreadContext();
3038
}
3139

32-
public ThreadContext getThreadContext() {
40+
public ThreadContext threadContext() {
3341
return client.threadPool().getThreadContext();
3442
}
3543

@@ -72,7 +80,20 @@ private static String planInferenceId(InferencePlan<?> plan) {
7280
return plan.inferenceId().fold(FoldContext.small()).toString();
7381
}
7482

75-
public void doInference(InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) {
76-
client.execute(InferenceAction.INSTANCE, request, listener);
83+
public void doInference(InferenceRequestSupplier request, ActionListener<InferenceAction.Response> listener) {
84+
try {
85+
if (request == null) {
86+
listener.onResponse(null);
87+
}
88+
executorService.submit(() -> {
89+
try {
90+
client.execute(InferenceAction.INSTANCE, request.get(), listener);
91+
} catch (Exception e) {
92+
listener.onFailure(e);
93+
}
94+
});
95+
} catch (Exception e) {
96+
listener.onFailure(e);
97+
}
7798
}
7899
}

0 commit comments

Comments
 (0)