Skip to content

Commit bcde6be

Browse files
LookupJoin prejoin filter POC WIP
1 parent a995a12 commit bcde6be

File tree

7 files changed

+159
-76
lines changed

7 files changed

+159
-76
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/lookup/EnrichQuerySourceOperator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ Page buildPage(int positions, IntVector.Builder positionsBuilder, IntVector.Buil
159159
return page;
160160
}
161161

162-
private Query nextQuery() {
162+
private Query nextQuery() throws IOException {
163163
++queryPosition;
164164
while (isFinished() == false) {
165165
Query query = queryList.getQuery(queryPosition);

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/lookup/ExpressionQueryList.java

Lines changed: 0 additions & 61 deletions
This file was deleted.

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/lookup/LookupEnrichQueryGenerator.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import org.apache.lucene.search.Query;
1111
import org.elasticsearch.core.Nullable;
1212

13+
import java.io.IOException;
14+
1315
/**
1416
* An interface to generates queries for the lookup and enrich operators.
1517
* This interface is used to retrieve queries based on a position index.
@@ -20,7 +22,7 @@ public interface LookupEnrichQueryGenerator {
2022
* Returns the query at the given position.
2123
*/
2224
@Nullable
23-
Query getQuery(int position);
25+
Query getQuery(int position) throws IOException;
2426

2527
/**
2628
* Returns the number of queries in this generator
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.enrich;
9+
10+
import org.apache.lucene.search.BooleanClause;
11+
import org.apache.lucene.search.BooleanQuery;
12+
import org.apache.lucene.search.Query;
13+
import org.elasticsearch.compute.operator.lookup.LookupEnrichQueryGenerator;
14+
import org.elasticsearch.compute.operator.lookup.QueryList;
15+
import org.elasticsearch.index.query.QueryBuilder;
16+
import org.elasticsearch.index.query.SearchExecutionContext;
17+
import org.elasticsearch.xpack.esql.core.expression.Literal;
18+
19+
import java.io.IOException;
20+
import java.util.List;
21+
22+
/**
23+
* A {@link LookupEnrichQueryGenerator} that combines multiple {@link QueryList}s into a single query.
24+
* Each query in the resulting query will be a conjunction of all queries from the input lists at the same position.
25+
* In the future we can extend this to support more complex expressions, such as disjunctions or negations.
26+
*/
27+
public class ExpressionQueryList implements LookupEnrichQueryGenerator {
28+
private final List<QueryList> queryLists;
29+
private final QueryBuilder preJoinFilter;
30+
private final SearchExecutionContext context;
31+
32+
public ExpressionQueryList(List<QueryList> queryLists, SearchExecutionContext context, QueryBuilder preJoinFilter) {
33+
if (queryLists.size() < 2 && Literal.TRUE.equals(preJoinFilter)) {
34+
throw new IllegalArgumentException("ExpressionQueryList must have at least two QueryLists");
35+
}
36+
this.queryLists = queryLists;
37+
this.preJoinFilter = preJoinFilter;
38+
this.context = context;
39+
}
40+
41+
@Override
42+
public Query getQuery(int position) throws IOException {
43+
BooleanQuery.Builder builder = new BooleanQuery.Builder();
44+
for (QueryList queryList : queryLists) {
45+
Query q = queryList.getQuery(position);
46+
if (q == null) {
47+
// if any of the matchFields are null, it means there is no match for this position
48+
// A AND NULL is always NULL, so we can skip this position
49+
return null;
50+
}
51+
builder.add(q, BooleanClause.Occur.FILTER);
52+
}
53+
// also attach the pre-join filter if it exists
54+
/*if (Literal.TRUE.equals(preJoinFilter) == false) {
55+
if (preJoinFilter instanceof TranslationAware translationAware) {
56+
Query preJoinQuery = tryToGetAsLuceneQuery(translationAware);
57+
if (preJoinQuery == null) {
58+
preJoinQuery = tryToGetThroughQueryBuilder(translationAware);
59+
}
60+
if (preJoinQuery == null) {
61+
throw new UnsupportedOperationException("Cannot translate pre-join filter to Lucene query: " + preJoinFilter);
62+
}
63+
builder.add(preJoinQuery, BooleanClause.Occur.FILTER);
64+
}
65+
}*/
66+
if (preJoinFilter != null) {
67+
// JULIAN TO DO: Can we precompile the query? I don't want to call toQuery for every row
68+
builder.add(preJoinFilter.toQuery(context), BooleanClause.Occur.FILTER);
69+
}
70+
return builder.build();
71+
}
72+
73+
/*private Query tryToGetThroughQueryBuilder(TranslationAware translationAware) {
74+
// it seems I might need to pass a QueryBuilder, instead of Expression directly????
75+
// can a QueryBuilder support nested complex expressions with AND, OR, NOT?
76+
return translationAware.asQuery(WHAT_GOES_HERE, WHAT_GOES_HERE).toQueryBuilder().toQuery(queryLists.get(0).searchExecutionContext);
77+
}
78+
79+
private Query tryToGetAsLuceneQuery(TranslationAware translationAware) {
80+
// attempt to translate directly to a Lucene Query
81+
// not sure how to get the field name from the expression
82+
MappedFieldType fieldType = context.getFieldType(WHAT_GOES_HERE.fieldName().string());
83+
try {
84+
return translationAware.asLuceneQuery(fieldType, CONSTANT_SCORE_REWRITE, context);
85+
} catch (Exception e) {}
86+
// only a few expression types support asLuceneQuery, it is OK to fail here and we will try a different approach
87+
return null;
88+
}
89+
*/
90+
@Override
91+
public int getPositionCount() {
92+
int positionCount = queryLists.get(0).getPositionCount();
93+
for (QueryList queryList : queryLists) {
94+
if (queryList.getPositionCount() != positionCount) {
95+
throw new IllegalArgumentException(
96+
"All QueryLists must have the same position count, expected: "
97+
+ positionCount
98+
+ ", but got: "
99+
+ queryList.getPositionCount()
100+
);
101+
}
102+
}
103+
return positionCount;
104+
}
105+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/LookupFromIndexOperator.java

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.elasticsearch.compute.operator.lookup.RightChunkedLeftJoin;
2222
import org.elasticsearch.core.Releasable;
2323
import org.elasticsearch.core.Releasables;
24+
import org.elasticsearch.index.query.QueryBuilder;
2425
import org.elasticsearch.tasks.CancellableTask;
2526
import org.elasticsearch.xcontent.XContentBuilder;
2627
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
@@ -46,7 +47,8 @@ public record Factory(
4647
String lookupIndexPattern,
4748
String lookupIndex,
4849
List<NamedExpression> loadFields,
49-
Source source
50+
Source source,
51+
QueryBuilder preJoinFilter
5052
) implements OperatorFactory {
5153
@Override
5254
public String describe() {
@@ -60,6 +62,7 @@ public String describe() {
6062
.append(" inputChannel=")
6163
.append(matchField.channel());
6264
}
65+
stringBuilder.append(" pre_join_filter=").append(preJoinFilter);
6366
stringBuilder.append("]");
6467
return stringBuilder.toString();
6568
}
@@ -76,7 +79,8 @@ public Operator get(DriverContext driverContext) {
7679
lookupIndexPattern,
7780
lookupIndex,
7881
loadFields,
79-
source
82+
source,
83+
preJoinFilter
8084
);
8185
}
8286
}
@@ -90,6 +94,7 @@ public Operator get(DriverContext driverContext) {
9094
private final Source source;
9195
private long totalRows = 0L;
9296
private List<MatchConfig> matchFields;
97+
private QueryBuilder preJoinFilter;
9398
/**
9499
* Total number of pages emitted by this {@link Operator}.
95100
*/
@@ -109,7 +114,8 @@ public LookupFromIndexOperator(
109114
String lookupIndexPattern,
110115
String lookupIndex,
111116
List<NamedExpression> loadFields,
112-
Source source
117+
Source source,
118+
QueryBuilder preJoinFilter
113119
) {
114120
super(driverContext, lookupService.getThreadContext(), maxOutstandingRequests);
115121
this.matchFields = matchFields;
@@ -120,6 +126,7 @@ public LookupFromIndexOperator(
120126
this.lookupIndex = lookupIndex;
121127
this.loadFields = loadFields;
122128
this.source = source;
129+
this.preJoinFilter = preJoinFilter;
123130
}
124131

125132
@Override
@@ -146,7 +153,8 @@ protected void performAsync(Page inputPage, ActionListener<OngoingJoin> listener
146153
newMatchFields,
147154
new Page(inputBlockArray),
148155
loadFields,
149-
source
156+
source,
157+
preJoinFilter
150158
);
151159
lookupService.lookupAsync(
152160
request,
@@ -203,6 +211,7 @@ public String toString() {
203211
.append(" inputChannel=")
204212
.append(matchField.channel());
205213
}
214+
stringBuilder.append(" pre_join_filter=").append(preJoinFilter);
206215
stringBuilder.append("]");
207216
return stringBuilder.toString();
208217
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/LookupFromIndexService.java

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
import org.elasticsearch.compute.data.BlockStreamInput;
2121
import org.elasticsearch.compute.data.Page;
2222
import org.elasticsearch.compute.operator.Warnings;
23-
import org.elasticsearch.compute.operator.lookup.ExpressionQueryList;
2423
import org.elasticsearch.compute.operator.lookup.LookupEnrichQueryGenerator;
2524
import org.elasticsearch.compute.operator.lookup.QueryList;
2625
import org.elasticsearch.core.Releasables;
26+
import org.elasticsearch.index.query.QueryBuilder;
2727
import org.elasticsearch.index.query.SearchExecutionContext;
2828
import org.elasticsearch.index.shard.ShardId;
2929
import org.elasticsearch.indices.IndicesService;
@@ -88,7 +88,8 @@ protected TransportRequest transportRequest(LookupFromIndexService.Request reque
8888
null,
8989
request.extractFields,
9090
request.matchFields,
91-
request.source
91+
request.source,
92+
request.preJoinFilter
9293
);
9394
}
9495

@@ -112,10 +113,10 @@ protected LookupEnrichQueryGenerator queryList(
112113
).onlySingleValues(warnings, "LOOKUP JOIN encountered multi-value");
113114
queryLists.add(q);
114115
}
115-
if (queryLists.size() == 1) {
116+
if (queryLists.size() == 1 && request.preJoinFilter == null) {
116117
return queryLists.getFirst();
117118
}
118-
return new ExpressionQueryList(queryLists);
119+
return new ExpressionQueryList(queryLists, context, request.preJoinFilter);
119120
}
120121

121122
@Override
@@ -130,6 +131,7 @@ protected AbstractLookupService.LookupResponse readLookupResponse(StreamInput in
130131

131132
public static class Request extends AbstractLookupService.Request {
132133
private final List<MatchConfig> matchFields;
134+
private final QueryBuilder preJoinFilter;
133135

134136
Request(
135137
String sessionId,
@@ -138,15 +140,18 @@ public static class Request extends AbstractLookupService.Request {
138140
List<MatchConfig> matchFields,
139141
Page inputPage,
140142
List<NamedExpression> extractFields,
141-
Source source
143+
Source source,
144+
QueryBuilder preJoinFilter
142145
) {
143146
super(sessionId, index, indexPattern, matchFields.get(0).type(), inputPage, extractFields, source);
144147
this.matchFields = matchFields;
148+
this.preJoinFilter = preJoinFilter;
145149
}
146150
}
147151

148152
protected static class TransportRequest extends AbstractLookupService.TransportRequest {
149153
private final List<MatchConfig> matchFields;
154+
private final QueryBuilder preJoinFilter;
150155

151156
// Right now we assume that the page contains the same number of blocks as matchFields and that the blocks are in the same order
152157
// The channel information inside the MatchConfig, should say the same thing
@@ -158,10 +163,12 @@ protected static class TransportRequest extends AbstractLookupService.TransportR
158163
Page toRelease,
159164
List<NamedExpression> extractFields,
160165
List<MatchConfig> matchFields,
161-
Source source
166+
Source source,
167+
QueryBuilder preJoinFilter
162168
) {
163169
super(sessionId, shardId, indexPattern, inputPage, toRelease, extractFields, source);
164170
this.matchFields = matchFields;
171+
this.preJoinFilter = preJoinFilter;
165172
}
166173

167174
static TransportRequest readFrom(StreamInput in, BlockFactory blockFactory) throws IOException {
@@ -207,6 +214,10 @@ static TransportRequest readFrom(StreamInput in, BlockFactory blockFactory) thro
207214
String sourceText = in.readString();
208215
source = new Source(source.source(), sourceText);
209216
}
217+
QueryBuilder preJoinFilter = null;
218+
if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_LOOKUP_JOIN_ON_MANY_FIELDS)) {
219+
preJoinFilter = planIn.readOptionalNamedWriteable(QueryBuilder.class);
220+
}
210221
TransportRequest result = new TransportRequest(
211222
sessionId,
212223
shardId,
@@ -215,7 +226,8 @@ static TransportRequest readFrom(StreamInput in, BlockFactory blockFactory) thro
215226
inputPage,
216227
extractFields,
217228
matchFields,
218-
source
229+
source,
230+
preJoinFilter
219231
);
220232
result.setParentTask(parentTaskId);
221233
return result;
@@ -258,11 +270,19 @@ public void writeTo(StreamOutput out) throws IOException {
258270
if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_LOOKUP_JOIN_SOURCE_TEXT)) {
259271
out.writeString(source.text());
260272
}
273+
if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_LOOKUP_JOIN_ON_MANY_FIELDS)) {
274+
planOut.writeOptionalNamedWriteable(preJoinFilter);
275+
} else if (preJoinFilter != null) {
276+
throw new EsqlIllegalArgumentException("LOOKUP JOIN with pre-join filter is not supported on remote node");
277+
}
261278
}
262279

263280
@Override
264281
protected String extraDescription() {
265-
return " ,match_fields=" + matchFields.stream().map(x -> x.fieldName().string()).collect(Collectors.joining(", "));
282+
return " ,match_fields="
283+
+ matchFields.stream().map(x -> x.fieldName().string()).collect(Collectors.joining(", "))
284+
+ ", pre_join_filter="
285+
+ preJoinFilter;
266286
}
267287
}
268288

0 commit comments

Comments
 (0)