Skip to content

Commit 31ac987

Browse files
committed
Continue refactoring of the RowEncoder.
1 parent 588b53a commit 31ac987

File tree

5 files changed

+63
-92
lines changed

5 files changed

+63
-92
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/RerankOperator.java

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,21 @@
99

1010
import org.apache.lucene.util.BytesRef;
1111
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.common.lucene.BytesRefs;
1213
import org.elasticsearch.compute.data.Block;
1314
import org.elasticsearch.compute.data.BlockFactory;
1415
import org.elasticsearch.compute.data.BytesRefBlock;
1516
import org.elasticsearch.compute.data.DoubleBlock;
1617
import org.elasticsearch.compute.data.Page;
1718
import org.elasticsearch.compute.operator.AsyncOperator;
1819
import org.elasticsearch.compute.operator.DriverContext;
19-
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
2020
import org.elasticsearch.compute.operator.Operator;
2121
import org.elasticsearch.core.Releasables;
2222
import org.elasticsearch.inference.TaskType;
2323
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
2424
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
2525

2626
import java.util.List;
27-
import java.util.Map;
28-
29-
import static org.elasticsearch.xpack.esql.inference.XContentRowEncoder.yamlRowEncoderFactory;
3027

3128
public class RerankOperator extends AsyncOperator<Page> {
3229

@@ -37,21 +34,13 @@ public record Factory(
3734
InferenceService inferenceService,
3835
String inferenceId,
3936
String queryText,
40-
Map<String, ExpressionEvaluator.Factory> fieldsEvaluatorFactories,
37+
RowEncoder.Factory<BytesRefBlock> rowEncoderFactory,
4138
int scoreChannel
4239
) implements OperatorFactory {
4340

4441
@Override
4542
public String describe() {
46-
return "RerankOperator[inference_id=["
47-
+ inferenceId
48-
+ "], query=["
49-
+ queryText
50-
+ "], rerank_fields="
51-
+ fieldsEvaluatorFactories.keySet()
52-
+ ", score_channel=["
53-
+ scoreChannel
54-
+ "]]";
43+
return "RerankOperator[inference_id=[" + inferenceId + "], query=[" + queryText + "], score_channel=[" + scoreChannel + "]]";
5544
}
5645

5746
@Override
@@ -61,7 +50,7 @@ public Operator get(DriverContext driverContext) {
6150
inferenceService,
6251
inferenceId,
6352
queryText,
64-
yamlRowEncoderFactory(fieldsEvaluatorFactories).get(driverContext),
53+
rowEncoderFactory().get(driverContext),
6554
scoreChannel
6655
);
6756
}
@@ -71,15 +60,15 @@ public Operator get(DriverContext driverContext) {
7160
private final BlockFactory blockFactory;
7261
private final String inferenceId;
7362
private final String queryText;
74-
private final ExpressionEvaluator rowEncoder;
63+
private final RowEncoder<BytesRefBlock> rowEncoder;
7564
private final int scoreChannel;
7665

7766
public RerankOperator(
7867
DriverContext driverContext,
7968
InferenceService inferenceService,
8069
String inferenceId,
8170
String queryText,
82-
ExpressionEvaluator rowEncoder,
71+
RowEncoder<BytesRefBlock> rowEncoder,
8372
int scoreChannel
8473
) {
8574
super(driverContext, inferenceService.getThreadContext(), MAX_INFERENCE_WORKER);
@@ -129,15 +118,7 @@ public Page getOutput() {
129118

130119
@Override
131120
public String toString() {
132-
return "RerankOperator[inference_id=["
133-
+ inferenceId
134-
+ "], query=["
135-
+ queryText
136-
+ "], row_encoder=["
137-
+ rowEncoder
138-
+ "], score_channel=["
139-
+ scoreChannel
140-
+ "]]";
121+
return "RerankOperator[inference_id=[" + inferenceId + "], query=[" + queryText + "], score_channel=[" + scoreChannel + "]]";
141122
}
142123

143124
private Page buildOutput(Page inputPage, InferenceAction.Response inferenceResponse) {
@@ -196,17 +177,17 @@ private Block buildScoreBlock(Page inputPage, RankedDocsResults rankedDocsResult
196177
}
197178

198179
private InferenceAction.Request buildInferenceRequest(Page inputPage) {
199-
try (BytesRefBlock encodedRowBlock = (BytesRefBlock) rowEncoder.eval(inputPage)) {
200-
assert (encodedRowBlock.getPositionCount() == inputPage.getPositionCount());
180+
try (BytesRefBlock encodedRowsBlock = rowEncoder.encodeRows(inputPage)) {
181+
assert (encodedRowsBlock.getPositionCount() == inputPage.getPositionCount());
201182
String[] inputs = new String[inputPage.getPositionCount()];
202183
BytesRef buffer = new BytesRef();
203184

204185
for (int pos = 0; pos < inputPage.getPositionCount(); pos++) {
205-
if (encodedRowBlock.isNull(pos)) {
186+
if (encodedRowsBlock.isNull(pos)) {
206187
inputs[pos] = "";
207188
} else {
208-
buffer = encodedRowBlock.getBytesRef(encodedRowBlock.getFirstValueIndex(pos), buffer);
209-
inputs[pos] = buffer.utf8ToString();
189+
buffer = encodedRowsBlock.getBytesRef(encodedRowsBlock.getFirstValueIndex(pos), buffer);
190+
inputs[pos] = BytesRefs.toString(buffer);
210191
}
211192
}
212193

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference;
9+
10+
import org.elasticsearch.compute.data.Block;
11+
import org.elasticsearch.compute.data.Page;
12+
import org.elasticsearch.compute.operator.DriverContext;
13+
import org.elasticsearch.core.Releasable;
14+
15+
public interface RowEncoder<B extends Block> extends Releasable {
16+
17+
B encodeRows(Page page);
18+
19+
interface Factory<B extends Block> {
20+
RowEncoder<B> get(DriverContext context);
21+
}
22+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/XContentRowEncoder.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import java.util.List;
2727
import java.util.Map;
2828

29-
class XContentRowEncoder implements EvalOperator.ExpressionEvaluator {
29+
public class XContentRowEncoder implements RowEncoder<BytesRefBlock> {
3030
private final XContentType xContentType;
3131
private final BlockFactory blockFactory;
3232
private final String[] fieldNames;
@@ -55,7 +55,7 @@ public void close() {
5555
}
5656

5757
@Override
58-
public Block eval(Page page) {
58+
public BytesRefBlock encodeRows(Page page) {
5959
Block[] fieldValueBlocks = new Block[fieldsValueEvaluators.length];
6060
try (
6161
BytesRefStreamOutput outputStream = new BytesRefStreamOutput();
@@ -107,7 +107,7 @@ private Object toYamlValue(Object value) {
107107
}
108108
}
109109

110-
public static final class Factory implements EvalOperator.ExpressionEvaluator.Factory {
110+
public static final class Factory implements RowEncoder.Factory<BytesRefBlock> {
111111
private final XContentType xContentType;
112112
private final Map<String, EvalOperator.ExpressionEvaluator.Factory> fieldsEvaluatorFactories;
113113

@@ -116,8 +116,7 @@ private Factory(XContentType xContentType, Map<String, EvalOperator.ExpressionEv
116116
this.fieldsEvaluatorFactories = fieldsEvaluatorFactories;
117117
}
118118

119-
@Override
120-
public EvalOperator.ExpressionEvaluator get(DriverContext context) {
119+
public RowEncoder<BytesRefBlock> get(DriverContext context) {
121120
return new XContentRowEncoder(xContentType, context.blockFactory(), fieldNames(), fieldsValueEvaluators(context));
122121
}
123122

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
import org.elasticsearch.xpack.esql.expression.Order;
8383
import org.elasticsearch.xpack.esql.inference.InferenceService;
8484
import org.elasticsearch.xpack.esql.inference.RerankOperator;
85+
import org.elasticsearch.xpack.esql.inference.XContentRowEncoder;
8586
import org.elasticsearch.xpack.esql.plan.logical.Fork;
8687
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
8788
import org.elasticsearch.xpack.esql.plan.physical.ChangePointExec;
@@ -566,19 +567,20 @@ private PhysicalOperation planRerank(RerankExec rerank, LocalExecutionPlannerCon
566567
);
567568
}
568569

570+
XContentRowEncoder.Factory rowEncoderFactory = XContentRowEncoder.yamlRowEncoderFactory(rerankFieldsEvaluatorSuppliers);
571+
569572
String inferenceId = BytesRefs.toString(rerank.inferenceId().fold(context.foldCtx));
570573
String queryText = BytesRefs.toString(rerank.queryText().fold(context.foldCtx));
571574

572-
Layout.Builder layoutBuilder = source.layout.builder();
575+
Layout outputLayout = source.layout;
573576
if (source.layout.get(rerank.scoreAttribute().id()) == null) {
574-
layoutBuilder.append(rerank.scoreAttribute());
577+
outputLayout = source.layout.builder().append(rerank.scoreAttribute()).build();
575578
}
576-
Layout outputLayout = layoutBuilder.build();
577579

578580
int scoreChannel = outputLayout.get(rerank.scoreAttribute().id()).channel();
579581

580582
return source.with(
581-
new RerankOperator.Factory(inferenceService, inferenceId, queryText, rerankFieldsEvaluatorSuppliers, scoreChannel),
583+
new RerankOperator.Factory(inferenceService, inferenceId, queryText, rowEncoderFactory, scoreChannel),
582584
outputLayout
583585
);
584586
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/RerankOperatorTests.java

Lines changed: 19 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import org.elasticsearch.compute.data.Page;
2525
import org.elasticsearch.compute.operator.AsyncOperator;
2626
import org.elasticsearch.compute.operator.DriverContext;
27-
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
2827
import org.elasticsearch.compute.operator.Operator;
2928
import org.elasticsearch.compute.operator.SourceOperator;
3029
import org.elasticsearch.compute.test.AbstractBlockSourceOperator;
@@ -43,13 +42,10 @@
4342
import java.io.IOException;
4443
import java.util.ArrayList;
4544
import java.util.List;
46-
import java.util.Map;
4745
import java.util.function.BiFunction;
4846
import java.util.function.Consumer;
49-
import java.util.function.Function;
5047
import java.util.stream.Collectors;
5148
import java.util.stream.IntStream;
52-
import java.util.stream.Stream;
5349

5450
import static org.hamcrest.Matchers.equalTo;
5551
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@@ -66,15 +62,15 @@ public class RerankOperatorTests extends OperatorTestCase {
6662
private static final String SIMPLE_QUERY = "query text";
6763
private ThreadPool threadPool;
6864
private List<ElementType> inputChannelElementTypes;
69-
private Map<String, ExpressionEvaluator.Factory> rerankFieldsEvaluatorFactories;
65+
private RowEncoder.Factory<BytesRefBlock> rowEncoderFactory;
7066
private int scoreChannel;
7167

7268
@Before
7369
private void initChannels() {
7470
int channelCount = randomIntBetween(2, 10);
7571
scoreChannel = randomIntBetween(0, channelCount - 1);
7672
inputChannelElementTypes = IntStream.range(0, channelCount).sorted().mapToObj(this::randomElementType).collect(Collectors.toList());
77-
rerankFieldsEvaluatorFactories = randomFieldEvaluators().collect(Collectors.toMap((e) -> randomIdentifier(), Function.identity()));
73+
rowEncoderFactory = mockRowEncoderFactory();
7874
}
7975

8076
@Before
@@ -94,14 +90,7 @@ public void shutdownThreadPool() {
9490
@Override
9591
protected Operator.OperatorFactory simple() {
9692
InferenceService inferenceService = mockedSimpleInferenceService();
97-
98-
return new RerankOperator.Factory(
99-
inferenceService,
100-
SIMPLE_INFERENCE_ID,
101-
SIMPLE_QUERY,
102-
rerankFieldsEvaluatorFactories,
103-
scoreChannel
104-
);
93+
return new RerankOperator.Factory(inferenceService, SIMPLE_INFERENCE_ID, SIMPLE_QUERY, rowEncoderFactory, scoreChannel);
10594
}
10695

10796
private InferenceService mockedSimpleInferenceService() {
@@ -136,31 +125,13 @@ private RankedDocsResults mockedRankedDocResults(InferenceAction.Request request
136125

137126
@Override
138127
protected Matcher<String> expectedDescriptionOfSimple() {
139-
return equalTo(
140-
"RerankOperator[inference_id=["
141-
+ SIMPLE_INFERENCE_ID
142-
+ "], query=["
143-
+ SIMPLE_QUERY
144-
+ "], rerank_fields="
145-
+ rerankFieldsEvaluatorFactories.keySet()
146-
+ ", score_channel=["
147-
+ scoreChannel
148-
+ "]]"
149-
);
128+
return expectedToStringOfSimple();
150129
}
151130

152131
@Override
153132
protected Matcher<String> expectedToStringOfSimple() {
154133
return equalTo(
155-
"RerankOperator[inference_id=["
156-
+ SIMPLE_INFERENCE_ID
157-
+ "], query=["
158-
+ SIMPLE_QUERY
159-
+ "], row_encoder=[XContentRowEncoder[content_type=[YAML], field_names="
160-
+ rerankFieldsEvaluatorFactories.keySet()
161-
+ "]], score_channel=["
162-
+ scoreChannel
163-
+ "]]"
134+
"RerankOperator[inference_id=[" + SIMPLE_INFERENCE_ID + "], query=[" + SIMPLE_QUERY + "], score_channel=[" + scoreChannel + "]]"
164135
);
165136
}
166137

@@ -257,33 +228,29 @@ private int inputChannelCount() {
257228
return inputChannelElementTypes.size();
258229
}
259230

260-
private int randomInputChannel() {
261-
return randomIntBetween(0, inputChannelCount() - 1);
262-
}
263-
264231
private ElementType randomElementType(int channel) {
265232
return channel == scoreChannel ? ElementType.DOUBLE : randomFrom(ElementType.FLOAT, ElementType.DOUBLE, ElementType.LONG);
266233
}
267234

268-
private Stream<ExpressionEvaluator.Factory> randomFieldEvaluators() {
269-
return Stream.generate(() -> randomFieldEvaluator(randomInputChannel())).limit(randomIntBetween(0, 20));
270-
}
271-
272-
private static ExpressionEvaluator.Factory randomFieldEvaluator(int channel) {
273-
return context -> new ExpressionEvaluator() {
235+
private RowEncoder.Factory<BytesRefBlock> mockRowEncoderFactory() {
236+
RowEncoder.Factory<BytesRefBlock> factory = new RowEncoder.Factory<>() {
274237
@Override
275-
public Block eval(Page page) {
276-
Block b = page.getBlock(channel);
277-
b.incRef();
278-
;
279-
return b;
280-
}
238+
public RowEncoder<BytesRefBlock> get(DriverContext context) {
239+
return new RowEncoder<BytesRefBlock>() {
240+
@Override
241+
public BytesRefBlock encodeRows(Page page) {
242+
return blockFactory().newConstantBytesRefBlockWith(new BytesRef(randomAlphaOfLength(100)), page.getPositionCount());
243+
}
281244

282-
@Override
283-
public void close() {
245+
@Override
246+
public void close() {
284247

248+
}
249+
};
285250
}
286251
};
252+
253+
return factory;
287254
}
288255

289256
private void assertExpectedScore(DoubleBlock scoreBlockResult) {

0 commit comments

Comments
 (0)