diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeUnion.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeUnion.java index 637a454c213..d06fa31661a 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeUnion.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeUnion.java @@ -22,8 +22,10 @@ import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.linq4j.tree.ParameterExpression; import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelTrait; import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelCollationTraitDef; import org.apache.calcite.rel.RelNode; import org.apache.calcite.util.BuiltInMethod; import org.apache.calcite.util.Pair; @@ -42,17 +44,20 @@ public class EnumerableMergeUnion extends EnumerableUnion { protected EnumerableMergeUnion(RelOptCluster cluster, RelTraitSet traitSet, List inputs, boolean all) { super(cluster, traitSet, inputs, all); - final RelCollation collation = traitSet.getCollation(); - if (collation == null || collation.getFieldCollations().isEmpty()) { + final List collations = traitSet.getCollations(); + if (collations.isEmpty() || collations.get(0).getFieldCollations().isEmpty()) { throw new IllegalArgumentException("EnumerableMergeUnion with no collation"); } for (RelNode input : inputs) { - final RelCollation inputCollation = input.getTraitSet().getCollation(); - if (inputCollation == null || !inputCollation.satisfies(collation)) { - throw new IllegalArgumentException("EnumerableMergeUnion input does " - + "not satisfy collation. EnumerableMergeUnion collation: " - + collation + ". Input collation: " + inputCollation + ". Input: " - + input); + final RelTrait inputCollationTrait = + input.getTraitSet().getTrait(RelCollationTraitDef.INSTANCE); + for (RelCollation collation : collations) { + if (inputCollationTrait == null || !inputCollationTrait.satisfies(collation)) { + throw new IllegalArgumentException("EnumerableMergeUnion input does " + + "not satisfy collation. EnumerableMergeUnion collation: " + + collation + ". Input collation: " + inputCollationTrait + ". Input: " + + input); + } } } } diff --git a/core/src/main/java/org/apache/calcite/plan/RelTraitSet.java b/core/src/main/java/org/apache/calcite/plan/RelTraitSet.java index b91b660752e..65747426062 100644 --- a/core/src/main/java/org/apache/calcite/plan/RelTraitSet.java +++ b/core/src/main/java/org/apache/calcite/plan/RelTraitSet.java @@ -392,6 +392,22 @@ public RelTraitSet getDefaultSansConvention() { return (@Nullable T) getTrait(RelCollationTraitDef.INSTANCE); } + /** + * Returns {@link RelCollation} traits defined by + * {@link RelCollationTraitDef#INSTANCE}. + */ + @SuppressWarnings("unchecked") + public List getCollations() { + RelCollation trait = getTrait(RelCollationTraitDef.INSTANCE); + if (trait == null) { + return ImmutableList.of(); + } + if (trait instanceof RelCompositeTrait) { + return ((RelCompositeTrait) trait).traitList(); + } + return ImmutableList.of(trait); + } + /** * Returns the size of the RelTraitSet. * diff --git a/core/src/test/resources/sql/sort.iq b/core/src/test/resources/sql/sort.iq index fb69970e41b..b9e0412ff98 100644 --- a/core/src/test/resources/sql/sort.iq +++ b/core/src/test/resources/sql/sort.iq @@ -465,4 +465,47 @@ order by arr desc nulls first; !ok +# [CALCITE-7374] NULLS LAST throws ClassCastException when sorting arrays +select * from +(values + (2, array[null, 3]), + (3, array[3, 4]), + (1, array[1, 2]), + (4, array[4, 5]), + (5, cast(null as integer array))) as t(id, arr) +order by arr nulls last; ++----+-----------+ +| ID | ARR | ++----+-----------+ +| 1 | [1, 2] | +| 3 | [3, 4] | +| 4 | [4, 5] | +| 2 | [null, 3] | +| 5 | | ++----+-----------+ +(5 rows) + +!ok + +select * from +(values + (2, array[null, 3]), + (3, array[3, 4]), + (1, array[1, 2]), + (4, array[4, 5]), + (5, cast(null as integer array))) as t(id, arr) +order by arr desc nulls last; ++----+-----------+ +| ID | ARR | ++----+-----------+ +| 2 | [null, 3] | +| 4 | [4, 5] | +| 3 | [3, 4] | +| 1 | [1, 2] | +| 5 | | ++----+-----------+ +(5 rows) + +!ok + # End sort.iq diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/function/Functions.java b/linq4j/src/main/java/org/apache/calcite/linq4j/function/Functions.java index 73a06f42bd8..6e9a186344c 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/function/Functions.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/function/Functions.java @@ -556,7 +556,8 @@ private static class NullsFirstComparator } else if (o1 instanceof Object[] && o2 instanceof Object[]) { return compareObjectArrays((Object[]) o1, (Object[]) o2); } else { - throw new IllegalArgumentException(); + throw new IllegalArgumentException("Item types do not match: " + + o1.getClass() + " vs " + o2.getClass()); } } } @@ -582,7 +583,8 @@ private static class NullsLastComparator } else if (o1 instanceof Object[] && o2 instanceof Object[]) { return compareObjectArrays((Object[]) o1, (Object[]) o2); } else { - throw new IllegalArgumentException(); + throw new IllegalArgumentException("Item types do not match: " + + o1.getClass() + " vs " + o2.getClass()); } } } @@ -590,7 +592,7 @@ private static class NullsLastComparator /** Nulls first reverse comparator. */ private static class NullsFirstReverseComparator implements Comparator, Serializable { - @Override public int compare(Object o1, Object o2) { + @Override public int compare(@Nullable Object o1, @Nullable Object o2) { if (o1 == o2) { return 0; } @@ -608,7 +610,8 @@ private static class NullsFirstReverseComparator } else if (o1 instanceof Object[] && o2 instanceof Object[]) { return -compareObjectArrays((Object[]) o1, (Object[]) o2); } else { - throw new IllegalArgumentException(); + throw new IllegalArgumentException("Item types do not match: " + + o1.getClass() + " vs " + o2.getClass()); } } } @@ -707,8 +710,8 @@ public static int compareObjectArrays(@Nullable Object @Nullable [] b0, /** Nulls last reverse comparator. */ private static class NullsLastReverseComparator - implements Comparator, Serializable { - @Override public int compare(Comparable o1, Comparable o2) { + implements Comparator, Serializable { + @Override public int compare(@Nullable Object o1, @Nullable Object o2) { if (o1 == o2) { return 0; } @@ -718,8 +721,17 @@ private static class NullsLastReverseComparator if (o2 == null) { return -1; } - //noinspection unchecked - return -o1.compareTo(o2); + if (o1 instanceof Comparable && o2 instanceof Comparable) { + //noinspection unchecked + return -((Comparable) o1).compareTo(o2); + } else if (o1 instanceof List && o2 instanceof List) { + return -compareLists((List) o1, (List) o2); + } else if (o1 instanceof Object[] && o2 instanceof Object[]) { + return -compareObjectArrays((Object[]) o1, (Object[]) o2); + } else { + throw new IllegalArgumentException("Item types do not match: " + + o1.getClass() + " vs " + o2.getClass()); + } } }