Skip to content

Commit 5fd5e1c

Browse files
author
Srikanth Padakanti
committed
PPL: add union command with type coercion and UNION ALL semantics
Signed-off-by: Srikanth Padakanti <srikanth_padakanti@apple.com>
1 parent 90393bf commit 5fd5e1c

File tree

21 files changed

+1574
-18
lines changed

21 files changed

+1574
-18
lines changed

core/src/main/java/org/opensearch/sql/analysis/Analyzer.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
import org.opensearch.sql.ast.tree.TableFunction;
107107
import org.opensearch.sql.ast.tree.Transpose;
108108
import org.opensearch.sql.ast.tree.Trendline;
109+
import org.opensearch.sql.ast.tree.Union;
109110
import org.opensearch.sql.ast.tree.UnresolvedPlan;
110111
import org.opensearch.sql.ast.tree.Values;
111112
import org.opensearch.sql.ast.tree.Window;
@@ -897,6 +898,11 @@ public LogicalPlan visitMultisearch(Multisearch node, AnalysisContext context) {
897898
throw getOnlyForCalciteException("Multisearch");
898899
}
899900

901+
@Override
902+
public LogicalPlan visitUnion(Union node, AnalysisContext context) {
903+
throw getOnlyForCalciteException("Union");
904+
}
905+
900906
private LogicalSort buildSort(
901907
LogicalPlan child, AnalysisContext context, Integer count, List<Field> sortFields) {
902908
ExpressionReferenceOptimizer optimizer =

core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
import org.opensearch.sql.ast.tree.TableFunction;
9494
import org.opensearch.sql.ast.tree.Transpose;
9595
import org.opensearch.sql.ast.tree.Trendline;
96+
import org.opensearch.sql.ast.tree.Union;
9697
import org.opensearch.sql.ast.tree.Values;
9798
import org.opensearch.sql.ast.tree.Window;
9899

@@ -472,6 +473,10 @@ public T visitMultisearch(Multisearch node, C context) {
472473
return visitChildren(node, context);
473474
}
474475

476+
public T visitUnion(Union node, C context) {
477+
return visitChildren(node, context);
478+
}
479+
475480
public T visitAddTotals(AddTotals node, C context) {
476481
return visitChildren(node, context);
477482
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.ast.tree;
7+
8+
import com.google.common.collect.ImmutableList;
9+
import java.util.List;
10+
import lombok.AllArgsConstructor;
11+
import lombok.EqualsAndHashCode;
12+
import lombok.Getter;
13+
import lombok.RequiredArgsConstructor;
14+
import lombok.ToString;
15+
import org.opensearch.sql.ast.AbstractNodeVisitor;
16+
17+
/** Logical plan node for Union operation. Combines results from multiple datasets (UNION ALL). */
18+
@Getter
19+
@ToString
20+
@EqualsAndHashCode(callSuper = false)
21+
@RequiredArgsConstructor
22+
@AllArgsConstructor
23+
public class Union extends UnresolvedPlan {
24+
private final List<UnresolvedPlan> datasets;
25+
26+
private Integer maxout;
27+
28+
@Override
29+
public UnresolvedPlan attach(UnresolvedPlan child) {
30+
List<UnresolvedPlan> newDatasets =
31+
ImmutableList.<UnresolvedPlan>builder().add(child).addAll(datasets).build();
32+
return new Union(newDatasets, maxout);
33+
}
34+
35+
@Override
36+
public List<? extends UnresolvedPlan> getChild() {
37+
return datasets;
38+
}
39+
40+
@Override
41+
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
42+
return nodeVisitor.visitUnion(this, context);
43+
}
44+
}

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@
152152
import org.opensearch.sql.ast.tree.TableFunction;
153153
import org.opensearch.sql.ast.tree.Trendline;
154154
import org.opensearch.sql.ast.tree.Trendline.TrendlineType;
155+
import org.opensearch.sql.ast.tree.Union;
155156
import org.opensearch.sql.ast.tree.UnresolvedPlan;
156157
import org.opensearch.sql.ast.tree.Values;
157158
import org.opensearch.sql.ast.tree.Window;
@@ -2445,6 +2446,40 @@ private String findTimestampField(RelDataType rowType) {
24452446
return null;
24462447
}
24472448

2449+
@Override
2450+
public RelNode visitUnion(Union node, CalcitePlanContext context) {
2451+
List<RelNode> inputNodes = new ArrayList<>();
2452+
2453+
for (UnresolvedPlan dataset : node.getDatasets()) {
2454+
UnresolvedPlan prunedDataset = dataset.accept(new EmptySourcePropagateVisitor(), null);
2455+
prunedDataset.accept(this, context);
2456+
inputNodes.add(context.relBuilder.build());
2457+
}
2458+
2459+
if (inputNodes.size() < 2) {
2460+
throw new IllegalArgumentException(
2461+
"Union command requires at least two datasets. Provided: " + inputNodes.size());
2462+
}
2463+
2464+
List<RelNode> unifiedInputs =
2465+
SchemaUnifier.buildUnifiedSchemaWithTypeCoercion(inputNodes, context);
2466+
2467+
for (RelNode input : unifiedInputs) {
2468+
context.relBuilder.push(input);
2469+
}
2470+
context.relBuilder.union(true, unifiedInputs.size()); // true = UNION ALL
2471+
2472+
if (node.getMaxout() != null) {
2473+
context.relBuilder.push(
2474+
LogicalSystemLimit.create(
2475+
LogicalSystemLimit.SystemLimitType.SUBSEARCH_MAXOUT,
2476+
context.relBuilder.build(),
2477+
context.relBuilder.literal(node.getMaxout())));
2478+
}
2479+
2480+
return context.relBuilder.peek();
2481+
}
2482+
24482483
/*
24492484
* Unsupported Commands of PPL with Calcite for OpenSearch 3.0.0-beta
24502485
*/

core/src/main/java/org/opensearch/sql/calcite/SchemaUnifier.java

Lines changed: 240 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,16 @@
1414
import org.apache.calcite.rel.type.RelDataType;
1515
import org.apache.calcite.rel.type.RelDataTypeField;
1616
import org.apache.calcite.rex.RexNode;
17+
import org.apache.calcite.sql.type.SqlTypeName;
1718

1819
/**
19-
* Utility class for unifying schemas across multiple RelNodes. Throws an exception when type
20-
* conflicts are detected.
20+
* Utility class for unifying schemas across multiple RelNodes. Supports two strategies:
21+
*
22+
* <ul>
23+
* <li>Conflict resolution (multisearch): throws on type mismatch, fills missing fields with NULL
24+
* <li>Type coercion (union): widens compatible types (e.g. INTEGER→BIGINT), falls back to VARCHAR
25+
* for incompatible types, fills missing fields with NULL
26+
* </ul>
2127
*/
2228
public class SchemaUnifier {
2329

@@ -147,4 +153,236 @@ RelDataType getType() {
147153
return type;
148154
}
149155
}
156+
157+
/**
158+
* Builds unified schema with type coercion for UNION command. Coerces compatible types to a
159+
* common supertype (e.g. int+float→float), falls back to VARCHAR for incompatible types, and
160+
* fills missing fields with NULL.
161+
*/
162+
public static List<RelNode> buildUnifiedSchemaWithTypeCoercion(
163+
List<RelNode> inputs, CalcitePlanContext context) {
164+
if (inputs.isEmpty() || inputs.size() == 1) {
165+
return inputs;
166+
}
167+
168+
List<RelNode> coercedInputs = coerceUnionTypes(inputs, context);
169+
return unifySchemasForUnion(coercedInputs, context);
170+
}
171+
172+
/**
173+
* Aligns schemas by projecting NULL for missing fields and CAST for type mismatches. Uses
174+
* force=true to clear collation traits and prevent EnumerableMergeUnion cast exception.
175+
*/
176+
private static List<RelNode> unifySchemasForUnion(
177+
List<RelNode> inputs, CalcitePlanContext context) {
178+
List<SchemaField> unifiedSchema = buildUnifiedSchemaForUnion(inputs);
179+
List<String> fieldNames =
180+
unifiedSchema.stream().map(SchemaField::getName).collect(Collectors.toList());
181+
182+
List<RelNode> projectedNodes = new ArrayList<>();
183+
for (RelNode node : inputs) {
184+
List<RexNode> projection = buildProjectionForUnion(node, unifiedSchema, context);
185+
RelNode projectedNode =
186+
context.relBuilder.push(node).project(projection, fieldNames, true).build();
187+
projectedNodes.add(projectedNode);
188+
}
189+
return projectedNodes;
190+
}
191+
192+
private static List<SchemaField> buildUnifiedSchemaForUnion(List<RelNode> nodes) {
193+
List<SchemaField> schema = new ArrayList<>();
194+
Map<String, RelDataType> seenFields = new HashMap<>();
195+
196+
for (RelNode node : nodes) {
197+
for (RelDataTypeField field : node.getRowType().getFieldList()) {
198+
if (!seenFields.containsKey(field.getName())) {
199+
schema.add(new SchemaField(field.getName(), field.getType()));
200+
seenFields.put(field.getName(), field.getType());
201+
}
202+
}
203+
}
204+
return schema;
205+
}
206+
207+
private static List<RexNode> buildProjectionForUnion(
208+
RelNode node, List<SchemaField> unifiedSchema, CalcitePlanContext context) {
209+
Map<String, RelDataTypeField> nodeFieldMap =
210+
node.getRowType().getFieldList().stream()
211+
.collect(Collectors.toMap(RelDataTypeField::getName, field -> field));
212+
213+
List<RexNode> projection = new ArrayList<>();
214+
for (SchemaField schemaField : unifiedSchema) {
215+
RelDataTypeField nodeField = nodeFieldMap.get(schemaField.getName());
216+
217+
if (nodeField != null) {
218+
RexNode fieldRef = context.rexBuilder.makeInputRef(node, nodeField.getIndex());
219+
if (!nodeField.getType().equals(schemaField.getType())) {
220+
projection.add(context.rexBuilder.makeCast(schemaField.getType(), fieldRef));
221+
} else {
222+
projection.add(fieldRef);
223+
}
224+
} else {
225+
projection.add(context.rexBuilder.makeNullLiteral(schemaField.getType()));
226+
}
227+
}
228+
return projection;
229+
}
230+
231+
/** Casts fields to their common supertypes across all inputs when types differ. */
232+
private static List<RelNode> coerceUnionTypes(List<RelNode> inputs, CalcitePlanContext context) {
233+
Map<String, List<SqlTypeName>> fieldTypeMap = new HashMap<>();
234+
for (RelNode input : inputs) {
235+
for (RelDataTypeField field : input.getRowType().getFieldList()) {
236+
String fieldName = field.getName();
237+
SqlTypeName typeName = field.getType().getSqlTypeName();
238+
if (typeName != null) {
239+
fieldTypeMap.computeIfAbsent(fieldName, k -> new ArrayList<>()).add(typeName);
240+
}
241+
}
242+
}
243+
244+
Map<String, SqlTypeName> targetTypeMap = new HashMap<>();
245+
for (Map.Entry<String, List<SqlTypeName>> entry : fieldTypeMap.entrySet()) {
246+
String fieldName = entry.getKey();
247+
List<SqlTypeName> types = entry.getValue();
248+
249+
SqlTypeName commonType = types.getFirst();
250+
for (int i = 1; i < types.size(); i++) {
251+
commonType = findCommonTypeForUnion(commonType, types.get(i));
252+
}
253+
targetTypeMap.put(fieldName, commonType);
254+
}
255+
256+
boolean needsCoercion = false;
257+
for (RelNode input : inputs) {
258+
for (RelDataTypeField field : input.getRowType().getFieldList()) {
259+
SqlTypeName targetType = targetTypeMap.get(field.getName());
260+
if (targetType != null && field.getType().getSqlTypeName() != targetType) {
261+
needsCoercion = true;
262+
break;
263+
}
264+
}
265+
if (needsCoercion) break;
266+
}
267+
268+
if (!needsCoercion) {
269+
return inputs;
270+
}
271+
272+
List<RelNode> coercedInputs = new ArrayList<>();
273+
for (RelNode input : inputs) {
274+
List<RexNode> projections = new ArrayList<>();
275+
List<String> projectionNames = new ArrayList<>();
276+
boolean needsProjection = false;
277+
278+
for (RelDataTypeField field : input.getRowType().getFieldList()) {
279+
String fieldName = field.getName();
280+
SqlTypeName currentType = field.getType().getSqlTypeName();
281+
SqlTypeName targetType = targetTypeMap.get(fieldName);
282+
283+
RexNode fieldRef = context.rexBuilder.makeInputRef(input, field.getIndex());
284+
285+
if (currentType != targetType && targetType != null) {
286+
projections.add(context.relBuilder.cast(fieldRef, targetType));
287+
needsProjection = true;
288+
} else {
289+
projections.add(fieldRef);
290+
}
291+
projectionNames.add(fieldName);
292+
}
293+
294+
if (needsProjection) {
295+
context.relBuilder.push(input);
296+
context.relBuilder.project(projections, projectionNames, true);
297+
coercedInputs.add(context.relBuilder.build());
298+
} else {
299+
coercedInputs.add(input);
300+
}
301+
}
302+
303+
return coercedInputs;
304+
}
305+
306+
/**
307+
* Returns the wider type for two SqlTypeNames. Within the same family, returns the wider type
308+
* (e.g. INTEGER+BIGINT-->BIGINT). Across families, falls back to VARCHAR.
309+
*/
310+
private static SqlTypeName findCommonTypeForUnion(SqlTypeName type1, SqlTypeName type2) {
311+
if (type1 == type2) {
312+
return type1;
313+
}
314+
315+
if (type1 == SqlTypeName.NULL) {
316+
return type2;
317+
}
318+
if (type2 == SqlTypeName.NULL) {
319+
return type1;
320+
}
321+
322+
if (isNumericTypeForUnion(type1) && isNumericTypeForUnion(type2)) {
323+
return getWiderNumericTypeForUnion(type1, type2);
324+
}
325+
326+
if (isStringTypeForUnion(type1) && isStringTypeForUnion(type2)) {
327+
return SqlTypeName.VARCHAR;
328+
}
329+
330+
if (isTemporalTypeForUnion(type1) && isTemporalTypeForUnion(type2)) {
331+
return getWiderTemporalTypeForUnion(type1, type2);
332+
}
333+
334+
return SqlTypeName.VARCHAR;
335+
}
336+
337+
private static boolean isNumericTypeForUnion(SqlTypeName typeName) {
338+
return typeName == SqlTypeName.TINYINT
339+
|| typeName == SqlTypeName.SMALLINT
340+
|| typeName == SqlTypeName.INTEGER
341+
|| typeName == SqlTypeName.BIGINT
342+
|| typeName == SqlTypeName.FLOAT
343+
|| typeName == SqlTypeName.REAL
344+
|| typeName == SqlTypeName.DOUBLE
345+
|| typeName == SqlTypeName.DECIMAL;
346+
}
347+
348+
private static boolean isStringTypeForUnion(SqlTypeName typeName) {
349+
return typeName == SqlTypeName.CHAR || typeName == SqlTypeName.VARCHAR;
350+
}
351+
352+
private static boolean isTemporalTypeForUnion(SqlTypeName typeName) {
353+
return typeName == SqlTypeName.DATE
354+
|| typeName == SqlTypeName.TIMESTAMP
355+
|| typeName == SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE;
356+
}
357+
358+
private static SqlTypeName getWiderNumericTypeForUnion(SqlTypeName type1, SqlTypeName type2) {
359+
int rank1 = getNumericTypeRankForUnion(type1);
360+
int rank2 = getNumericTypeRankForUnion(type2);
361+
return rank1 >= rank2 ? type1 : type2;
362+
}
363+
364+
private static int getNumericTypeRankForUnion(SqlTypeName typeName) {
365+
return switch (typeName) {
366+
case TINYINT -> 1;
367+
case SMALLINT -> 2;
368+
case INTEGER -> 3;
369+
case BIGINT -> 4;
370+
case FLOAT -> 5;
371+
case REAL -> 6;
372+
case DOUBLE -> 7;
373+
case DECIMAL -> 8;
374+
default -> 0;
375+
};
376+
}
377+
378+
private static SqlTypeName getWiderTemporalTypeForUnion(SqlTypeName type1, SqlTypeName type2) {
379+
if (type1 == SqlTypeName.TIMESTAMP || type2 == SqlTypeName.TIMESTAMP) {
380+
return SqlTypeName.TIMESTAMP;
381+
}
382+
if (type1 == SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE
383+
|| type2 == SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE) {
384+
return SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE;
385+
}
386+
return SqlTypeName.DATE;
387+
}
150388
}

docs/category.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
"user/ppl/cmd/top.md",
4949
"user/ppl/cmd/trendline.md",
5050
"user/ppl/cmd/transpose.md",
51+
"user/ppl/cmd/union.md",
5152
"user/ppl/cmd/where.md",
5253
"user/ppl/functions/aggregations.md",
5354
"user/ppl/functions/collection.md",

0 commit comments

Comments
 (0)