Skip to content

Commit f147a74

Browse files
committed
Improved XContentRowEncoder to support most element types correctly.
1 parent 330160d commit f147a74

File tree

5 files changed

+76
-60
lines changed

5 files changed

+76
-60
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/PositionToXContent.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.spatialToString;
3434
import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.versionToString;
3535

36-
abstract class PositionToXContent {
36+
public abstract class PositionToXContent {
3737
protected final Block block;
3838

3939
PositionToXContent(Block block) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/RerankOperator.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.compute.data.Page;
1818
import org.elasticsearch.compute.operator.AsyncOperator;
1919
import org.elasticsearch.compute.operator.DriverContext;
20+
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
2021
import org.elasticsearch.compute.operator.Operator;
2122
import org.elasticsearch.core.Releasables;
2223
import org.elasticsearch.inference.TaskType;
@@ -34,7 +35,7 @@ public record Factory(
3435
InferenceService inferenceService,
3536
String inferenceId,
3637
String queryText,
37-
RowEncoder.Factory<BytesRefBlock> rowEncoderFactory,
38+
ExpressionEvaluator.Factory rowEncoderFactory,
3839
int scoreChannel
3940
) implements OperatorFactory {
4041

@@ -60,15 +61,15 @@ public Operator get(DriverContext driverContext) {
6061
private final BlockFactory blockFactory;
6162
private final String inferenceId;
6263
private final String queryText;
63-
private final RowEncoder<BytesRefBlock> rowEncoder;
64+
private final ExpressionEvaluator rowEncoder;
6465
private final int scoreChannel;
6566

6667
public RerankOperator(
6768
DriverContext driverContext,
6869
InferenceService inferenceService,
6970
String inferenceId,
7071
String queryText,
71-
RowEncoder<BytesRefBlock> rowEncoder,
72+
ExpressionEvaluator rowEncoder,
7273
int scoreChannel
7374
) {
7475
super(driverContext, inferenceService.getThreadContext(), MAX_INFERENCE_WORKER);
@@ -177,7 +178,7 @@ private Block buildScoreBlock(Page inputPage, RankedDocsResults rankedDocsResult
177178
}
178179

179180
private InferenceAction.Request buildInferenceRequest(Page inputPage) {
180-
try (BytesRefBlock encodedRowsBlock = rowEncoder.encodeRows(inputPage)) {
181+
try (BytesRefBlock encodedRowsBlock = (BytesRefBlock) rowEncoder.eval(inputPage)) {
181182
assert (encodedRowsBlock.getPositionCount() == inputPage.getPositionCount());
182183
String[] inputs = new String[inputPage.getPositionCount()];
183184
BytesRef buffer = new BytesRef();

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/XContentRowEncoder.java

Lines changed: 53 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,41 +11,57 @@
1111
import org.elasticsearch.common.io.stream.BytesRefStreamOutput;
1212
import org.elasticsearch.compute.data.Block;
1313
import org.elasticsearch.compute.data.BlockFactory;
14-
import org.elasticsearch.compute.data.BlockUtils;
1514
import org.elasticsearch.compute.data.BytesRefBlock;
1615
import org.elasticsearch.compute.data.Page;
1716
import org.elasticsearch.compute.operator.DriverContext;
1817
import org.elasticsearch.compute.operator.EvalOperator;
18+
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
1919
import org.elasticsearch.core.Releasables;
20+
import org.elasticsearch.xcontent.ToXContent;
2021
import org.elasticsearch.xcontent.XContentBuilder;
2122
import org.elasticsearch.xcontent.XContentFactory;
2223
import org.elasticsearch.xcontent.XContentType;
24+
import org.elasticsearch.xpack.esql.action.ColumnInfoImpl;
25+
import org.elasticsearch.xpack.esql.action.PositionToXContent;
2326

2427
import java.io.IOException;
2528
import java.io.UncheckedIOException;
29+
import java.util.Arrays;
2630
import java.util.List;
2731
import java.util.Map;
32+
import java.util.stream.Collectors;
2833

29-
public class XContentRowEncoder implements RowEncoder<BytesRefBlock> {
34+
35+
/**
36+
* Encodes rows into an XContent format (JSON,YAML,...) for further processing.
37+
* Extracted columns can be specified using {@link EvalOperator}ExpressionEvaluator}
38+
*/
39+
public class XContentRowEncoder implements ExpressionEvaluator {
3040
private final XContentType xContentType;
3141
private final BlockFactory blockFactory;
32-
private final String[] fieldNames;
33-
private final EvalOperator.ExpressionEvaluator[] fieldsValueEvaluators;
34-
35-
public static Factory yamlRowEncoderFactory(Map<String, EvalOperator.ExpressionEvaluator.Factory> fieldsEvaluatorFactories) {
42+
private final ColumnInfoImpl[] columnsInfo;
43+
private final ExpressionEvaluator[] fieldsValueEvaluators;
44+
45+
/**
46+
* Creates a factory for YAML XContent row encoding.
47+
*
48+
* @param fieldsEvaluatorFactories A map of column information to expression evaluators.
49+
* @return A Factory instance for creating YAML row encoder for the specified column.
50+
*/
51+
public static Factory yamlRowEncoderFactory(Map<ColumnInfoImpl, ExpressionEvaluator.Factory> fieldsEvaluatorFactories) {
3652
return new Factory(XContentType.YAML, fieldsEvaluatorFactories);
3753
}
3854

3955
private XContentRowEncoder(
4056
XContentType xContentType,
4157
BlockFactory blockFactory,
42-
String[] fieldNames,
43-
EvalOperator.ExpressionEvaluator[] fieldsValueEvaluators
58+
ColumnInfoImpl[] columnsInfo,
59+
ExpressionEvaluator[] fieldsValueEvaluators
4460
) {
45-
assert fieldNames.length == fieldsValueEvaluators.length;
61+
assert columnsInfo.length == fieldsValueEvaluators.length;
4662
this.xContentType = xContentType;
4763
this.blockFactory = blockFactory;
48-
this.fieldNames = fieldNames;
64+
this.columnsInfo = columnsInfo;
4965
this.fieldsValueEvaluators = fieldsValueEvaluators;
5066
}
5167

@@ -54,27 +70,36 @@ public void close() {
5470
Releasables.closeExpectNoException(fieldsValueEvaluators);
5571
}
5672

73+
/**
74+
* Process the provided Page and encode its rows into a BytesRefBlock containing XContent-formatted rows.
75+
*
76+
* @param page The input Page containing row data.
77+
* @return A BytesRefBlock containing the encoded rows.
78+
*/
5779
@Override
58-
public BytesRefBlock encodeRows(Page page) {
80+
public BytesRefBlock eval(Page page) {
5981
Block[] fieldValueBlocks = new Block[fieldsValueEvaluators.length];
6082
try (
6183
BytesRefStreamOutput outputStream = new BytesRefStreamOutput();
6284
XContentBuilder xContentBuilder = XContentFactory.contentBuilder(xContentType, outputStream);
6385
BytesRefBlock.Builder outputBlockBuilder = blockFactory.newBytesRefBlockBuilder(page.getPositionCount());
6486
) {
87+
88+
PositionToXContent[] toXContents = new PositionToXContent[fieldsValueEvaluators.length];
6589
for (int b = 0; b < fieldValueBlocks.length; b++) {
6690
fieldValueBlocks[b] = fieldsValueEvaluators[b].eval(page);
91+
toXContents[b] = PositionToXContent.positionToXContent(columnsInfo[b], fieldValueBlocks[b], new BytesRef());
6792
}
6893

6994
for (int pos = 0; pos < page.getPositionCount(); pos++) {
7095
xContentBuilder.startObject();
7196
for (int i = 0; i < fieldValueBlocks.length; i++) {
72-
String fieldName = fieldNames[i];
97+
String fieldName = columnsInfo[i].name();
7398
Block currentBlock = fieldValueBlocks[i];
74-
if (currentBlock.isNull(pos)) {
99+
if (currentBlock.isNull(pos) || currentBlock.getValueCount(pos) < 1) {
75100
continue;
76101
}
77-
xContentBuilder.field(fieldName, toYamlValue(BlockUtils.toJavaObject(currentBlock, pos)));
102+
toXContents[i].positionToXContent(xContentBuilder.field(fieldName), ToXContent.EMPTY_PARAMS, pos);
78103
}
79104
xContentBuilder.endObject().flush();
80105
outputBlockBuilder.appendBytesRef(outputStream.get());
@@ -89,46 +114,37 @@ public BytesRefBlock encodeRows(Page page) {
89114
}
90115
}
91116

92-
@Override
93-
public String toString() {
94-
return "XContentRowEncoder[content_type=[" + xContentType.toString() + "], field_names=" + List.of(fieldNames) + "]";
117+
public List<String> fieldNames() {
118+
return Arrays.stream(columnsInfo).map(ColumnInfoImpl::name).collect(Collectors.toList());
95119
}
96120

97-
private Object toYamlValue(Object value) {
98-
try {
99-
return switch (value) {
100-
case BytesRef b -> b.utf8ToString();
101-
case List<?> l -> l.stream().map(this::toYamlValue).toList();
102-
default -> value;
103-
};
104-
} catch (Error | Exception e) {
105-
// Swallow errors caused by invalid byteref.
106-
return "";
107-
}
121+
@Override
122+
public String toString() {
123+
return "XContentRowEncoder[content_type=[" + xContentType.toString() + "], field_names=" + fieldNames() + "]";
108124
}
109125

110-
public static final class Factory implements RowEncoder.Factory<BytesRefBlock> {
126+
public static class Factory implements ExpressionEvaluator.Factory {
111127
private final XContentType xContentType;
112-
private final Map<String, EvalOperator.ExpressionEvaluator.Factory> fieldsEvaluatorFactories;
128+
private final Map<ColumnInfoImpl, ExpressionEvaluator.Factory> fieldsEvaluatorFactories;
113129

114-
private Factory(XContentType xContentType, Map<String, EvalOperator.ExpressionEvaluator.Factory> fieldsEvaluatorFactories) {
130+
private Factory(XContentType xContentType, Map<ColumnInfoImpl, ExpressionEvaluator.Factory> fieldsEvaluatorFactories) {
115131
this.xContentType = xContentType;
116132
this.fieldsEvaluatorFactories = fieldsEvaluatorFactories;
117133
}
118134

119-
public RowEncoder<BytesRefBlock> get(DriverContext context) {
120-
return new XContentRowEncoder(xContentType, context.blockFactory(), fieldNames(), fieldsValueEvaluators(context));
135+
public XContentRowEncoder get(DriverContext context) {
136+
return new XContentRowEncoder(xContentType, context.blockFactory(), columnsInfo(), fieldsValueEvaluators(context));
121137
}
122138

123-
private String[] fieldNames() {
124-
return fieldsEvaluatorFactories.keySet().toArray(String[]::new);
139+
private ColumnInfoImpl[] columnsInfo() {
140+
return fieldsEvaluatorFactories.keySet().toArray(ColumnInfoImpl[]::new);
125141
}
126142

127-
private EvalOperator.ExpressionEvaluator[] fieldsValueEvaluators(DriverContext context) {
143+
private ExpressionEvaluator[] fieldsValueEvaluators(DriverContext context) {
128144
return fieldsEvaluatorFactories.values()
129145
.stream()
130146
.map(factory -> factory.get(context))
131-
.toArray(EvalOperator.ExpressionEvaluator[]::new);
147+
.toArray(ExpressionEvaluator[]::new);
132148
}
133149
}
134150
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
import org.elasticsearch.node.Node;
6060
import org.elasticsearch.tasks.CancellableTask;
6161
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
62+
import org.elasticsearch.xpack.esql.action.ColumnInfoImpl;
6263
import org.elasticsearch.xpack.esql.core.expression.Alias;
6364
import org.elasticsearch.xpack.esql.core.expression.Attribute;
6465
import org.elasticsearch.xpack.esql.core.expression.Expression;
@@ -558,11 +559,11 @@ private PhysicalOperation planEnrich(EnrichExec enrich, LocalExecutionPlannerCon
558559
private PhysicalOperation planRerank(RerankExec rerank, LocalExecutionPlannerContext context) {
559560
PhysicalOperation source = plan(rerank.child(), context);
560561

561-
Map<String, EvalOperator.ExpressionEvaluator.Factory> rerankFieldsEvaluatorSuppliers = new LinkedHashMap<>();
562+
Map<ColumnInfoImpl, EvalOperator.ExpressionEvaluator.Factory> rerankFieldsEvaluatorSuppliers = new LinkedHashMap<>();
562563

563564
for (var rerankField : rerank.rerankFields()) {
564565
rerankFieldsEvaluatorSuppliers.put(
565-
rerankField.name(),
566+
new ColumnInfoImpl(rerankField.name(), rerankField.dataType()),
566567
EvalMapper.toEvaluator(context.foldCtx(), rerankField.child(), source.layout)
567568
);
568569
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/RerankOperatorTests.java

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ public class RerankOperatorTests extends OperatorTestCase {
6262
private static final String SIMPLE_QUERY = "query text";
6363
private ThreadPool threadPool;
6464
private List<ElementType> inputChannelElementTypes;
65-
private RowEncoder.Factory<BytesRefBlock> rowEncoderFactory;
65+
private XContentRowEncoder.Factory rowEncoderFactory;
6666
private int scoreChannel;
6767

6868
@Before
@@ -232,23 +232,21 @@ private ElementType randomElementType(int channel) {
232232
return channel == scoreChannel ? ElementType.DOUBLE : randomFrom(ElementType.FLOAT, ElementType.DOUBLE, ElementType.LONG);
233233
}
234234

235-
private RowEncoder.Factory<BytesRefBlock> mockRowEncoderFactory() {
236-
RowEncoder.Factory<BytesRefBlock> factory = new RowEncoder.Factory<>() {
237-
@Override
238-
public RowEncoder<BytesRefBlock> get(DriverContext context) {
239-
return new RowEncoder<BytesRefBlock>() {
240-
@Override
241-
public BytesRefBlock encodeRows(Page page) {
242-
return blockFactory().newConstantBytesRefBlockWith(new BytesRef(randomAlphaOfLength(100)), page.getPositionCount());
243-
}
235+
private XContentRowEncoder.Factory mockRowEncoderFactory() {
236+
XContentRowEncoder.Factory factory = mock(XContentRowEncoder.Factory.class);
237+
doAnswer(factoryInvocation -> {
238+
DriverContext driverContext = factoryInvocation.getArgument(0, DriverContext.class);
239+
XContentRowEncoder rowEncoder = mock(XContentRowEncoder.class);
240+
doAnswer(
241+
encoderInvocation -> {
242+
Page inputPage = encoderInvocation.getArgument(0, Page.class);
243+
return driverContext.blockFactory().newConstantBytesRefBlockWith(new BytesRef(randomRealisticUnicodeOfCodepointLength(4)), inputPage.getPositionCount());
244+
}
245+
).when(rowEncoder).eval(any(Page.class));
244246

245-
@Override
246-
public void close() {
247+
return rowEncoder;
248+
}).when(factory).get(any(DriverContext.class));
247249

248-
}
249-
};
250-
}
251-
};
252250

253251
return factory;
254252
}

0 commit comments

Comments
 (0)