Skip to content

Commit c173522

Browse files
authored
Fix for FORK branches with mixed outputs and unsupported field types (#129636)
1 parent 7f12f80 commit c173522

File tree

4 files changed

+152
-19
lines changed

4 files changed

+152
-19
lines changed

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/ForkIT.java

Lines changed: 88 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
import org.elasticsearch.action.support.WriteRequest;
1212
import org.elasticsearch.common.settings.Settings;
1313
import org.elasticsearch.compute.operator.DriverProfile;
14-
import org.elasticsearch.test.junit.annotations.TestLogging;
1514
import org.elasticsearch.xpack.esql.VerificationException;
1615
import org.elasticsearch.xpack.esql.parser.ParsingException;
1716
import org.junit.Before;
1817

1918
import java.util.Arrays;
2019
import java.util.Iterator;
2120
import java.util.List;
21+
import java.util.Map;
2222
import java.util.Set;
2323
import java.util.function.Predicate;
2424
import java.util.stream.Collectors;
@@ -27,7 +27,7 @@
2727
import static org.elasticsearch.xpack.esql.EsqlTestUtils.getValuesList;
2828
import static org.hamcrest.Matchers.equalTo;
2929

30-
@TestLogging(value = "org.elasticsearch.xpack.esql:TRACE,org.elasticsearch.compute:TRACE", reason = "debug")
30+
// @TestLogging(value = "org.elasticsearch.xpack.esql:TRACE,org.elasticsearch.compute:TRACE", reason = "debug")
3131
public class ForkIT extends AbstractEsqlIntegTestCase {
3232

3333
@Before
@@ -800,6 +800,79 @@ public void testWithKeep() {
800800
}
801801
}
802802

803+
public void testWithUnsupportedFieldsWithSameBranches() {
804+
var query = """
805+
FROM test-other
806+
| FORK
807+
( WHERE id == "3")
808+
( WHERE id == "2" )
809+
| SORT _fork
810+
""";
811+
812+
try (var resp = run(query)) {
813+
assertColumnNames(resp.columns(), List.of("content", "embedding", "id", "_fork"));
814+
assertColumnTypes(resp.columns(), List.of("keyword", "unsupported", "keyword", "keyword"));
815+
Iterable<Iterable<Object>> expectedValues = List.of(
816+
Arrays.stream(new Object[] { "This dog is really brown", null, "3", "fork1" }).toList(),
817+
Arrays.stream(new Object[] { "This is a brown dog", null, "2", "fork2" }).toList()
818+
);
819+
assertValues(resp.values(), expectedValues);
820+
}
821+
}
822+
823+
public void testWithUnsupportedFieldsWithDifferentBranches() {
824+
var query = """
825+
FROM test-other
826+
| FORK
827+
( STATS x = count(*))
828+
( WHERE id == "2" )
829+
| SORT _fork
830+
""";
831+
832+
try (var resp = run(query)) {
833+
assertColumnNames(resp.columns(), List.of("x", "_fork", "content", "embedding", "id"));
834+
assertColumnTypes(resp.columns(), List.of("long", "keyword", "keyword", "unsupported", "keyword"));
835+
Iterable<Iterable<Object>> expectedValues = List.of(
836+
Arrays.stream(new Object[] { 3L, "fork1", null, null, null }).toList(),
837+
Arrays.stream(new Object[] { null, "fork2", "This is a brown dog", null, "2" }).toList()
838+
);
839+
assertValues(resp.values(), expectedValues);
840+
}
841+
}
842+
843+
public void testWithUnsupportedFieldsAndConflicts() {
844+
var firstQuery = """
845+
FROM test-other
846+
| FORK
847+
( STATS embedding = count(*))
848+
( WHERE id == "2" )
849+
| SORT _fork
850+
""";
851+
var e = expectThrows(VerificationException.class, () -> run(firstQuery));
852+
assertTrue(e.getMessage().contains("Column [embedding] has conflicting data types"));
853+
854+
var secondQuery = """
855+
FROM test-other
856+
| FORK
857+
( WHERE id == "2" )
858+
( STATS embedding = count(*))
859+
| SORT _fork
860+
""";
861+
e = expectThrows(VerificationException.class, () -> run(secondQuery));
862+
assertTrue(e.getMessage().contains("Column [embedding] has conflicting data types"));
863+
864+
var thirdQuery = """
865+
FROM test-other
866+
| FORK
867+
( WHERE id == "2" )
868+
( WHERE id == "3" )
869+
( STATS embedding = count(*))
870+
| SORT _fork
871+
""";
872+
e = expectThrows(VerificationException.class, () -> run(thirdQuery));
873+
assertTrue(e.getMessage().contains("Column [embedding] has conflicting data types"));
874+
}
875+
803876
public void testWithEvalWithConflictingTypes() {
804877
var query = """
805878
FROM test
@@ -976,12 +1049,21 @@ private void createAndPopulateIndices() {
9761049

9771050
createRequest = client.prepareCreate(otherTestIndex)
9781051
.setSettings(Settings.builder().put("index.number_of_shards", 1))
979-
.setMapping("id", "type=keyword", "content", "type=keyword");
1052+
.setMapping("id", "type=keyword", "content", "type=keyword", "embedding", "type=sparse_vector");
9801053
assertAcked(createRequest);
9811054
client().prepareBulk()
982-
.add(new IndexRequest(otherTestIndex).id("1").source("id", "1", "content", "This is a brown fox"))
983-
.add(new IndexRequest(otherTestIndex).id("2").source("id", "2", "content", "This is a brown dog"))
984-
.add(new IndexRequest(otherTestIndex).id("3").source("id", "3", "content", "This dog is really brown"))
1055+
.add(
1056+
new IndexRequest(otherTestIndex).id("1")
1057+
.source("id", "1", "content", "This is a brown fox", "embedding", Map.of("abc", 1.0))
1058+
)
1059+
.add(
1060+
new IndexRequest(otherTestIndex).id("2")
1061+
.source("id", "2", "content", "This is a brown dog", "embedding", Map.of("def", 2.0))
1062+
)
1063+
.add(
1064+
new IndexRequest(otherTestIndex).id("3")
1065+
.source("id", "3", "content", "This dog is really brown", "embedding", Map.of("ghi", 1.0))
1066+
)
9851067
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
9861068
.get();
9871069
ensureYellow(indexName);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,7 @@ private LogicalPlan resolveFork(Fork fork, AnalyzerContext context) {
760760
List<LogicalPlan> newSubPlans = new ArrayList<>();
761761
List<Attribute> outputUnion = Fork.outputUnion(fork.children());
762762
List<String> forkColumns = outputUnion.stream().map(Attribute::name).toList();
763+
Set<String> unsupportedAttributeNames = Fork.outputUnsupportedAttributeNames(fork.children());
763764

764765
for (LogicalPlan logicalPlan : fork.children()) {
765766
Source source = logicalPlan.source();
@@ -773,7 +774,12 @@ private LogicalPlan resolveFork(Fork fork, AnalyzerContext context) {
773774
}
774775
}
775776

776-
List<Alias> aliases = missing.stream().map(attr -> new Alias(source, attr.name(), Literal.of(attr, null))).toList();
777+
List<Alias> aliases = missing.stream().map(attr -> {
778+
// We cannot assign an alias with an UNSUPPORTED data type, so we use another type that is
779+
// supported. This way we can add this missing column containing only null values to the fork branch output.
780+
var attrType = attr.dataType() == UNSUPPORTED ? KEYWORD : attr.dataType();
781+
return new Alias(source, attr.name(), new Literal(attr.source(), null, attrType));
782+
}).toList();
777783

778784
// add the missing columns
779785
if (aliases.size() > 0) {
@@ -785,7 +791,6 @@ private LogicalPlan resolveFork(Fork fork, AnalyzerContext context) {
785791
// We need to add an explicit Keep even if the outputs align
786792
// This is because at the moment the sub plans are executed and optimized separately and the output might change
787793
// during optimizations. Once we add streaming we might not need to add a Keep when the outputs already align.
788-
// Note that until we add explicit support for KEEP in FORK branches, this condition will always be true.
789794
if (logicalPlan instanceof Keep == false || subPlanColumns.equals(forkColumns) == false) {
790795
changed = true;
791796
List<Attribute> newOutput = new ArrayList<>();
@@ -810,7 +815,7 @@ private LogicalPlan resolveFork(Fork fork, AnalyzerContext context) {
810815

811816
// We don't want to keep the same attributes that are outputted by the FORK branches.
812817
// Keeping the same attributes can have unintended side effects when applying optimizations like constant folding.
813-
for (Attribute attr : newSubPlans.getFirst().output()) {
818+
for (Attribute attr : outputUnion) {
814819
newOutput.add(new ReferenceAttribute(attr.source(), attr.name(), attr.dataType()));
815820
}
816821

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PlanConsistencyChecker.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,14 @@ private static void checkMissingFork(QueryPlan<?> plan, Failures failures) {
7777

7878
private static void checkMissingForkBranch(QueryPlan<?> plan, AttributeSet forkOutputSet, Failures failures) {
7979
Map<String, DataType> attributeTypes = forkOutputSet.stream().collect(Collectors.toMap(Attribute::name, Attribute::dataType));
80-
AttributeSet missing = AttributeSet.of();
80+
Set<Attribute> missing = new HashSet<>();
8181

8282
Set<String> commonAttrs = new HashSet<>();
8383

8484
// get the missing attributes from the sub plan
8585
plan.output().forEach(attribute -> {
8686
var attrType = attributeTypes.get(attribute.name());
87-
if (attrType == null || attrType != attribute.dataType()) {
87+
if (attrType == null || (attrType != attribute.dataType() && attrType != DataType.UNSUPPORTED)) {
8888
missing.add(attribute);
8989
}
9090
commonAttrs.add(attribute.name());
@@ -99,7 +99,12 @@ private static void checkMissingForkBranch(QueryPlan<?> plan, AttributeSet forkO
9999

100100
if (missing.isEmpty() == false) {
101101
failures.add(
102-
fail(plan, "Plan [{}] optimized incorrectly due to missing attributes in subplans", plan.nodeString(), missing.toString())
102+
fail(
103+
plan,
104+
"Plan [{}] optimized incorrectly due to missing attributes in subplans: [{}]",
105+
plan.nodeString(),
106+
missing.toString()
107+
)
103108
);
104109
}
105110
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Fork.java

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,18 @@ public List<Attribute> output() {
103103
public static List<Attribute> outputUnion(List<LogicalPlan> subplans) {
104104
List<Attribute> output = new ArrayList<>();
105105
Set<String> names = new HashSet<>();
106+
// these are attribute names we know should have an UNSUPPORTED data type in the FORK output
107+
Set<String> unsupportedAttributesNames = outputUnsupportedAttributeNames(subplans);
106108

107109
for (var subPlan : subplans) {
108110
for (var attr : subPlan.output()) {
111+
// When we have multiple attributes with the same name, the ones that have a supported data type take priority.
112+
// We only add an attribute with an unsupported data type if we know that in the output of the rest of the FORK branches
113+
// there exists no attribute with the same name and with a supported data type.
114+
if (attr.dataType() == DataType.UNSUPPORTED && unsupportedAttributesNames.contains(attr.name()) == false) {
115+
continue;
116+
}
117+
109118
if (names.contains(attr.name()) == false && attr.name().equals(Analyzer.NO_FIELDS_NAME) == false) {
110119
names.add(attr.name());
111120
output.add(attr);
@@ -115,6 +124,34 @@ public static List<Attribute> outputUnion(List<LogicalPlan> subplans) {
115124
return output;
116125
}
117126

127+
/**
128+
* Returns a list of attribute names that will need to have the @{code UNSUPPORTED} data type in FORK output.
129+
* These are attributes that are either {@code UNSUPPORTED} or missing in each FORK branch.
130+
* If two branches have the same attribute name, but only in one of them the data type is {@code UNSUPPORTED}, this constitutes
131+
* data type conflict, and so this attribute name will not be returned by this function.
132+
* Data type conflicts are later on checked in {@code postAnalysisPlanVerification}.
133+
*/
134+
public static Set<String> outputUnsupportedAttributeNames(List<LogicalPlan> subplans) {
135+
Set<String> unsupportedAttributes = new HashSet<>();
136+
Set<String> names = new HashSet<>();
137+
138+
for (var subPlan : subplans) {
139+
for (var attr : subPlan.output()) {
140+
var attrName = attr.name();
141+
if (unsupportedAttributes.contains(attrName) == false
142+
&& attr.dataType() == DataType.UNSUPPORTED
143+
&& names.contains(attrName) == false) {
144+
unsupportedAttributes.add(attrName);
145+
} else if (unsupportedAttributes.contains(attrName) && attr.dataType() != DataType.UNSUPPORTED) {
146+
unsupportedAttributes.remove(attrName);
147+
}
148+
names.add(attrName);
149+
}
150+
}
151+
152+
return unsupportedAttributes;
153+
}
154+
118155
@Override
119156
public int hashCode() {
120157
return Objects.hash(Fork.class, children());
@@ -152,16 +189,20 @@ private static void checkFork(LogicalPlan plan, Failures failures) {
152189
failures.add(Failure.fail(otherFork, "Only a single FORK command is allowed, but found multiple"));
153190
});
154191

155-
Map<String, DataType> outputTypes = fork.children()
156-
.getFirst()
157-
.output()
158-
.stream()
159-
.collect(Collectors.toMap(Attribute::name, Attribute::dataType));
192+
Map<String, DataType> outputTypes = fork.output().stream().collect(Collectors.toMap(Attribute::name, Attribute::dataType));
160193

161-
fork.children().stream().skip(1).forEach(subPlan -> {
194+
fork.children().forEach(subPlan -> {
162195
for (Attribute attr : subPlan.output()) {
163-
var actual = attr.dataType();
164196
var expected = outputTypes.get(attr.name());
197+
198+
// If the FORK output has an UNSUPPORTED data type, we know there is no conflict.
199+
// We only assign an UNSUPPORTED attribute in the FORK output when there exists no attribute with the
200+
// same name and supported data type in any of the FORK branches.
201+
if (expected == DataType.UNSUPPORTED) {
202+
continue;
203+
}
204+
205+
var actual = attr.dataType();
165206
if (actual != expected) {
166207
failures.add(
167208
Failure.fail(

0 commit comments

Comments
 (0)