Skip to content
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -624,14 +624,6 @@ object CometConf extends ShimCometConf {
.booleanConf
.createWithDefault(false)

val COMET_CAST_ALLOW_INCOMPATIBLE: ConfigEntry[Boolean] =
conf("spark.comet.cast.allowIncompatible")
.doc(
"Comet is not currently fully compatible with Spark for all cast operations. " +
s"Set this config to true to allow them anyway. $COMPAT_GUIDE.")
.booleanConf
.createWithDefault(false)

val COMET_REGEXP_ALLOW_INCOMPATIBLE: ConfigEntry[Boolean] =
conf("spark.comet.regexp.allowIncompatible")
.doc(
Expand Down
87 changes: 69 additions & 18 deletions spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,17 @@

package org.apache.comet.expressions

import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression}
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, NullType, StructType}

import org.apache.comet.serde.{Compatible, Incompatible, SupportLevel, Unsupported}
import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.{CometExpressionSerde, Compatible, ExprOuterClass, Incompatible, SupportLevel, Unsupported}
import org.apache.comet.serde.ExprOuterClass.Expr
import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProtoInternal, serializeDataType}
import org.apache.comet.shims.CometExprShim

object CometCast {
object CometCast extends CometExpressionSerde[Cast] with CometExprShim {

def supportedTypes: Seq[DataType] =
Seq(
Expand All @@ -42,6 +48,51 @@ object CometCast {
// TODO add DataTypes.TimestampNTZType for Spark 3.4 and later
// https://github.com/apache/datafusion-comet/issues/378

override def getSupportLevel(cast: Cast): SupportLevel = {
isSupported(cast.child.dataType, cast.dataType, cast.timeZoneId, evalMode(cast))
}

override def convert(
cast: Cast,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val childExpr = exprToProtoInternal(cast.child, inputs, binding)
if (childExpr.isDefined) {
castToProto(cast, cast.timeZoneId, cast.dataType, childExpr.get, evalMode(cast))
} else {
withInfo(cast, cast.child)
None
}
}

/**
* Wrap an already serialized expression in a cast.
*/
def castToProto(
expr: Expression,
timeZoneId: Option[String],
dt: DataType,
childExpr: Expr,
evalMode: CometEvalMode.Value): Option[Expr] = {
serializeDataType(dt) match {
case Some(dataType) =>
val castBuilder = ExprOuterClass.Cast.newBuilder()
castBuilder.setChild(childExpr)
castBuilder.setDatatype(dataType)
castBuilder.setEvalMode(evalModeToProto(evalMode))
castBuilder.setAllowIncompat(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.get())
castBuilder.setTimezone(timeZoneId.getOrElse("UTC"))
Some(
ExprOuterClass.Expr
.newBuilder()
.setCast(castBuilder)
.build())
case _ =>
withInfo(expr, s"Unsupported datatype in castToProto: $dt")
None
}
}

def isSupported(
fromType: DataType,
toType: DataType,
Expand All @@ -62,7 +113,7 @@ object CometCast {
case DataTypes.TimestampType | DataTypes.DateType | DataTypes.StringType =>
Incompatible()
case _ =>
Unsupported
Unsupported(Some(s"Cast from $fromType to $toType is not supported"))
}
case (_: DecimalType, _: DecimalType) =>
Compatible()
Expand Down Expand Up @@ -98,7 +149,7 @@ object CometCast {
}
}
Compatible()
case _ => Unsupported
case _ => Unsupported(Some(s"Cast from $fromType to $toType is not supported"))
}
}

Expand Down Expand Up @@ -136,7 +187,7 @@ object CometCast {
// https://github.com/apache/datafusion-comet/issues/328
Incompatible(Some("Not all valid formats are supported"))
case _ =>
Unsupported
Unsupported(Some(s"Cast from String to $toType is not supported"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can see lots of Unsupported(Some(s"Cast from String to $toType is not supported")), canthis be a helper function?

def unsupportedCast(from, to) {
   Unsupported(Some(s"Cast from $from to $to is not supported"))
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Thanks.

}
}

Expand Down Expand Up @@ -171,13 +222,13 @@ object CometCast {
isSupported(field.dataType, DataTypes.StringType, timeZoneId, evalMode) match {
case s: Incompatible =>
return s
case Unsupported =>
return Unsupported
case u: Unsupported =>
return u
case _ =>
}
}
Compatible()
case _ => Unsupported
case _ => Unsupported(Some(s"Cast from $fromType to String is not supported"))
}
}

Expand All @@ -187,21 +238,21 @@ object CometCast {
DataTypes.IntegerType =>
// https://github.com/apache/datafusion-comet/issues/352
// this seems like an edge case that isn't important for us to support
Unsupported
Unsupported(Some(s"Cast from Timestamp to $toType is not supported"))
case DataTypes.LongType =>
// https://github.com/apache/datafusion-comet/issues/352
Compatible()
case DataTypes.StringType => Compatible()
case DataTypes.DateType => Compatible()
case _ => Unsupported
case _ => Unsupported(Some(s"Cast from Timestamp to $toType is not supported"))
}
}

private def canCastFromBoolean(toType: DataType): SupportLevel = toType match {
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType |
DataTypes.FloatType | DataTypes.DoubleType =>
Compatible()
case _ => Unsupported
case _ => Unsupported(Some(s"Cast from Boolean to $toType is not supported"))
}

private def canCastFromByte(toType: DataType): SupportLevel = toType match {
Expand All @@ -212,7 +263,7 @@ object CometCast {
case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
Compatible()
case _ =>
Unsupported
Unsupported(Some(s"Cast from Byte to $toType is not supported"))
}

private def canCastFromShort(toType: DataType): SupportLevel = toType match {
Expand All @@ -223,7 +274,7 @@ object CometCast {
case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
Compatible()
case _ =>
Unsupported
Unsupported(Some(s"Cast from Short to $toType is not supported"))
}

private def canCastFromInt(toType: DataType): SupportLevel = toType match {
Expand All @@ -236,7 +287,7 @@ object CometCast {
case _: DecimalType =>
Incompatible(Some("No overflow check"))
case _ =>
Unsupported
Unsupported(Some(s"Cast from Int to $toType is not supported"))
}

private def canCastFromLong(toType: DataType): SupportLevel = toType match {
Expand All @@ -249,7 +300,7 @@ object CometCast {
case _: DecimalType =>
Incompatible(Some("No overflow check"))
case _ =>
Unsupported
Unsupported(Some(s"Cast from Long to $toType is not supported"))
}

private def canCastFromFloat(toType: DataType): SupportLevel = toType match {
Expand All @@ -259,7 +310,7 @@ object CometCast {
case _: DecimalType =>
// https://github.com/apache/datafusion-comet/issues/1371
Incompatible(Some("There can be rounding differences"))
case _ => Unsupported
case _ => Unsupported(Some(s"Cast from Float to $toType is not supported"))
}

private def canCastFromDouble(toType: DataType): SupportLevel = toType match {
Expand All @@ -269,14 +320,14 @@ object CometCast {
case _: DecimalType =>
// https://github.com/apache/datafusion-comet/issues/1371
Incompatible(Some("There can be rounding differences"))
case _ => Unsupported
case _ => Unsupported(Some(s"Cast from Double to $toType is not supported"))
}

private def canCastFromDecimal(toType: DataType): SupportLevel = toType match {
case DataTypes.FloatType | DataTypes.DoubleType | DataTypes.ByteType | DataTypes.ShortType |
DataTypes.IntegerType | DataTypes.LongType =>
Compatible()
case _ => Unsupported
case _ => Unsupported(Some(s"Cast from Decimal to $toType is not supported"))
}

}
Loading
Loading