Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ object SortOrder {
case class SortPrefix(child: SortOrder) extends UnaryExpression {

val nullValue = child.child.dataType match {
case BooleanType | DateType | TimestampType | TimestampNTZType |
case BooleanType | DateType | TimestampType | TimestampNTZType | _: TimeType |
_: IntegralType | _: AnsiIntervalType =>
if (nullAsSmallest) Long.MinValue else Long.MaxValue
case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
Expand All @@ -151,7 +151,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
private lazy val calcPrefix: Any => Long = child.child.dataType match {
case BooleanType => (raw) =>
if (raw.asInstanceOf[Boolean]) 1 else 0
case DateType | TimestampType | TimestampNTZType |
case DateType | TimestampType | TimestampNTZType | _: TimeType |
_: IntegralType | _: AnsiIntervalType => (raw) =>
raw.asInstanceOf[java.lang.Number].longValue()
case FloatType | DoubleType => (raw) => {
Expand Down Expand Up @@ -202,7 +202,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
s"$input ? 1L : 0L"
case _: IntegralType =>
s"(long) $input"
case DateType | TimestampType | TimestampNTZType | _: AnsiIntervalType =>
case DateType | TimestampType | TimestampNTZType | _: TimeType | _: AnsiIntervalType =>
s"(long) $input"
case FloatType | DoubleType =>
s"$DoublePrefixCmp.computePrefix((double)$input)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
package org.apache.spark.sql.catalyst.expressions

import java.sql.Timestamp
import java.time.LocalTime

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._
Expand Down Expand Up @@ -51,6 +53,9 @@ class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
val dec3 = Literal(Decimal(20132983L, 21, 2))
val list1 = Literal.create(Seq(1, 2), ArrayType(IntegerType))
val nullVal = Literal.create(null, IntegerType)
val tm1LocalTime = LocalTime.of(21, 15, 1, 123456)
val tm1Nano = SparkDateTimeUtils.localTimeToNanos(tm1LocalTime)
val tm1 = Literal.create(tm1LocalTime, TimeType(6))

checkEvaluation(SortPrefix(SortOrder(b1, Ascending)), 0L)
checkEvaluation(SortPrefix(SortOrder(b2, Ascending)), 1L)
Expand Down Expand Up @@ -83,6 +88,7 @@ class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
DoublePrefixComparator.computePrefix(201329.83d))
checkEvaluation(SortPrefix(SortOrder(list1, Ascending)), 0L)
checkEvaluation(SortPrefix(SortOrder(nullVal, Ascending)), null)
checkEvaluation(SortPrefix(SortOrder(tm1, Ascending)), tm1Nano)
}

test("Cannot sort map type") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ object SortPrefixUtils {
case StringType => stringPrefixComparator(sortOrder)
case BinaryType => binaryPrefixComparator(sortOrder)
case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType | TimestampType |
TimestampNTZType | _: AnsiIntervalType =>
TimestampNTZType | _: TimeType |_: AnsiIntervalType =>
longPrefixComparator(sortOrder)
case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
longPrefixComparator(sortOrder)
Expand Down Expand Up @@ -123,7 +123,8 @@ object SortPrefixUtils {
def canSortFullyWithPrefix(sortOrder: SortOrder): Boolean = {
sortOrder.dataType match {
case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType |
TimestampType | TimestampNTZType | FloatType | DoubleType | _: AnsiIntervalType =>
TimestampType | TimestampNTZType | _: TimeType | FloatType | DoubleType |
_: AnsiIntervalType =>
true
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
true
Expand Down