Skip to content

Commit fa34213

Browse files
committed
enable score in WHERE with FTFs, more tests
1 parent d4c2392 commit fa34213

File tree

4 files changed

+281
-145
lines changed

4 files changed

+281
-145
lines changed

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/ScoreIT.java renamed to x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/ScoreFunctionIT.java

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,20 @@
1212
import org.elasticsearch.common.settings.Settings;
1313
import org.elasticsearch.common.util.CollectionUtils;
1414
import org.elasticsearch.plugins.Plugin;
15+
import org.elasticsearch.xpack.esql.VerificationException;
1516
import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
1617
import org.elasticsearch.xpack.kql.KqlPlugin;
1718
import org.junit.Before;
19+
import org.junit.Ignore;
1820

1921
import java.util.Collection;
2022
import java.util.List;
2123

2224
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
25+
import static org.hamcrest.CoreMatchers.containsString;
2326

2427
//@TestLogging(value = "org.elasticsearch.xpack.esql:TRACE,org.elasticsearch.compute:TRACE", reason = "debug")
25-
public class ScoreIT extends AbstractEsqlIntegTestCase {
28+
public class ScoreFunctionIT extends AbstractEsqlIntegTestCase {
2629

2730
@Before
2831
public void setupIndex() {
@@ -48,6 +51,136 @@ public void testScoreDifferentWhereMatch() {
4851
}
4952
}
5053

54+
public void testScoreDifferentWhereMatchNoMetadata() {
55+
var query = """
56+
FROM test
57+
| EVAL first_score = score(match(content, "brown"))
58+
| WHERE match(content, "fox")
59+
| KEEP id, first_score
60+
| SORT id
61+
""";
62+
63+
try (var resp = run(query)) {
64+
assertColumnNames(resp.columns(), List.of("id", "first_score"));
65+
assertColumnTypes(resp.columns(), List.of("integer", "double"));
66+
assertValues(resp.values(), List.of(List.of(1, 0.2708943784236908), List.of(6, 0.21347221732139587)));
67+
}
68+
}
69+
70+
public void testScoreInWhereWithMatch() {
71+
var query = """
72+
FROM test
73+
| WHERE score(match(content, "brown"))
74+
""";
75+
76+
var error = expectThrows(VerificationException.class, () -> run(query));
77+
assertThat(error.getMessage(), containsString("Condition expression needs to be boolean, found [DOUBLE]"));
78+
}
79+
80+
public void testScoreInWhereWithFilter() {
81+
var query = """
82+
FROM test
83+
| WHERE score(id > 0)
84+
""";
85+
86+
var error = expectThrows(VerificationException.class, () -> run(query));
87+
assertThat(error.getMessage(), containsString("Condition expression needs to be boolean, found [DOUBLE]"));
88+
}
89+
90+
public void testMatchScoreFilter() {
91+
var query = """
92+
FROM test
93+
| WHERE score(match(content, "brown")) > 0
94+
| KEEP id
95+
| SORT id
96+
""";
97+
98+
try (var resp = run(query)) {
99+
assertColumnNames(resp.columns(), List.of("id"));
100+
assertColumnTypes(resp.columns(), List.of("integer"));
101+
assertValues(resp.values(), List.of(List.of(1), List.of(2), List.of(3), List.of(4), List.of(6)));
102+
}
103+
}
104+
105+
public void testKqlScoreFilter() {
106+
var query = """
107+
FROM test
108+
| WHERE score(kql("brown")) > 0
109+
| KEEP id
110+
| SORT id
111+
""";
112+
113+
try (var resp = run(query)) {
114+
assertColumnNames(resp.columns(), List.of("id"));
115+
assertColumnTypes(resp.columns(), List.of("integer"));
116+
assertValues(resp.values(), List.of(List.of(1), List.of(2), List.of(3), List.of(4), List.of(6)));
117+
}
118+
}
119+
120+
public void testQstrScoreFilter() {
121+
var query = """
122+
FROM test
123+
| WHERE score(qstr("brown")) > 0
124+
| KEEP id
125+
| SORT id
126+
""";
127+
128+
try (var resp = run(query)) {
129+
assertColumnNames(resp.columns(), List.of("id"));
130+
assertColumnTypes(resp.columns(), List.of("integer"));
131+
assertValues(resp.values(), List.of(List.of(1), List.of(2), List.of(3), List.of(4), List.of(6)));
132+
}
133+
}
134+
135+
public void testMultipleFTFScoreFilter() {
136+
var query = """
137+
FROM test
138+
| WHERE score(match(content, "brown")) > 0.4 OR score(match(content, "fox")) > 0.2
139+
| KEEP id
140+
| SORT id
141+
""";
142+
143+
try (var resp = run(query)) {
144+
assertColumnNames(resp.columns(), List.of("id"));
145+
assertColumnTypes(resp.columns(), List.of("integer"));
146+
assertValues(resp.values(), List.of(List.of(1), List.of(6)));
147+
}
148+
}
149+
150+
public void testMultipleHybridScoreFilter() {
151+
var query = """
152+
FROM test
153+
| WHERE score(match(content, "brown")) > 0.2 AND id > 2
154+
| KEEP id
155+
| SORT id
156+
""";
157+
158+
try (var resp = run(query)) {
159+
assertColumnNames(resp.columns(), List.of("id"));
160+
assertColumnTypes(resp.columns(), List.of("integer"));
161+
assertValues(resp.values(), List.of(List.of(3), List.of(6)));
162+
}
163+
}
164+
165+
@Ignore("it's meaningless but it passes o_O")
166+
public void testScoreMeaninglessFunction() {
167+
var query = """
168+
FROM test
169+
| EVAL meaningless = score(abs(-0.1))
170+
| KEEP id, meaningless
171+
| SORT id
172+
""";
173+
174+
try (var resp = run(query)) {
175+
assertColumnNames(resp.columns(), List.of("id", "meaningless"));
176+
assertColumnTypes(resp.columns(), List.of("integer", "double"));
177+
assertValues(
178+
resp.values(),
179+
List.of(List.of(1, 0.0), List.of(2, 0.0), List.of(3, 0.0), List.of(4, 0.0), List.of(5, 0.0), List.of(6, 0.0))
180+
);
181+
}
182+
}
183+
51184
public void testScoreMultipleWhereMatch() {
52185
var query = """
53186
FROM test METADATA _score

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
3535
import org.elasticsearch.xpack.esql.expression.predicate.logical.BinaryLogic;
3636
import org.elasticsearch.xpack.esql.expression.predicate.logical.Not;
37+
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison;
3738
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
3839
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
3940
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
@@ -304,6 +305,8 @@ private static void checkFullTextFunctionsParents(Expression condition, Failures
304305
forEachFullTextFunctionParent(condition, (ftf, parent) -> {
305306
if ((parent instanceof FullTextFunction == false)
306307
&& (parent instanceof BinaryLogic == false)
308+
&& (parent instanceof EsqlBinaryComparison == false)
309+
&& (parent instanceof Score == false) // e.g., WHERE score($ftf) > ...
307310
&& (parent instanceof Not == false)) {
308311
failures.add(
309312
fail(

0 commit comments

Comments
 (0)