diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 824024a84cbad..166866c90b877 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -128,7 +128,7 @@ object SortOrder { case class SortPrefix(child: SortOrder) extends UnaryExpression { val nullValue = child.child.dataType match { - case BooleanType | DateType | TimestampType | TimestampNTZType | + case BooleanType | DateType | TimestampType | TimestampNTZType | _: TimeType | _: IntegralType | _: AnsiIntervalType => if (nullAsSmallest) Long.MinValue else Long.MaxValue case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => @@ -151,7 +151,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { private lazy val calcPrefix: Any => Long = child.child.dataType match { case BooleanType => (raw) => if (raw.asInstanceOf[Boolean]) 1 else 0 - case DateType | TimestampType | TimestampNTZType | + case DateType | TimestampType | TimestampNTZType | _: TimeType | _: IntegralType | _: AnsiIntervalType => (raw) => raw.asInstanceOf[java.lang.Number].longValue() case FloatType | DoubleType => (raw) => { @@ -202,7 +202,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { s"$input ? 1L : 0L" case _: IntegralType => s"(long) $input" - case DateType | TimestampType | TimestampNTZType | _: AnsiIntervalType => + case DateType | TimestampType | TimestampNTZType | _: TimeType | _: AnsiIntervalType => s"(long) $input" case FloatType | DoubleType => s"$DoublePrefixCmp.computePrefix((double)$input)" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala index 9332ef5595325..80bb16d72f6f7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala @@ -18,9 +18,11 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.Timestamp +import java.time.LocalTime import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ @@ -51,6 +53,9 @@ class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val dec3 = Literal(Decimal(20132983L, 21, 2)) val list1 = Literal.create(Seq(1, 2), ArrayType(IntegerType)) val nullVal = Literal.create(null, IntegerType) + val tm1LocalTime = LocalTime.of(21, 15, 1, 123456) + val tm1Nano = SparkDateTimeUtils.localTimeToNanos(tm1LocalTime) + val tm1 = Literal.create(tm1LocalTime, TimeType(6)) checkEvaluation(SortPrefix(SortOrder(b1, Ascending)), 0L) checkEvaluation(SortPrefix(SortOrder(b2, Ascending)), 1L) @@ -83,6 +88,7 @@ class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper DoublePrefixComparator.computePrefix(201329.83d)) checkEvaluation(SortPrefix(SortOrder(list1, Ascending)), 0L) checkEvaluation(SortPrefix(SortOrder(nullVal, Ascending)), null) + checkEvaluation(SortPrefix(SortOrder(tm1, Ascending)), tm1Nano) } test("Cannot sort map type") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 4b561b813067e..7332bbcb18454 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -43,7 +43,7 @@ object SortPrefixUtils { case StringType => stringPrefixComparator(sortOrder) case BinaryType => binaryPrefixComparator(sortOrder) case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType | TimestampType | - TimestampNTZType | _: AnsiIntervalType => + TimestampNTZType | _: TimeType |_: AnsiIntervalType => longPrefixComparator(sortOrder) case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => longPrefixComparator(sortOrder) @@ -123,7 +123,8 @@ object SortPrefixUtils { def canSortFullyWithPrefix(sortOrder: SortOrder): Boolean = { sortOrder.dataType match { case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType | - TimestampType | TimestampNTZType | FloatType | DoubleType | _: AnsiIntervalType => + TimestampType | TimestampNTZType | _: TimeType | FloatType | DoubleType | + _: AnsiIntervalType => true case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => true