Skip to content

Commit 0e918df

Browse files
committed
Better handling of the YAML encoding.
1 parent 78a7efb commit 0e918df

File tree

4 files changed

+189
-81
lines changed

4 files changed

+189
-81
lines changed

x-pack/plugin/esql/qa/testFixtures/src/main/resources/rerank.csv-spec

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ FROM books METADATA _score
1414
;
1515

1616
book_no:keyword | title:text | author:text | _score:double
17-
5327 | War and Peace | Leo Tolstoy | 0.03703703731298447
17+
5327 | War and Peace | Leo Tolstoy | 0.03846153989434242
1818
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.02222222276031971
1919
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo | 0.02083333395421505
20-
2776 | The Devil and Other Stories (Oxford World's Classics) | Leo Tolstoy | 0.014925372786819935
20+
2776 | The Devil and Other Stories (Oxford World's Classics) | Leo Tolstoy | 0.01515151560306549
2121
;
2222

2323

@@ -32,9 +32,9 @@ FROM books METADATA _score
3232
;
3333

3434
book_no:keyword | title:text | author:text | _score:double
35-
5327 | War and Peace | Leo Tolstoy | 0.020408162847161293
35+
5327 | War and Peace | Leo Tolstoy | 0.02083333395421505
3636
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo | 0.014285714365541935
37-
2776 | The Devil and Other Stories (Oxford World's Classics) | Leo Tolstoy | 0.01123595517128706
37+
2776 | The Devil and Other Stories (Oxford World's Classics) | Leo Tolstoy | 0.011363636702299118
3838
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.009523809887468815
3939
;
4040

@@ -52,7 +52,7 @@ FROM books METADATA _score
5252
;
5353

5454
book_no:keyword | title:text | author:text | _score:double
55-
5327 | War and Peace | Leo Tolstoy | 0.03703703731298447
55+
5327 | War and Peace | Leo Tolstoy | 0.03846153989434242
5656
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.02222222276031971
5757
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo | 0.02083333395421505
5858
;
@@ -70,7 +70,7 @@ FROM books METADATA _score
7070
;
7171

7272
book_no:keyword | title:text | author:text | _score:double
73-
5327 | War and Peace | Leo Tolstoy | 0.03703703731298447
73+
5327 | War and Peace | Leo Tolstoy | 0.03846153989434242
7474
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.02222222276031971
7575
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo | 0.02083333395421505
7676
;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/RerankOperator.java

Lines changed: 29 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@
99

1010
import org.apache.lucene.util.BytesRef;
1111
import org.elasticsearch.action.ActionListener;
12-
import org.elasticsearch.common.Strings;
1312
import org.elasticsearch.compute.data.Block;
1413
import org.elasticsearch.compute.data.BlockFactory;
15-
import org.elasticsearch.compute.data.BlockUtils;
14+
import org.elasticsearch.compute.data.BytesRefBlock;
1615
import org.elasticsearch.compute.data.DoubleBlock;
1716
import org.elasticsearch.compute.data.Page;
1817
import org.elasticsearch.compute.operator.AsyncOperator;
@@ -21,16 +20,14 @@
2120
import org.elasticsearch.compute.operator.Operator;
2221
import org.elasticsearch.core.Releasables;
2322
import org.elasticsearch.inference.TaskType;
24-
import org.elasticsearch.xcontent.XContentBuilder;
25-
import org.elasticsearch.xcontent.XContentFactory;
2623
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
2724
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
2825

29-
import java.io.IOException;
30-
import java.io.UncheckedIOException;
3126
import java.util.List;
3227
import java.util.Map;
3328

29+
import static org.elasticsearch.xpack.esql.inference.XContentRowEncoder.yamlRowEncoderFactory;
30+
3431
public class RerankOperator extends AsyncOperator<Page> {
3532

3633
// Move to a setting.
@@ -46,15 +43,15 @@ public record Factory(
4643

4744
@Override
4845
public String describe() {
49-
return "RerankOperator[inference_id="
46+
return "RerankOperator[inference_id=["
5047
+ inferenceId
51-
+ " query="
48+
+ "], query=["
5249
+ queryText
53-
+ " rerank_fields="
50+
+ "], rerank_fields="
5451
+ fieldsEvaluatorFactories.keySet()
55-
+ " score_channel="
52+
+ ", score_channel=["
5653
+ scoreChannel
57-
+ "]";
54+
+ "]]";
5855
}
5956

6057
@Override
@@ -64,56 +61,43 @@ public Operator get(DriverContext driverContext) {
6461
inferenceService,
6562
inferenceId,
6663
queryText,
67-
fieldNames(),
68-
fieldsEvaluators(driverContext),
64+
yamlRowEncoderFactory(fieldsEvaluatorFactories).get(driverContext),
6965
scoreChannel
7066
);
7167
}
72-
73-
private String[] fieldNames() {
74-
return fieldsEvaluatorFactories.keySet().toArray(String[]::new);
75-
}
76-
77-
private ExpressionEvaluator[] fieldsEvaluators(DriverContext context) {
78-
return fieldsEvaluatorFactories.values().stream().map(factory -> factory.get(context)).toArray(ExpressionEvaluator[]::new);
79-
}
8068
}
8169

8270
private final InferenceService inferenceService;
8371
private final BlockFactory blockFactory;
8472
private final String inferenceId;
8573
private final String queryText;
86-
private final String[] fieldNames;
87-
private final ExpressionEvaluator[] fieldsEvaluators;
74+
private final ExpressionEvaluator rowEncoder;
8875
private final int scoreChannel;
8976

9077
public RerankOperator(
9178
DriverContext driverContext,
9279
InferenceService inferenceService,
9380
String inferenceId,
9481
String queryText,
95-
String[] fieldNames,
96-
ExpressionEvaluator[] fieldsEvaluators,
82+
ExpressionEvaluator rowEncoder,
9783
int scoreChannel
9884
) {
9985
super(driverContext, inferenceService.getThreadContext(), MAX_INFERENCE_WORKER);
10086

10187
assert inferenceService.getThreadContext() != null;
102-
assert fieldNames.length == fieldsEvaluators.length;
10388

10489
this.blockFactory = driverContext.blockFactory();
10590
this.inferenceService = inferenceService;
10691
this.inferenceId = inferenceId;
10792
this.queryText = queryText;
108-
this.fieldNames = fieldNames;
109-
this.fieldsEvaluators = fieldsEvaluators;
93+
this.rowEncoder = rowEncoder;
11094
this.scoreChannel = scoreChannel;
11195
}
11296

11397
@Override
11498
protected void performAsync(Page inputPage, ActionListener<Page> listener) {
11599
// Ensure input page blocks are released when the listener is called.
116-
final ActionListener<Page> outputListener = ActionListener.runAfter(listener, () -> { inputPage.releaseBlocks(); });
100+
final ActionListener<Page> outputListener = ActionListener.runAfter(listener, () -> { releasePageOnAnyThread(inputPage); });
117101

118102
try {
119103
inferenceService.doInference(
@@ -130,7 +114,7 @@ protected void performAsync(Page inputPage, ActionListener<Page> listener) {
130114

131115
@Override
132116
protected void doClose() {
133-
Releasables.closeExpectNoException(this.fieldsEvaluators);
117+
Releasables.closeExpectNoException(rowEncoder);
134118
}
135119

136120
@Override
@@ -145,15 +129,15 @@ public Page getOutput() {
145129

146130
@Override
147131
public String toString() {
148-
return "RerankOperator[inference_id="
132+
return "RerankOperator[inference_id=["
149133
+ inferenceId
150-
+ " query="
134+
+ "], query=["
151135
+ queryText
152-
+ " rerank_fields="
153-
+ List.of(fieldNames)
154-
+ " score_channel="
136+
+ "], row_encoder=["
137+
+ rowEncoder
138+
+ "], score_channel=["
155139
+ scoreChannel
156-
+ "]";
140+
+ "]]";
157141
}
158142

159143
private Page buildOutput(Page inputPage, InferenceAction.Response inferenceResponse) {
@@ -212,46 +196,21 @@ private Block buildScoreBlock(Page inputPage, RankedDocsResults rankedDocsResult
212196
}
213197

214198
private InferenceAction.Request buildInferenceRequest(Page inputPage) {
215-
Block[] inputBlocks = new Block[fieldsEvaluators.length];
216-
217-
try {
218-
for (int b = 0; b < inputBlocks.length; b++) {
219-
inputBlocks[b] = fieldsEvaluators[b].eval(inputPage);
220-
}
221-
199+
try (BytesRefBlock encodedRowBlock = (BytesRefBlock) rowEncoder.eval(inputPage)) {
200+
assert (encodedRowBlock.getPositionCount() == inputPage.getPositionCount());
222201
String[] inputs = new String[inputPage.getPositionCount()];
202+
BytesRef buffer = new BytesRef();
203+
223204
for (int pos = 0; pos < inputPage.getPositionCount(); pos++) {
224-
try (XContentBuilder yamlBuilder = XContentFactory.yamlBuilder().startObject()) {
225-
for (int i = 0; i < inputBlocks.length; i++) {
226-
String fieldName = fieldNames[i];
227-
Block currentBlock = inputBlocks[i];
228-
if (currentBlock.isNull(pos)) {
229-
continue;
230-
}
231-
yamlBuilder.field(fieldName, toYaml(BlockUtils.toJavaObject(currentBlock, pos)));
232-
}
233-
inputs[pos] = Strings.toString(yamlBuilder.endObject());
234-
} catch (IOException e) {
235-
throw new UncheckedIOException(e);
205+
if (encodedRowBlock.isNull(pos)) {
206+
inputs[pos] = "";
207+
} else {
208+
buffer = encodedRowBlock.getBytesRef(encodedRowBlock.getFirstValueIndex(pos), buffer);
209+
inputs[pos] = buffer.utf8ToString();
236210
}
237211
}
238212

239213
return InferenceAction.Request.builder(inferenceId, TaskType.RERANK).setInput(List.of(inputs)).setQuery(queryText).build();
240-
} finally {
241-
Releasables.closeExpectNoException(inputBlocks);
242-
}
243-
}
244-
245-
private Object toYaml(Object value) {
246-
try {
247-
return switch (value) {
248-
case BytesRef b -> b.utf8ToString();
249-
case List<?> l -> l.stream().map(this::toYaml).toList();
250-
default -> value;
251-
};
252-
} catch (Error | Exception e) {
253-
// Swallow errors caused by invalid byteref.
254-
return "";
255214
}
256215
}
257216
}
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference;
9+
10+
import org.apache.lucene.util.BytesRef;
11+
import org.elasticsearch.common.io.stream.BytesRefStreamOutput;
12+
import org.elasticsearch.compute.data.Block;
13+
import org.elasticsearch.compute.data.BlockFactory;
14+
import org.elasticsearch.compute.data.BlockUtils;
15+
import org.elasticsearch.compute.data.BytesRefBlock;
16+
import org.elasticsearch.compute.data.Page;
17+
import org.elasticsearch.compute.operator.DriverContext;
18+
import org.elasticsearch.compute.operator.EvalOperator;
19+
import org.elasticsearch.core.Releasables;
20+
import org.elasticsearch.xcontent.XContentBuilder;
21+
import org.elasticsearch.xcontent.XContentFactory;
22+
import org.elasticsearch.xcontent.XContentType;
23+
24+
import java.io.IOException;
25+
import java.io.UncheckedIOException;
26+
import java.util.List;
27+
import java.util.Map;
28+
29+
class XContentRowEncoder implements EvalOperator.ExpressionEvaluator {
30+
private final XContentType xContentType;
31+
private final BlockFactory blockFactory;
32+
private final String[] fieldNames;
33+
private final EvalOperator.ExpressionEvaluator[] fieldsValueEvaluators;
34+
35+
public static Factory yamlRowEncoderFactory(Map<String, EvalOperator.ExpressionEvaluator.Factory> fieldsEvaluatorFactories) {
36+
return new Factory(XContentType.YAML, fieldsEvaluatorFactories);
37+
}
38+
39+
private XContentRowEncoder(
40+
XContentType xContentType,
41+
BlockFactory blockFactory,
42+
String[] fieldNames,
43+
EvalOperator.ExpressionEvaluator[] fieldsValueEvaluators
44+
) {
45+
assert fieldNames.length == fieldsValueEvaluators.length;
46+
this.xContentType = xContentType;
47+
this.blockFactory = blockFactory;
48+
this.fieldNames = fieldNames;
49+
this.fieldsValueEvaluators = fieldsValueEvaluators;
50+
}
51+
52+
@Override
53+
public void close() {
54+
Releasables.closeExpectNoException(fieldsValueEvaluators);
55+
}
56+
57+
@Override
58+
public Block eval(Page page) {
59+
Block[] fieldValueBlocks = new Block[fieldsValueEvaluators.length];
60+
try (
61+
BytesRefStreamOutput outputStream = new BytesRefStreamOutput();
62+
XContentBuilder xContentBuilder = XContentFactory.contentBuilder(xContentType, outputStream);
63+
BytesRefBlock.Builder outputBlockBuilder = blockFactory.newBytesRefBlockBuilder(page.getPositionCount());
64+
) {
65+
for (int b = 0; b < fieldValueBlocks.length; b++) {
66+
fieldValueBlocks[b] = fieldsValueEvaluators[b].eval(page);
67+
}
68+
69+
for (int pos = 0; pos < page.getPositionCount(); pos++) {
70+
xContentBuilder.startObject();
71+
for (int i = 0; i < fieldValueBlocks.length; i++) {
72+
String fieldName = fieldNames[i];
73+
Block currentBlock = fieldValueBlocks[i];
74+
if (currentBlock.isNull(pos)) {
75+
continue;
76+
}
77+
xContentBuilder.field(fieldName, toYamlValue(BlockUtils.toJavaObject(currentBlock, pos)));
78+
}
79+
xContentBuilder.endObject().flush();
80+
outputBlockBuilder.appendBytesRef(outputStream.get());
81+
outputStream.reset();
82+
}
83+
84+
return outputBlockBuilder.build();
85+
} catch (IOException e) {
86+
throw new UncheckedIOException(e);
87+
} finally {
88+
Releasables.closeExpectNoException(fieldValueBlocks);
89+
}
90+
}
91+
92+
@Override
93+
public String toString() {
94+
return "XContentRowEncoder[content_type=[" + xContentType.toString() + "], field_names=" + List.of(fieldNames) + "]";
95+
}
96+
97+
private Object toYamlValue(Object value) {
98+
try {
99+
return switch (value) {
100+
case BytesRef b -> b.utf8ToString();
101+
case List<?> l -> l.stream().map(this::toYamlValue).toList();
102+
default -> value;
103+
};
104+
} catch (Error | Exception e) {
105+
// Swallow errors caused by invalid byteref.
106+
return "";
107+
}
108+
}
109+
110+
public static final class Factory implements EvalOperator.ExpressionEvaluator.Factory {
111+
private final XContentType xContentType;
112+
private final Map<String, EvalOperator.ExpressionEvaluator.Factory> fieldsEvaluatorFactories;
113+
114+
private Factory(XContentType xContentType, Map<String, EvalOperator.ExpressionEvaluator.Factory> fieldsEvaluatorFactories) {
115+
this.xContentType = xContentType;
116+
this.fieldsEvaluatorFactories = fieldsEvaluatorFactories;
117+
}
118+
119+
@Override
120+
public EvalOperator.ExpressionEvaluator get(DriverContext context) {
121+
return new XContentRowEncoder(xContentType, context.blockFactory(), fieldNames(), fieldsValueEvaluators(context));
122+
}
123+
124+
private String[] fieldNames() {
125+
return fieldsEvaluatorFactories.keySet().toArray(String[]::new);
126+
}
127+
128+
private EvalOperator.ExpressionEvaluator[] fieldsValueEvaluators(DriverContext context) {
129+
return fieldsEvaluatorFactories.values()
130+
.stream()
131+
.map(factory -> factory.get(context))
132+
.toArray(EvalOperator.ExpressionEvaluator[]::new);
133+
}
134+
}
135+
}

0 commit comments

Comments
 (0)