diff --git a/docs/changelog/132167.yaml b/docs/changelog/132167.yaml new file mode 100644 index 0000000000000..ae3c724ecfa17 --- /dev/null +++ b/docs/changelog/132167.yaml @@ -0,0 +1,6 @@ +pr: 132167 +summary: Deal with internally created IN in a different way for EQL +area: EQL +type: bug +issues: + - 118621 diff --git a/x-pack/plugin/eql/qa/common/build.gradle b/x-pack/plugin/eql/qa/common/build.gradle index 5fe6e54a440a8..04cfb01f9376f 100644 --- a/x-pack/plugin/eql/qa/common/build.gradle +++ b/x-pack/plugin/eql/qa/common/build.gradle @@ -8,3 +8,10 @@ dependencies { // TOML parser for EqlActionIT tests api 'io.ous:jtoml:2.0.0' } + +tasks.register("loadTestData", JavaExec) { + group = "Execution" + description = "Loads EQL Spec Tests data on a running stand-alone instance" + classpath = sourceSets.main.runtimeClasspath + mainClass = "org.elasticsearch.test.eql.DataLoader" +} diff --git a/x-pack/plugin/eql/qa/common/src/main/java/org/elasticsearch/test/eql/DataLoader.java b/x-pack/plugin/eql/qa/common/src/main/java/org/elasticsearch/test/eql/DataLoader.java index 4618bd8f4ff3d..2794f514777a5 100644 --- a/x-pack/plugin/eql/qa/common/src/main/java/org/elasticsearch/test/eql/DataLoader.java +++ b/x-pack/plugin/eql/qa/common/src/main/java/org/elasticsearch/test/eql/DataLoader.java @@ -76,39 +76,60 @@ private static Map getReplacementPatterns() { public static void main(String[] args) throws IOException { main = true; try (RestClient client = RestClient.builder(new HttpHost("localhost", 9200)).build()) { - loadDatasetIntoEs(client, DataLoader::createParser); + loadDatasetIntoEsWithIndexCreator(client, DataLoader::createParser, (restClient, indexName, indexMapping) -> { + // don't use ESRestTestCase methods here or, if you do, test running the main method before making the change + StringBuilder jsonBody = new StringBuilder("{"); + jsonBody.append("\"settings\":{\"number_of_shards\":1},"); + jsonBody.append("\"mappings\":"); + jsonBody.append(indexMapping); + jsonBody.append("}"); + + Request request = new Request("PUT", "/" + indexName); + request.setJsonEntity(jsonBody.toString()); + restClient.performRequest(request); + }); } } public static void loadDatasetIntoEs(RestClient client, CheckedBiFunction p) throws IOException { + loadDatasetIntoEsWithIndexCreator(client, p, (restClient, indexName, indexMapping) -> { + ESRestTestCase.createIndex(restClient, indexName, Settings.builder().put("number_of_shards", 1).build(), indexMapping, null); + }); + } + + private static void loadDatasetIntoEsWithIndexCreator( + RestClient client, + CheckedBiFunction p, + IndexCreator indexCreator + ) throws IOException { // // Main Index // - load(client, TEST_INDEX, null, DataLoader::timestampToUnixMillis, p); + load(client, TEST_INDEX, null, DataLoader::timestampToUnixMillis, p, indexCreator); // // Aux Index // - load(client, TEST_EXTRA_INDEX, null, null, p); + load(client, TEST_EXTRA_INDEX, null, null, p, indexCreator); // // Date_Nanos index // // The data for this index is loaded from the same endgame-140.data sample, only having the mapping for @timestamp changed: the // chosen Windows filetime timestamps (2017+) can coincidentally also be readily used as nano-resolution unix timestamps (1973+). // There are mixed values with and without nanos precision so that the filtering is properly tested for both cases. - load(client, TEST_NANOS_INDEX, TEST_INDEX, DataLoader::timestampToUnixNanos, p); - load(client, TEST_SAMPLE, null, null, p); + load(client, TEST_NANOS_INDEX, TEST_INDEX, DataLoader::timestampToUnixNanos, p, indexCreator); + load(client, TEST_SAMPLE, null, null, p, indexCreator); // // missing_events index // - load(client, TEST_MISSING_EVENTS_INDEX, null, null, p); - load(client, TEST_SAMPLE_MULTI, null, null, p); + load(client, TEST_MISSING_EVENTS_INDEX, null, null, p, indexCreator); + load(client, TEST_SAMPLE_MULTI, null, null, p, indexCreator); // // index with a runtime field ("broken", type long) that causes shard failures. // the rest of the mapping is the same as TEST_INDEX // - load(client, TEST_SHARD_FAILURES_INDEX, null, DataLoader::timestampToUnixMillis, p); + load(client, TEST_SHARD_FAILURES_INDEX, null, DataLoader::timestampToUnixMillis, p, indexCreator); } private static void load( @@ -116,7 +137,8 @@ private static void load( String indexNames, String dataName, Consumer> datasetTransform, - CheckedBiFunction p + CheckedBiFunction p, + IndexCreator indexCreator ) throws IOException { String[] splitNames = indexNames.split(","); for (String indexName : splitNames) { @@ -130,15 +152,11 @@ private static void load( if (data == null) { throw new IllegalArgumentException("Cannot find resource " + name); } - createTestIndex(client, indexName, readMapping(mapping)); + indexCreator.createIndex(client, indexName, readMapping(mapping)); loadData(client, indexName, datasetTransform, data, p); } } - private static void createTestIndex(RestClient client, String indexName, String mapping) throws IOException { - ESRestTestCase.createIndex(client, indexName, Settings.builder().put("number_of_shards", 1).build(), mapping, null); - } - /** * Reads the mapping file, ignoring comments and replacing placeholders for random types. */ @@ -236,4 +254,8 @@ private static XContentParser createParser(XContent xContent, InputStream data) NamedXContentRegistry contentRegistry = new NamedXContentRegistry(ClusterModule.getNamedXWriteables()); return xContent.createParser(contentRegistry, LoggingDeprecationHandler.INSTANCE, data); } + + private interface IndexCreator { + void createIndex(RestClient client, String indexName, String mapping) throws IOException; + } } diff --git a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/optimizer/Optimizer.java b/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/optimizer/Optimizer.java index 5bccf013bc789..17f095db5033e 100644 --- a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/optimizer/Optimizer.java +++ b/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/optimizer/Optimizer.java @@ -44,7 +44,6 @@ import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.BinaryComparisonSimplification; import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.BooleanFunctionEqualsElimination; import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.BooleanSimplification; -import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.CombineDisjunctionsToIn; import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.ConstantFolding; import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.LiteralsOnTheRight; import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.OptimizerRule; @@ -252,6 +251,14 @@ protected Expression maybeSimplifyNegatable(Expression e) { } + static class CombineDisjunctionsToIn extends org.elasticsearch.xpack.ql.optimizer.OptimizerRules.CombineDisjunctionsToIn { + + @Override + protected boolean shouldValidateIn() { + return true; + } + } + static class PruneFilters extends org.elasticsearch.xpack.ql.optimizer.OptimizerRules.PruneFilters { @Override diff --git a/x-pack/plugin/eql/src/test/resources/querytranslator_tests.txt b/x-pack/plugin/eql/src/test/resources/querytranslator_tests.txt index 00c08096fd084..84ad63b964981 100644 --- a/x-pack/plugin/eql/src/test/resources/querytranslator_tests.txt +++ b/x-pack/plugin/eql/src/test/resources/querytranslator_tests.txt @@ -123,6 +123,90 @@ process where process_name in ("python.exe", "SMSS.exe", "explorer.exe") "terms":{"process_name":["python.exe","SMSS.exe","explorer.exe"], ; +mutipleOrEquals_As_InTranslation1 +process where process_name == "python.exe" or process_name == "SMSS.exe" or process_name == "explorer.exe" +; +"terms":{"process_name":["python.exe","SMSS.exe","explorer.exe"], +; + +multipleOrAndEquals_As_InTranslation +process where process_name == "python.exe" and process_name == "SMSS.exe" or process_name == "explorer.exe" or process_name == "test.exe" +; +{"bool":{"should":[{"bool":{"must":[{"term":{"process_name":{"value":"python.exe"}}},{"term":{"process_name":{"value":"SMSS.exe"}}}],"boost":1.0}},{"terms":{"process_name":["explorer.exe","test.exe"],"boost":1.0}}],"boost":1.0}} +; + +mutipleOrEquals_As_InTranslation2 +process where source_address == "123.12.1.1" or (opcode == 123 or opcode == 127) +; +{"bool":{"should":[{"term":{"source_address":{"value":"123.12.1.1"}}},{"terms":{"opcode":[123,127],"boost":1.0}}],"boost":1.0}} +; + +mutipleOrEquals_As_InTranslation3 +process where (source_address == "123.12.1.1" or source_address == "127.0.0.1") and (opcode == 123 or opcode == 127) +; +{"bool":{"should":[{"term":{"source_address":{"value":"123.12.1.1"}}},{"term":{"source_address":{"value":"127.0.0.1"}}}],"boost":1.0}},{"terms":{"opcode":[123,127],"boost":1.0}} +; + +mutipleOrEquals_As_InTranslation4 +process where (source_address == "123.12.1.1" or source_address == "127.0.0.1") and (opcode == 123 or opcode == 127) +; +"must":[{"bool":{"should":[{"term":{"source_address":{"value":"123.12.1.1"}}},{"term":{"source_address":{"value":"127.0.0.1"}}}],"boost":1.0}},{"terms":{"opcode":[123,127],"boost":1.0}},{"term":{"event.category":{"value":"process"}}}] +; + +multipleOrIncompatibleTypes1 +process where process_name == "python.exe" or process_name == 2 or process_name == "3" +; +{"bool":{"should":[{"term":{"process_name":{"value":"python.exe"}}},{"term":{"process_name":{"value":2}}},{"term":{"process_name":{"value":"3"}}}],"boost":1.0}} +; + +multipleOrIncompatibleTypes2 +process where process_name == "1" or process_name == 2 or process_name == "3" +; +{"bool":{"should":[{"term":{"process_name":{"value":"1"}}},{"term":{"process_name":{"value":2}}},{"term":{"process_name":{"value":"3"}}}],"boost":1.0}} +; + +multipleOrIncompatibleTypes3 +process where process_name == 1.2 or process_name == 2 or process_name == "3" +; +{"bool":{"should":[{"term":{"process_name":{"value":1.2}}},{"term":{"process_name":{"value":2}}},{"term":{"process_name":{"value":"3"}}}],"boost":1.0}} +; + +// this query as an equivalent with +// process where process_name in (1.2, 2, 3) +// will result in a user error: 1st argument of [process_name in (1.2, 2, 3)] must be [keyword], found value [1.2] type [double] +multipleOrIncompatibleTypes4 +process where process_name == 1.2 or process_name == 2 or process_name == 3 +; +{"bool":{"should":[{"term":{"process_name":{"value":1.2}}},{"term":{"process_name":{"value":2}}},{"term":{"process_name":{"value":3}}}],"boost":1.0}} +; + +// this query as an equivalent with +// process where source_address in ("123.12.1.1", "123.12.1.2") +// will result in a user error: 1st argument of [source_address in ("123.12.1.1", "123.12.1.2")] must be [ip], found value ["123.12.1.1"] type [keyword] +multipleOrIncompatibleTypes5 +process where source_address == "123.12.1.1" or source_address == "123.12.1.2" +; +{"bool":{"should":[{"term":{"source_address":{"value":"123.12.1.1"}}},{"term":{"source_address":{"value":"123.12.1.2"}}}],"boost":1.0}} +; + +multipleOrIncompatibleTypes6 +process where source_address == "123.12.1.1" or source_address == concat("123.12.","1.2") +; +{"bool":{"should":[{"term":{"source_address":{"value":"123.12.1.1"}}},{"term":{"source_address":{"value":"123.12.1.2"}}}],"boost":1.0}} +; + +multipleOrIncompatibleTypes7 +process where source_address == "123.12.1.1" and (source_address == "123.12.1.2" or source_address >= "127.0.0.1") +; +"must":[{"term":{"source_address":{"value":"123.12.1.1"}}},{"bool":{"should":[{"term":{"source_address":{"value":"123.12.1.2"}}},{"range":{"source_address":{"gte":"127.0.0.1","boost":1.0}}}],"boost":1.0}},{"term":{"event.category":{"value":"process"}}}] +; + +multipleOrIncompatibleTypes8 +process where source_address == "123.12.1.1" and (source_address == "123.12.1.2" or source_address == "127.0.0.1") +; +"must":[{"term":{"source_address":{"value":"123.12.1.1"}}},{"bool":{"should":[{"term":{"source_address":{"value":"123.12.1.2"}}},{"term":{"source_address":{"value":"127.0.0.1"}}}],"boost":1.0}},{"term":{"event.category":{"value":"process"}}}] +; + inFilterWithScripting process where substring(command_line, 5) in ("test*","best") ; diff --git a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/In.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/In.java index f12f8edb71795..0398288960dc1 100644 --- a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/In.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/In.java @@ -177,6 +177,10 @@ protected TypeResolution resolveType() { return super.resolveType(); } + public TypeResolution validateInTypes() { + return resolveType(); + } + @Override public int hashCode() { return Objects.hash(value, list); diff --git a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRules.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRules.java index 7625cbf3a56e5..a5e1d3bdbe620 100644 --- a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRules.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRules.java @@ -1203,8 +1203,8 @@ private static boolean notEqualsIsRemovableFromConjunction(NotEquals notEquals, * 2. a == 1 OR a IN (2) becomes a IN (1, 2) * 3. a IN (1) OR a IN (2) becomes a IN (1, 2) * - * This rule does NOT check for type compatibility as that phase has been - * already be verified in the analyzer. + * By default (see {@link #shouldValidateIn()}), this rule does NOT check for type compatibility as that phase has + * already been verified in the analyzer, but this behavior can be changed by subclasses. */ public static class CombineDisjunctionsToIn extends OptimizerExpressionRule { public CombineDisjunctionsToIn() { @@ -1214,18 +1214,24 @@ public CombineDisjunctionsToIn() { @Override protected Expression rule(Or or) { Expression e = or; - // look only at equals and In + // look only at Equals and In List exps = splitOr(e); Map> found = new LinkedHashMap<>(); + Map> originalOrs = new LinkedHashMap<>(); ZoneId zoneId = null; List ors = new LinkedList<>(); for (Expression exp : exps) { if (exp instanceof Equals eq) { - // consider only equals against foldables + // consider only Equals against foldables if (eq.right().foldable()) { found.computeIfAbsent(eq.left(), k -> new LinkedHashSet<>()).add(eq.right()); + if (shouldValidateIn()) { + // in case there is an optimized In being built and its validation fails, rebuild the original ORs + // so, keep around the original Expressions + originalOrs.computeIfAbsent(eq.left(), k -> new ArrayList<>()).add(eq); + } } else { ors.add(exp); } @@ -1234,6 +1240,11 @@ protected Expression rule(Or or) { } } else if (exp instanceof In in) { found.computeIfAbsent(in.value(), k -> new LinkedHashSet<>()).addAll(in.list()); + if (shouldValidateIn()) { + // in case there is an optimized In being built and its validation fails, rebuild the original ORs + // so, keep around the original Expressions + originalOrs.computeIfAbsent(in.value(), k -> new ArrayList<>()).add(in); + } if (zoneId == null) { zoneId = in.zoneId(); } @@ -1243,11 +1254,31 @@ protected Expression rule(Or or) { } if (found.isEmpty() == false) { - // combine equals alongside the existing ors + // combine Equals alongside the existing ORs final ZoneId finalZoneId = zoneId; - found.forEach( - (k, v) -> { ors.add(v.size() == 1 ? createEquals(k, v, finalZoneId) : createIn(k, new ArrayList<>(v), finalZoneId)); } - ); + found.forEach((k, v) -> { + if (v.size() == 1) { + ors.add(createEquals(k, v.iterator().next(), finalZoneId)); + } else { + In in = createIn(k, new ArrayList<>(v), finalZoneId); + // IN has its own particularities when it comes to type resolution and not all implementations + // double check the validity of an internally created IN (like the one created here). EQL is one where the IN + // implementation is like this mechanism here has been specifically created for it + if (shouldValidateIn()) { + Expression.TypeResolution resolution = in.validateInTypes(); + if (resolution.unresolved()) { + // if the internally created In is not valid, fall back to the original ORs + assert originalOrs.containsKey(k); + assert originalOrs.get(k).isEmpty() == false; + ors.add(combineOr(originalOrs.get(k))); + } else { + ors.add(in); + } + } else { + ors.add(in); + } + } + }); Expression combineOr = combineOr(ors); // check the result semantically since the result might different in order @@ -1261,13 +1292,17 @@ protected Expression rule(Or or) { return e; } - protected Equals createEquals(Expression k, Set v, ZoneId finalZoneId) { - return new Equals(k.source(), k, v.iterator().next(), finalZoneId); - } - protected In createIn(Expression key, List values, ZoneId zoneId) { return new In(key.source(), key, values, zoneId); } + + protected boolean shouldValidateIn() { + return false; + } + + private Equals createEquals(Expression key, Expression value, ZoneId finalZoneId) { + return new Equals(key.source(), key, value, finalZoneId); + } } public static class PushDownAndCombineFilters extends OptimizerRule { diff --git a/x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRulesTests.java b/x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRulesTests.java index bc7e0b2a93bf5..9e60a60c4c3d3 100644 --- a/x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRulesTests.java +++ b/x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRulesTests.java @@ -66,6 +66,8 @@ import java.time.ZoneId; import java.util.Collections; import java.util.List; +import java.util.Set; +import java.util.function.Consumer; import static java.util.Arrays.asList; import static java.util.Collections.emptyList; @@ -105,6 +107,9 @@ public class OptimizerRulesTests extends ESTestCase { private static final Literal FOUR = L(4); private static final Literal FIVE = L(5); private static final Literal SIX = L(6); + private static final Literal TEXT_A = L("A"); + private static final Literal TEXT_B = L("B"); + private static final Literal TEXT_C = L("C"); public static class DummyBooleanExpression extends Expression { @@ -1491,48 +1496,71 @@ public void testExactMatchRLike() throws Exception { // // CombineDisjunction in Equals // + + // CombineDisjunctionsToIn with shouldValidateIn as true + private final class ValidateableCombineDisjunctionsToIn extends CombineDisjunctionsToIn { + @Override + protected boolean shouldValidateIn() { + return true; + } + }; + + private void assertCombineDisjunctionsToIn(Consumer tester) { + for (CombineDisjunctionsToIn rule : Set.of(new CombineDisjunctionsToIn(), new ValidateableCombineDisjunctionsToIn())) { + tester.accept(rule); + } + } + public void testTwoEqualsWithOr() throws Exception { FieldAttribute fa = getFieldAttribute(); Or or = new Or(EMPTY, equalsOf(fa, ONE), equalsOf(fa, TWO)); - Expression e = new CombineDisjunctionsToIn().rule(or); - assertEquals(In.class, e.getClass()); - In in = (In) e; - assertEquals(fa, in.value()); - assertThat(in.list(), contains(ONE, TWO)); + assertCombineDisjunctionsToIn((rule) -> { + Expression e = rule.rule(or); + assertEquals(In.class, e.getClass()); + In in = (In) e; + assertEquals(fa, in.value()); + assertThat(in.list(), contains(ONE, TWO)); + }); } public void testTwoEqualsWithSameValue() throws Exception { FieldAttribute fa = getFieldAttribute(); Or or = new Or(EMPTY, equalsOf(fa, ONE), equalsOf(fa, ONE)); - Expression e = new CombineDisjunctionsToIn().rule(or); - assertEquals(Equals.class, e.getClass()); - Equals eq = (Equals) e; - assertEquals(fa, eq.left()); - assertEquals(ONE, eq.right()); + assertCombineDisjunctionsToIn((rule) -> { + Expression e = rule.rule(or); + assertEquals(Equals.class, e.getClass()); + Equals eq = (Equals) e; + assertEquals(fa, eq.left()); + assertEquals(ONE, eq.right()); + }); } public void testOneEqualsOneIn() throws Exception { FieldAttribute fa = getFieldAttribute(); Or or = new Or(EMPTY, equalsOf(fa, ONE), new In(EMPTY, fa, singletonList(TWO))); - Expression e = new CombineDisjunctionsToIn().rule(or); - assertEquals(In.class, e.getClass()); - In in = (In) e; - assertEquals(fa, in.value()); - assertThat(in.list(), contains(ONE, TWO)); + assertCombineDisjunctionsToIn((rule) -> { + Expression e = rule.rule(or); + assertEquals(In.class, e.getClass()); + In in = (In) e; + assertEquals(fa, in.value()); + assertThat(in.list(), contains(ONE, TWO)); + }); } public void testOneEqualsOneInWithSameValue() throws Exception { FieldAttribute fa = getFieldAttribute(); Or or = new Or(EMPTY, equalsOf(fa, ONE), new In(EMPTY, fa, asList(ONE, TWO))); - Expression e = new CombineDisjunctionsToIn().rule(or); - assertEquals(In.class, e.getClass()); - In in = (In) e; - assertEquals(fa, in.value()); - assertThat(in.list(), contains(ONE, TWO)); + assertCombineDisjunctionsToIn((rule) -> { + Expression e = rule.rule(or); + assertEquals(In.class, e.getClass()); + In in = (In) e; + assertEquals(fa, in.value()); + assertThat(in.list(), contains(ONE, TWO)); + }); } public void testSingleValueInToEquals() throws Exception { @@ -1540,8 +1568,10 @@ public void testSingleValueInToEquals() throws Exception { Equals equals = equalsOf(fa, ONE); Or or = new Or(EMPTY, equals, new In(EMPTY, fa, singletonList(ONE))); - Expression e = new CombineDisjunctionsToIn().rule(or); - assertEquals(equals, e); + assertCombineDisjunctionsToIn((rule) -> { + Expression e = rule.rule(or); + assertEquals(equals, e); + }); } public void testEqualsBehindAnd() throws Exception { @@ -1549,9 +1579,11 @@ public void testEqualsBehindAnd() throws Exception { And and = new And(EMPTY, equalsOf(fa, ONE), equalsOf(fa, TWO)); Filter dummy = new Filter(EMPTY, relation(), and); - LogicalPlan transformed = new CombineDisjunctionsToIn().apply(dummy); - assertSame(dummy, transformed); - assertEquals(and, ((Filter) transformed).condition()); + assertCombineDisjunctionsToIn((rule) -> { + LogicalPlan transformed = rule.apply(dummy); + assertSame(dummy, transformed); + assertEquals(and, ((Filter) transformed).condition()); + }); } public void testTwoEqualsDifferentFields() throws Exception { @@ -1559,8 +1591,10 @@ public void testTwoEqualsDifferentFields() throws Exception { FieldAttribute fieldTwo = TestUtils.getFieldAttribute("TWO"); Or or = new Or(EMPTY, equalsOf(fieldOne, ONE), equalsOf(fieldTwo, TWO)); - Expression e = new CombineDisjunctionsToIn().rule(or); - assertEquals(or, e); + assertCombineDisjunctionsToIn((rule) -> { + Expression e = rule.rule(or); + assertEquals(or, e); + }); } public void testMultipleIn() throws Exception { @@ -1568,11 +1602,13 @@ public void testMultipleIn() throws Exception { Or firstOr = new Or(EMPTY, new In(EMPTY, fa, singletonList(ONE)), new In(EMPTY, fa, singletonList(TWO))); Or secondOr = new Or(EMPTY, firstOr, new In(EMPTY, fa, singletonList(THREE))); - Expression e = new CombineDisjunctionsToIn().rule(secondOr); - assertEquals(In.class, e.getClass()); - In in = (In) e; - assertEquals(fa, in.value()); - assertThat(in.list(), contains(ONE, TWO, THREE)); + assertCombineDisjunctionsToIn((rule) -> { + Expression e = rule.rule(secondOr); + assertEquals(In.class, e.getClass()); + In in = (In) e; + assertEquals(fa, in.value()); + assertThat(in.list(), contains(ONE, TWO, THREE)); + }); } public void testOrWithNonCombinableExpressions() throws Exception { @@ -1580,14 +1616,159 @@ public void testOrWithNonCombinableExpressions() throws Exception { Or firstOr = new Or(EMPTY, new In(EMPTY, fa, singletonList(ONE)), lessThanOf(fa, TWO)); Or secondOr = new Or(EMPTY, firstOr, new In(EMPTY, fa, singletonList(THREE))); - Expression e = new CombineDisjunctionsToIn().rule(secondOr); + assertCombineDisjunctionsToIn((rule) -> { + Expression e = rule.rule(secondOr); + assertEquals(Or.class, e.getClass()); + Or or = (Or) e; + assertEquals(or.left(), firstOr.right()); + assertEquals(In.class, or.right().getClass()); + In in = (In) or.right(); + assertEquals(fa, in.value()); + assertThat(in.list(), contains(ONE, THREE)); + }); + } + + public void testDontCombineSimpleDifferentTypes() throws Exception { + FieldAttribute fa = getFieldAttribute(); + + Or or = new Or(EMPTY, new Equals(EMPTY, fa, ONE), new Equals(EMPTY, fa, TEXT_A)); + Expression e = new ValidateableCombineDisjunctionsToIn().rule(or); + assertEquals(or, e); + } + + public void testDontCombineDifferentTypes() throws Exception { + FieldAttribute fa = getFieldAttribute(); + + Or or = new Or(EMPTY, new Equals(EMPTY, fa, ONE), new Equals(EMPTY, fa, TEXT_A)); + Expression e = new ValidateableCombineDisjunctionsToIn().rule(or); + assertEquals(or, e); + } + + // See https://github.com/elastic/elasticsearch/issues/118621 + public void testDontCombineStringTypesForIPField() throws Exception { + FieldAttribute fa = TestUtils.getFieldAttribute("ip", DataTypes.IP); + + Or or = new Or(EMPTY, new Equals(EMPTY, fa, TEXT_A), new Equals(EMPTY, fa, TEXT_B)); + Expression e = new ValidateableCombineDisjunctionsToIn().rule(or); + assertEquals(or, e); + } + + public void testDontCombineForIncompatibleFieldType() throws Exception { + FieldAttribute fa = TestUtils.getFieldAttribute("boolean", BOOLEAN); + + Or or = new Or(EMPTY, new Equals(EMPTY, fa, ONE), new Equals(EMPTY, fa, TWO)); + Expression e = new ValidateableCombineDisjunctionsToIn().rule(or); + assertEquals(or, e); + } + + public void testDontCombineTwoCompatibleAndOneIncompatible() throws Exception { + FieldAttribute fa = getFieldAttribute(); + + Or firstOr = new Or(EMPTY, new Equals(EMPTY, fa, ONE), new Equals(EMPTY, fa, TWO)); + Or secondOr = new Or(EMPTY, firstOr, new Equals(EMPTY, fa, TEXT_A)); + Expression e = new ValidateableCombineDisjunctionsToIn().rule(secondOr); + assertEquals(secondOr, e); + } + + public void testDontCombineOneIncompatibleEqualsWithCompatibleIn() throws Exception { + FieldAttribute fa = getFieldAttribute(); + + Or or = new Or(EMPTY, new In(EMPTY, fa, List.of(ONE, TWO)), new Equals(EMPTY, fa, TEXT_A)); + Expression e = new ValidateableCombineDisjunctionsToIn().rule(or); + assertEquals(or, e); + } + + public void testDontCombineTwoIncompatibleIns1() throws Exception { + FieldAttribute fa = getFieldAttribute(); + + Or or = new Or(EMPTY, new In(EMPTY, fa, List.of(ONE, TWO)), new In(EMPTY, fa, List.of(TEXT_A, TEXT_B, TEXT_C))); + Expression e = new ValidateableCombineDisjunctionsToIn().rule(or); + assertEquals(or, e); + } + + public void testDontCombineTwoIncompatibleIns2() throws Exception { + FieldAttribute fa = getFieldAttribute(); + + Or or = new Or(EMPTY, new In(EMPTY, fa, List.of(ONE)), new In(EMPTY, fa, List.of(TEXT_A))); + Expression e = new ValidateableCombineDisjunctionsToIn().rule(or); + assertEquals(or, e); + } + + public void testDontCombineTwoIncompatibleIns3() throws Exception { + FieldAttribute fa1 = TestUtils.getFieldAttribute("field1"); + FieldAttribute fa2 = TestUtils.getFieldAttribute("field2"); + + Or or = new Or(EMPTY, new In(EMPTY, fa1, List.of(ONE, TWO)), new In(EMPTY, fa2, List.of(THREE, FOUR))); + Expression e = new ValidateableCombineDisjunctionsToIn().rule(or); + assertEquals(or, e); + } + + public void testDontCombineIncompatibleInWithTwoCompatibleEquals() throws Exception { + FieldAttribute fa = getFieldAttribute(); + + Or firstOr = new Or(EMPTY, new In(EMPTY, fa, List.of(TEXT_A, TEXT_B)), new Equals(EMPTY, fa, THREE)); + Or secondOr = new Or(EMPTY, firstOr, new Equals(EMPTY, fa, FOUR)); + Expression e = new ValidateableCombineDisjunctionsToIn().rule(secondOr); + assertEquals(secondOr, e); + } + + public void testCombineOnlyEqualsExpressions() throws Exception { + FieldAttribute faIn = TestUtils.getFieldAttribute("field_for_in"); + FieldAttribute faEquals = TestUtils.getFieldAttribute("field_for_equals"); + + Or firstOr = new Or(EMPTY, new In(EMPTY, faIn, List.of(ONE, TWO)), new Equals(EMPTY, faEquals, THREE)); + Or secondOr = new Or(EMPTY, firstOr, new Equals(EMPTY, faEquals, FOUR)); + Expression e = new ValidateableCombineDisjunctionsToIn().rule(secondOr); assertEquals(Or.class, e.getClass()); Or or = (Or) e; - assertEquals(or.left(), firstOr.right()); + assertEquals(or.left(), firstOr.left()); assertEquals(In.class, or.right().getClass()); In in = (In) or.right(); - assertEquals(fa, in.value()); - assertThat(in.list(), contains(ONE, THREE)); + assertEquals(faEquals, in.value()); + assertThat(in.list(), contains(THREE, FOUR)); + } + + public void testCombineOnlyCompatibleEqualsExpressions() throws Exception { + FieldAttribute faEquals1 = TestUtils.getFieldAttribute("field_for_equals1"); + FieldAttribute faEquals2 = TestUtils.getFieldAttribute("field_for_equals2"); + + Equals equalsA = new Equals(EMPTY, faEquals2, TEXT_A); + Equals equalsB = new Equals(EMPTY, faEquals2, TEXT_B); + Or firstOr = new Or(EMPTY, new Equals(EMPTY, faEquals1, ONE), equalsA); + Or secondOr = new Or(EMPTY, firstOr, new Equals(EMPTY, faEquals1, TWO)); + Or thirdOr = new Or(EMPTY, secondOr, equalsB); + + Expression e = new ValidateableCombineDisjunctionsToIn().rule(thirdOr); + assertEquals(Or.class, e.getClass()); + Or or = (Or) e; + assertEquals(In.class, or.left().getClass()); + In in = (In) or.left(); + assertThat(in.list(), contains(ONE, TWO)); + + assertEquals(Or.class, or.right().getClass()); + or = (Or) or.right(); + assertEquals(or.left(), equalsA); + assertEquals(or.right(), equalsB); + } + + public void testCombineTwoCompatiblePairsOrEqualsExpressions() throws Exception { + FieldAttribute faEquals1 = TestUtils.getFieldAttribute("field_for_equals1"); + FieldAttribute faEquals2 = TestUtils.getFieldAttribute("field_for_equals2"); + + Or firstOr = new Or(EMPTY, new Equals(EMPTY, faEquals1, ONE), new Equals(EMPTY, faEquals2, THREE)); + Or secondOr = new Or(EMPTY, firstOr, new Equals(EMPTY, faEquals1, TWO)); + Or thirdOr = new Or(EMPTY, secondOr, new Equals(EMPTY, faEquals2, FOUR)); + + Expression e = new ValidateableCombineDisjunctionsToIn().rule(thirdOr); + assertEquals(Or.class, e.getClass()); + Or or = (Or) e; + assertEquals(In.class, or.left().getClass()); + In in = (In) or.left(); + assertThat(in.list(), contains(ONE, TWO)); + + assertEquals(In.class, or.right().getClass()); + in = (In) or.right(); + assertThat(in.list(), contains(THREE, FOUR)); } // Null folding