Skip to content

Commit 88d7925

Browse files
WIP POC for expression join
1 parent d32cdc4 commit 88d7925

File tree

17 files changed

+405
-58
lines changed

17 files changed

+405
-58
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/lookup/LookupEnrichQueryGenerator.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import org.apache.lucene.search.Query;
1111
import org.elasticsearch.core.Nullable;
1212

13+
import java.io.IOException;
14+
1315
/**
1416
* An interface to generates queries for the lookup and enrich operators.
1517
* This interface is used to retrieve queries based on a position index.
@@ -20,7 +22,7 @@ public interface LookupEnrichQueryGenerator {
2022
* Returns the query at the given position.
2123
*/
2224
@Nullable
23-
Query getQuery(int position);
25+
Query getQuery(int position) throws IOException;
2426

2527
/**
2628
* Returns the number of queries in this generator

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/lookup/QueryList.java

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ public int getPositionCount() {
8989
public abstract QueryList onlySingleValues(Warnings warnings, String multiValueWarningMessage);
9090

9191
@Override
92-
public final Query getQuery(int position) {
92+
public final Query getQuery(int position) throws IOException {
9393
final int valueCount = block.getValueCount(position);
9494
if (onlySingleValueParams != null && valueCount != 1) {
9595
if (valueCount > 1) {
@@ -125,7 +125,7 @@ public final Query getQuery(int position) {
125125
* Returns the query at the given position.
126126
*/
127127
@Nullable
128-
abstract Query doGetQuery(int position, int firstValueIndex, int valueCount);
128+
public abstract Query doGetQuery(int position, int firstValueIndex, int valueCount) throws IOException;
129129

130130
private Query wrapSingleValueQuery(Query query) {
131131
assert onlySingleValueParams != null : "Requested to wrap single value query without single value params";
@@ -159,13 +159,8 @@ private Query wrapSingleValueQuery(Query query) {
159159
* using only the {@link ElementType} of the {@link Block} to determine the
160160
* query.
161161
*/
162-
public static QueryList rawTermQueryList(
163-
MappedFieldType field,
164-
SearchExecutionContext searchExecutionContext,
165-
AliasFilter aliasFilter,
166-
Block block
167-
) {
168-
IntFunction<Object> blockToJavaObject = switch (block.elementType()) {
162+
public static IntFunction<Object> createBlockValueReader(Block block) {
163+
return switch (block.elementType()) {
169164
case BOOLEAN -> {
170165
BooleanBlock booleanBlock = (BooleanBlock) block;
171166
yield booleanBlock::getBoolean;
@@ -196,7 +191,20 @@ public static QueryList rawTermQueryList(
196191
case AGGREGATE_METRIC_DOUBLE -> throw new IllegalArgumentException("can't read values from [aggregate metric double] block");
197192
case UNKNOWN -> throw new IllegalArgumentException("can't read values from [" + block + "]");
198193
};
199-
return new TermQueryList(field, searchExecutionContext, aliasFilter, block, null, blockToJavaObject);
194+
}
195+
196+
/**
197+
* Returns a list of term queries for the given field and the input block
198+
* using only the {@link ElementType} of the {@link Block} to determine the
199+
* query.
200+
*/
201+
public static QueryList rawTermQueryList(
202+
MappedFieldType field,
203+
SearchExecutionContext searchExecutionContext,
204+
AliasFilter aliasFilter,
205+
Block block
206+
) {
207+
return new TermQueryList(field, searchExecutionContext, aliasFilter, block, null, createBlockValueReader(block));
200208
}
201209

202210
/**
@@ -297,7 +305,7 @@ public TermQueryList onlySingleValues(Warnings warnings, String multiValueWarnin
297305
}
298306

299307
@Override
300-
Query doGetQuery(int position, int firstValueIndex, int valueCount) {
308+
public Query doGetQuery(int position, int firstValueIndex, int valueCount) {
301309
return switch (valueCount) {
302310
case 0 -> null;
303311
case 1 -> field.termQuery(blockValueReader.apply(firstValueIndex), searchExecutionContext);
@@ -360,7 +368,7 @@ public DateNanosQueryList onlySingleValues(Warnings warnings, String multiValueW
360368
}
361369

362370
@Override
363-
Query doGetQuery(int position, int firstValueIndex, int valueCount) {
371+
public Query doGetQuery(int position, int firstValueIndex, int valueCount) {
364372
return switch (valueCount) {
365373
case 0 -> null;
366374
case 1 -> dateFieldType.equalityQuery(blockValueReader.apply(firstValueIndex), searchExecutionContext);
@@ -412,7 +420,7 @@ public GeoShapeQueryList onlySingleValues(Warnings warnings, String multiValueWa
412420
}
413421

414422
@Override
415-
Query doGetQuery(int position, int firstValueIndex, int valueCount) {
423+
public Query doGetQuery(int position, int firstValueIndex, int valueCount) {
416424
return switch (valueCount) {
417425
case 0 -> null;
418426
case 1 -> shapeQuery.apply(firstValueIndex);

x-pack/plugin/esql/qa/testFixtures/src/main/resources/lookup-join.csv-spec

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2006,6 +2006,100 @@ language_code_float:integer | language_name:keyword
20062006
127 | max_byte
20072007
;
20082008

2009+
lookupWithExpressionEquals
2010+
required_capability: join_lookup_v12
2011+
required_capability: lookup_join_on_expression
2012+
2013+
FROM employees
2014+
| WHERE emp_no == 10001
2015+
| LOOKUP JOIN languages_lookup ON languages == language_code
2016+
| KEEP emp_no, languages, language_code, language_name
2017+
;
2018+
2019+
emp_no:integer | languages:integer | language_code:integer | language_name:keyword
2020+
10001 | 2 | 2 | French
2021+
;
2022+
2023+
lookupWithExpressionNotEquals
2024+
required_capability: join_lookup_v12
2025+
required_capability: lookup_join_on_expression
2026+
2027+
FROM employees
2028+
| WHERE emp_no == 10001
2029+
| LOOKUP JOIN languages_lookup ON languages != language_code
2030+
| KEEP emp_no, languages, language_code, language_name
2031+
| SORT language_code
2032+
;
2033+
2034+
emp_no:integer | languages:integer | language_code:integer | language_name:keyword
2035+
10001 | 2 | 1 | English
2036+
10001 | 2 | 3 | Spanish
2037+
10001 | 2 | 4 | German
2038+
;
2039+
2040+
lookupWithExpressionGreater
2041+
required_capability: join_lookup_v12
2042+
required_capability: lookup_join_on_expression
2043+
2044+
FROM employees
2045+
| WHERE emp_no == 10001
2046+
| LOOKUP JOIN languages_lookup ON languages > language_code
2047+
| KEEP emp_no, languages, language_code, language_name
2048+
;
2049+
2050+
emp_no:integer | languages:integer | language_code:integer | language_name:keyword
2051+
10001 | 2 | 1 | English
2052+
;
2053+
2054+
lookupWithExpressionGreaterOrEquals
2055+
required_capability: join_lookup_v12
2056+
required_capability: lookup_join_on_expression
2057+
2058+
FROM employees
2059+
| WHERE emp_no == 10001
2060+
| LOOKUP JOIN languages_lookup ON languages >= language_code
2061+
| KEEP emp_no, languages, language_code, language_name
2062+
| SORT language_code
2063+
;
2064+
2065+
emp_no:integer | languages:integer | language_code:integer | language_name:keyword
2066+
10001 | 2 | 1 | English
2067+
10001 | 2 | 2 | French
2068+
;
2069+
2070+
lookupWithExpressionLess
2071+
required_capability: join_lookup_v12
2072+
required_capability: lookup_join_on_expression
2073+
2074+
FROM employees
2075+
| WHERE emp_no == 10001
2076+
| LOOKUP JOIN languages_lookup ON languages < language_code
2077+
| KEEP emp_no, languages, language_code, language_name
2078+
| SORT language_code
2079+
;
2080+
2081+
emp_no:integer | languages:integer | language_code:integer | language_name:keyword
2082+
10001 | 2 | 3 | Spanish
2083+
10001 | 2 | 4 | German
2084+
;
2085+
2086+
lookupWithExpressionLessOrEquals
2087+
required_capability: join_lookup_v12
2088+
required_capability: lookup_join_on_expression
2089+
2090+
FROM employees
2091+
| WHERE emp_no == 10001
2092+
| LOOKUP JOIN languages_lookup ON languages <= language_code
2093+
| KEEP emp_no, languages, language_code, language_name
2094+
| SORT language_code
2095+
;
2096+
2097+
emp_no:integer | languages:integer | language_code:integer | language_name:keyword
2098+
10001 | 2 | 2 | French
2099+
10001 | 2 | 3 | Spanish
2100+
10001 | 2 | 4 | German
2101+
;
2102+
20092103
byteJoinDouble
20102104
required_capability: join_lookup_v12
20112105
required_capability: lookup_join_on_mixed_numeric_fields

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
import org.elasticsearch.xpack.esql.expression.function.vector.VectorFunction;
8686
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.DateTimeArithmeticOperation;
8787
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation;
88+
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison;
8889
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
8990
import org.elasticsearch.xpack.esql.index.EsIndex;
9091
import org.elasticsearch.xpack.esql.index.IndexResolution;
@@ -704,10 +705,25 @@ private LogicalPlan resolveLookup(Lookup l, List<Attribute> childrenOutput) {
704705
return l;
705706
}
706707

708+
private List<Expression> resolveJoinFilters(List<Expression> filters, List<Attribute> leftOutput, List<Attribute> rightOutput) {
709+
if (filters.isEmpty()) {
710+
return emptyList();
711+
}
712+
List<Attribute> childrenOutput = new ArrayList<>(leftOutput);
713+
childrenOutput.addAll(rightOutput);
714+
715+
List<Expression> resolvedFilters = new ArrayList<>(filters.size());
716+
for (Expression filter : filters) {
717+
resolvedFilters.add(filter.transformUp(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput)));
718+
}
719+
return resolvedFilters;
720+
}
721+
707722
private Join resolveLookupJoin(LookupJoin join) {
708723
JoinConfig config = join.config();
709724
// for now, support only (LEFT) USING clauses
710725
JoinType type = config.type();
726+
711727
// rewrite the join into an equi-join between the field with the same name between left and right
712728
if (type instanceof UsingJoinType using) {
713729
List<Attribute> cols = using.columns();
@@ -727,12 +743,33 @@ private Join resolveLookupJoin(LookupJoin join) {
727743
);
728744
return join.withConfig(new JoinConfig(type, singletonList(errorAttribute), emptyList(), emptyList()));
729745
}
730-
// resolve the using columns against the left and the right side then assemble the new join config
731-
List<Attribute> leftKeys = resolveUsingColumns(cols, join.left().output(), "left");
732-
List<Attribute> rightKeys = resolveUsingColumns(cols, join.right().output(), "right");
746+
List<Attribute> leftKeys = new ArrayList<>();
747+
List<Attribute> rightKeys = new ArrayList<>();
748+
List<Expression> resolvedFilters = new ArrayList<>();
749+
List<Attribute> matchKeys;
750+
if (join.candidateRightHandFilters().isEmpty() == false) {
751+
resolvedFilters = resolveJoinFilters(join.candidateRightHandFilters(), join.left().output(), join.right().output());
752+
// build leftKeys and rightKeys using the left side of the resolvedFilters.
753+
for (Expression expression : resolvedFilters) {
754+
if (expression instanceof EsqlBinaryComparison binaryComparison) {
755+
leftKeys.add((Attribute) binaryComparison.left());
756+
rightKeys.add((Attribute) binaryComparison.right());
757+
} else {
758+
throw new EsqlIllegalArgumentException("Unsupported join filter expression: " + expression);
759+
}
760+
}
761+
Set<Attribute> matchKeysSet = new HashSet<>(leftKeys);
762+
matchKeysSet.addAll(rightKeys);
763+
matchKeys = new ArrayList<>(matchKeysSet);
764+
} else {
765+
// resolve the using columns against the left and the right side then assemble the new join config
766+
leftKeys = resolveUsingColumns(cols, join.left().output(), "left");
767+
rightKeys = resolveUsingColumns(cols, join.right().output(), "right");
768+
matchKeys = leftKeys;
769+
}
733770

734-
config = new JoinConfig(coreJoin, leftKeys, leftKeys, rightKeys);
735-
join = new LookupJoin(join.source(), join.left(), join.right(), config, join.isRemote());
771+
config = new JoinConfig(coreJoin, matchKeys, leftKeys, rightKeys);
772+
return new LookupJoin(join.source(), join.left(), join.right(), config, join.isRemote(), resolvedFilters);
736773
} else if (type != JoinTypes.LEFT) {
737774
// everything else is unsupported for now
738775
// LEFT can only happen by being mapped from a USING above. So we need to exclude this as well because this rule can be run
@@ -741,6 +778,7 @@ private Join resolveLookupJoin(LookupJoin join) {
741778
// add error message
742779
return join.withConfig(new JoinConfig(type, singletonList(errorAttribute), emptyList(), emptyList()));
743780
}
781+
744782
return join;
745783
}
746784

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/AbstractLookupService.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,10 @@ abstract static class TransportRequest extends AbstractTransportRequest implemen
558558
this.source = source;
559559
}
560560

561+
public Page getInputPage() {
562+
return inputPage;
563+
}
564+
561565
@Override
562566
public final String[] indices() {
563567
return new String[] { indexPattern };
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.enrich;
9+
10+
import org.apache.lucene.search.Query;
11+
import org.elasticsearch.cluster.service.ClusterService;
12+
import org.elasticsearch.compute.data.Block;
13+
import org.elasticsearch.compute.operator.Warnings;
14+
import org.elasticsearch.compute.operator.lookup.QueryList;
15+
import org.elasticsearch.index.mapper.MappedFieldType;
16+
import org.elasticsearch.index.query.SearchExecutionContext;
17+
import org.elasticsearch.xpack.esql.core.expression.Literal;
18+
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison;
19+
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
20+
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
21+
import org.elasticsearch.xpack.esql.plugin.EsqlFlags;
22+
import org.elasticsearch.xpack.esql.stats.SearchContextStats;
23+
24+
import java.io.IOException;
25+
import java.util.List;
26+
import java.util.function.IntFunction;
27+
28+
public class BinaryComparisonQueryList extends QueryList {
29+
private final EsqlBinaryComparison binaryComparison;
30+
private final IntFunction<Object> blockValueReader;
31+
private final ClusterService clusterService;
32+
private final SearchExecutionContext searchExecutionContext;
33+
34+
public BinaryComparisonQueryList(
35+
MappedFieldType field,
36+
SearchExecutionContext searchExecutionContext,
37+
Block block,
38+
EsqlBinaryComparison binaryComparison,
39+
ClusterService clusterService
40+
) {
41+
super(field, searchExecutionContext, null, block, null);
42+
this.binaryComparison = binaryComparison;
43+
this.blockValueReader = QueryList.createBlockValueReader(block);
44+
this.clusterService = clusterService;
45+
this.searchExecutionContext = searchExecutionContext;
46+
}
47+
48+
@Override
49+
public QueryList onlySingleValues(Warnings warnings, String multiValueWarningMessage) {
50+
throw new UnsupportedOperationException();
51+
}
52+
53+
@Override
54+
public Query doGetQuery(int position, int firstValueIndex, int valueCount) throws IOException {
55+
if (valueCount == 0) {
56+
return null;
57+
}
58+
Object value = blockValueReader.apply(firstValueIndex);
59+
// create a new comparison with the value from the block as a literal
60+
EsqlBinaryComparison swapped = (EsqlBinaryComparison) binaryComparison.swapLeftAndRight();
61+
EsqlBinaryComparison comparison = swapped.getFunctionType()
62+
.buildNewInstance(swapped.source(), swapped.left(), new Literal(swapped.right().source(), value, swapped.right().dataType()));
63+
LucenePushdownPredicates lucenePushdownPredicates = LucenePushdownPredicates.from(
64+
SearchContextStats.from(List.of(searchExecutionContext)),
65+
new EsqlFlags(clusterService.getClusterSettings())
66+
);
67+
return comparison.asQuery(lucenePushdownPredicates, TranslatorHandler.TRANSLATOR_HANDLER)
68+
.toQueryBuilder()
69+
.toQuery(searchExecutionContext);
70+
}
71+
}

0 commit comments

Comments
 (0)