Skip to content

Commit 5c19b29

Browse files
committed
Inference execution.
1 parent 89bdda4 commit 5c19b29

File tree

4 files changed

+107
-3
lines changed

4 files changed

+107
-3
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ public static class Request extends BaseInferenceActionRequest {
6060
public static final ParseField QUERY = new ParseField("query");
6161
public static final ParseField TIMEOUT = new ParseField("timeout");
6262

63+
public static Builder builder(String inferenceEntityId, TaskType taskType) {
64+
return new Builder().setInferenceEntityId(inferenceEntityId).setTaskType(taskType);
65+
}
66+
6367
static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME, Request.Builder::new);
6468
static {
6569
PARSER.declareStringArray(Request.Builder::setInput, INPUT);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.common.util.concurrent.ThreadContext;
1414
import org.elasticsearch.inference.TaskType;
1515
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
16+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1617
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
1718
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
1819

@@ -72,4 +73,8 @@ public void resolveInferences(List<InferencePlan> plans, ActionListener<Inferenc
7273

7374
listener.onResponse(inferenceResolution);
7475
}
76+
77+
public void doInference(InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) {
78+
client.execute(InferenceAction.INSTANCE, request, listener);
79+
}
7580
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/RerankOperator.java

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,31 @@
77

88
package org.elasticsearch.xpack.esql.plan.physical.inference;
99

10+
import org.apache.lucene.util.BytesRef;
1011
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.common.Strings;
13+
import org.elasticsearch.compute.data.Block;
1114
import org.elasticsearch.compute.data.BlockFactory;
15+
import org.elasticsearch.compute.data.BlockUtils;
16+
import org.elasticsearch.compute.data.DoubleBlock;
17+
import org.elasticsearch.compute.data.ElementType;
1218
import org.elasticsearch.compute.data.Page;
1319
import org.elasticsearch.compute.operator.AsyncOperator;
1420
import org.elasticsearch.compute.operator.DriverContext;
1521
import org.elasticsearch.compute.operator.EvalOperator;
1622
import org.elasticsearch.compute.operator.Operator;
23+
import org.elasticsearch.inference.TaskType;
1724
import org.elasticsearch.logging.LogManager;
1825
import org.elasticsearch.logging.Logger;
26+
import org.elasticsearch.xcontent.XContentBuilder;
27+
import org.elasticsearch.xcontent.XContentFactory;
28+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
29+
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
1930
import org.elasticsearch.xpack.esql.inference.InferenceService;
2031

32+
import java.io.IOException;
2133
import java.util.HashMap;
34+
import java.util.List;
2235
import java.util.Map;
2336

2437
public class RerankOperator extends AsyncOperator<Page> {
@@ -101,7 +114,17 @@ public RerankOperator(
101114

102115
@Override
103116
protected void performAsync(Page inputPage, ActionListener<Page> listener) {
104-
listener.onResponse(inputPage);
117+
try {
118+
inferenceService.doInference(
119+
buildInferenceRequest(inputPage),
120+
ActionListener.wrap(
121+
(inferenceResponse) -> listener.onResponse(buildOutput(inputPage, inferenceResponse)),
122+
listener::onFailure
123+
)
124+
);
125+
} catch (IOException e) {
126+
listener.onFailure(e);
127+
}
105128
}
106129

107130
@Override
@@ -118,4 +141,78 @@ protected void releaseFetchedOnAnyThread(Page page) {
118141
public Page getOutput() {
119142
return fetchFromBuffer();
120143
}
144+
145+
private Page buildOutput(Page inputPage, InferenceAction.Response inferenceResponse) {
146+
int blockCount = inputPage.getBlockCount();
147+
Block.Builder[] blocksBuilders = new Block.Builder[blockCount];
148+
149+
for (int b = 0; b < blockCount; b++) {
150+
if (b == scoreChannel) {
151+
blocksBuilders[b] = ElementType.DOUBLE.newBlockBuilder(inputPage.getPositionCount(), blockFactory);
152+
} else {
153+
blocksBuilders[b] = inputPage.getBlock(b).elementType().newBlockBuilder(inputPage.getPositionCount(), blockFactory);
154+
}
155+
}
156+
157+
if (inferenceResponse.getResults() instanceof RankedDocsResults rankedDocsResults) {
158+
for (var rankedDoc : rankedDocsResults.getRankedDocs()) {
159+
for (int b = 0; b < blockCount; b++) {
160+
if (b == scoreChannel) {
161+
if (blocksBuilders[b] instanceof DoubleBlock.Builder scoreBlockBuilder) {
162+
scoreBlockBuilder.beginPositionEntry().appendDouble(rankedDoc.relevanceScore()).endPositionEntry();
163+
}
164+
} else {
165+
blocksBuilders[b].copyFrom(inputPage.getBlock(b), rankedDoc.index(), rankedDoc.index() + 1);
166+
}
167+
}
168+
}
169+
170+
return new Page(Block.Builder.buildAll(blocksBuilders));
171+
}
172+
173+
throw new IllegalStateException(
174+
"Inference result has wrong type. Got ["
175+
+ inferenceResponse.getResults().getClass()
176+
+ "] while expecting ["
177+
+ RankedDocsResults.class
178+
+ "]"
179+
);
180+
}
181+
182+
private InferenceAction.Request buildInferenceRequest(Page inputPage) throws IOException {
183+
String[] inputs = new String[inputPage.getPositionCount()];
184+
Map<String, Block> inputBlocks = new HashMap<>();
185+
186+
for (var entry : rerankFieldsEvaluator.entrySet()) {
187+
inputBlocks.put(entry.getKey(), entry.getValue().eval(inputPage));
188+
}
189+
190+
for (int pos = 0; pos < inputPage.getPositionCount(); pos++) {
191+
inputs[pos] = toYaml(inputBlocks, pos);
192+
}
193+
194+
return InferenceAction.Request.builder(inferenceId, TaskType.RERANK).setInput(List.of(inputs)).setQuery(queryText).build();
195+
}
196+
197+
private String toYaml(Map<String, Block> inputBlocks, int position) throws IOException {
198+
try (XContentBuilder yamlBuilder = XContentFactory.yamlBuilder().startObject()) {
199+
for (var blockEntry : inputBlocks.entrySet()) {
200+
String fieldName = blockEntry.getKey();
201+
Block currentBlock = blockEntry.getValue();
202+
if (currentBlock.isNull(position)) {
203+
continue;
204+
}
205+
yamlBuilder.field(fieldName, toYaml(BlockUtils.toJavaObject(currentBlock, position)));
206+
}
207+
return Strings.toString(yamlBuilder.endObject());
208+
}
209+
}
210+
211+
private Object toYaml(Object value) {
212+
return switch (value) {
213+
case BytesRef b -> b.utf8ToString();
214+
case List<?> l -> l.stream().map(this::toYaml).toList();
215+
default -> value;
216+
};
217+
}
121218
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -537,8 +537,6 @@ private PhysicalOperation planRerank(RerankExec rerank, LocalExecutionPlannerCon
537537
}
538538
}
539539

540-
logger.warn("layout {}", source.layout);
541-
542540
return source.with(
543541
new RerankOperator.Factory(inferenceService, inferenceId, queryText, rerankFieldsEvaluatorSuppliers, scoreChannel),
544542
source.layout

0 commit comments

Comments
 (0)