Skip to content
Merged
10 changes: 10 additions & 0 deletions spark/src/main/scala/org/apache/comet/DataTypeSupport.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,14 @@ object DataTypeSupport {
case _: StructType | _: ArrayType | _: MapType => true
case _ => false
}

def hasTemporalType(t: DataType): Boolean = t match {
case DataTypes.DateType | DataTypes.TimestampType | DataTypes.TimestampNTZType =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering if interval can also be considered as temporal? Probably not, it is more like duration rather than representing any point of time

true
case t: StructType => t.exists(f => hasTemporalType(f.dataType))
case t: ArrayType => hasTemporalType(t.elementType)
case t: MapType => hasTemporalType(t.keyType) || hasTemporalType(t.valueType)
case _ => false
}

}
24 changes: 1 addition & 23 deletions spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -592,34 +592,12 @@ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with Com
val partitionSchemaSupported =
typeChecker.isSchemaSupported(partitionSchema, fallbackReasons)

def hasUnsupportedType(dataType: DataType): Boolean = {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check is no longer needed

dataType match {
case s: StructType => s.exists(field => hasUnsupportedType(field.dataType))
case a: ArrayType => hasUnsupportedType(a.elementType)
case m: MapType =>
// maps containing complex types are not supported
isComplexType(m.keyType) || isComplexType(m.valueType) ||
hasUnsupportedType(m.keyType) || hasUnsupportedType(m.valueType)
case dt if isStringCollationType(dt) => true
case _ => false
}
}

val knownIssues =
scanExec.requiredSchema.exists(field => hasUnsupportedType(field.dataType)) ||
partitionSchema.exists(field => hasUnsupportedType(field.dataType))

if (knownIssues) {
fallbackReasons += "Schema contains data types that are not supported by " +
s"$SCAN_NATIVE_ICEBERG_COMPAT"
}

val cometExecEnabled = COMET_EXEC_ENABLED.get()
if (!cometExecEnabled) {
fallbackReasons += s"$SCAN_NATIVE_ICEBERG_COMPAT requires ${COMET_EXEC_ENABLED.key}=true"
}

if (cometExecEnabled && schemaSupported && partitionSchemaSupported && !knownIssues &&
if (cometExecEnabled && schemaSupported && partitionSchemaSupported &&
fallbackReasons.isEmpty) {
logInfo(s"Auto scan mode selecting $SCAN_NATIVE_ICEBERG_COMPAT")
SCAN_NATIVE_ICEBERG_COMPAT
Expand Down
44 changes: 37 additions & 7 deletions spark/src/test/scala/org/apache/comet/CometFuzzTestBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,15 @@ import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.internal.SQLConf

import org.apache.comet.testing.{DataGenOptions, ParquetGenerator, SchemaGenOptions}
import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator, ParquetGenerator, SchemaGenOptions}

class CometFuzzTestBase extends CometTestBase with AdaptiveSparkPlanHelper {

var filename: String = null

/** Filename for data file with deeply nested complex types */
var complexTypesFilename: String = null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be great to add if this filename is input or output or temp location?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will improve the comments in my next PR later today


/**
* We use Asia/Kathmandu because it has a non-zero number of minutes as the offset, so is an
* interesting edge case. Also, this timezone tends to be different from the default system
Expand All @@ -53,18 +56,20 @@ class CometFuzzTestBase extends CometTestBase with AdaptiveSparkPlanHelper {
override def beforeAll(): Unit = {
super.beforeAll()
val tempDir = System.getProperty("java.io.tmpdir")
filename = s"$tempDir/CometFuzzTestSuite_${System.currentTimeMillis()}.parquet"
val random = new Random(42)
val dataGenOptions = DataGenOptions(
generateNegativeZero = false,
// override base date due to known issues with experimental scans
baseDate = new SimpleDateFormat("YYYY-MM-DD hh:mm:ss").parse("2024-05-25 12:34:56").getTime)

// generate Parquet file with primitives, structs, and arrays, but no maps
// and no nested complex types
filename = s"$tempDir/CometFuzzTestSuite_${System.currentTimeMillis()}.parquet"
withSQLConf(
CometConf.COMET_ENABLED.key -> "false",
SQLConf.SESSION_LOCAL_TIMEZONE.key -> defaultTimezone) {
val schemaGenOptions =
SchemaGenOptions(generateArray = true, generateStruct = true)
val dataGenOptions = DataGenOptions(
generateNegativeZero = false,
// override base date due to known issues with experimental scans
baseDate =
new SimpleDateFormat("YYYY-MM-DD hh:mm:ss").parse("2024-05-25 12:34:56").getTime)
ParquetGenerator.makeParquetFile(
random,
spark,
Expand All @@ -73,6 +78,30 @@ class CometFuzzTestBase extends CometTestBase with AdaptiveSparkPlanHelper {
schemaGenOptions,
dataGenOptions)
}

// generate Parquet file with complex nested types
complexTypesFilename =
s"$tempDir/CometFuzzTestSuite_nested_${System.currentTimeMillis()}.parquet"
withSQLConf(
CometConf.COMET_ENABLED.key -> "false",
SQLConf.SESSION_LOCAL_TIMEZONE.key -> defaultTimezone) {
val schemaGenOptions =
SchemaGenOptions(generateArray = true, generateStruct = true, generateMap = true)
val schema = FuzzDataGenerator.generateNestedSchema(
random,
numCols = 10,
minDepth = 2,
maxDepth = 4,
options = schemaGenOptions)
ParquetGenerator.makeParquetFile(
random,
spark,
complexTypesFilename,
schema,
1000,
dataGenOptions)
}

}

protected override def afterAll(): Unit = {
Expand All @@ -84,6 +113,7 @@ class CometFuzzTestBase extends CometTestBase with AdaptiveSparkPlanHelper {
pos: Position): Unit = {
Seq("native", "jvm").foreach { shuffleMode =>
Seq(
CometConf.SCAN_AUTO,
CometConf.SCAN_NATIVE_COMET,
CometConf.SCAN_NATIVE_DATAFUSION,
CometConf.SCAN_NATIVE_ICEBERG_COMPAT).foreach { scanImpl =>
Expand Down
51 changes: 30 additions & 21 deletions spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType
import org.apache.spark.sql.types._

import org.apache.comet.DataTypeSupport.isComplexType
import org.apache.comet.testing.{DataGenOptions, ParquetGenerator, SchemaGenOptions}
import org.apache.comet.testing.{DataGenOptions, ParquetGenerator}
import org.apache.comet.testing.FuzzDataGenerator.{doubleNaNLiteral, floatNaNLiteral}

class CometFuzzTestSuite extends CometFuzzTestBase {
Expand All @@ -44,6 +44,17 @@ class CometFuzzTestSuite extends CometFuzzTestBase {
}
}

test("select * with deeply nested complex types") {
val df = spark.read.parquet(complexTypesFilename)
df.createOrReplaceTempView("t1")
val sql = "SELECT * FROM t1"
if (CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_COMET) {
checkSparkAnswerAndOperator(sql)
} else {
checkSparkAnswer(sql)
}
}

test("select * with limit") {
val df = spark.read.parquet(filename)
df.createOrReplaceTempView("t1")
Expand Down Expand Up @@ -179,7 +190,7 @@ class CometFuzzTestSuite extends CometFuzzTestBase {
case CometConf.SCAN_NATIVE_COMET =>
// native_comet does not support reading complex types
0
case CometConf.SCAN_NATIVE_ICEBERG_COMPAT | CometConf.SCAN_NATIVE_DATAFUSION =>
case _ =>
CometConf.COMET_SHUFFLE_MODE.get() match {
case "jvm" =>
1
Expand All @@ -202,7 +213,7 @@ class CometFuzzTestSuite extends CometFuzzTestBase {
case CometConf.SCAN_NATIVE_COMET =>
// native_comet does not support reading complex types
0
case CometConf.SCAN_NATIVE_ICEBERG_COMPAT | CometConf.SCAN_NATIVE_DATAFUSION =>
case _ =>
CometConf.COMET_SHUFFLE_MODE.get() match {
case "jvm" =>
1
Expand Down Expand Up @@ -272,12 +283,7 @@ class CometFuzzTestSuite extends CometFuzzTestBase {
}

private def testParquetTemporalTypes(
outputTimestampType: ParquetOutputTimestampType.Value,
generateArray: Boolean = true,
generateStruct: Boolean = true): Unit = {

val schemaGenOptions =
SchemaGenOptions(generateArray = generateArray, generateStruct = generateStruct)
outputTimestampType: ParquetOutputTimestampType.Value): Unit = {

val dataGenOptions = DataGenOptions(generateNegativeZero = false)

Expand All @@ -287,12 +293,23 @@ class CometFuzzTestSuite extends CometFuzzTestBase {
CometConf.COMET_ENABLED.key -> "false",
SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> outputTimestampType.toString,
SQLConf.SESSION_LOCAL_TIMEZONE.key -> defaultTimezone) {

// TODO test with MapType
// https://github.com/apache/datafusion-comet/issues/2945
val schema = StructType(
Seq(
StructField("c0", DataTypes.DateType),
StructField("c1", DataTypes.createArrayType(DataTypes.DateType)),
StructField(
"c2",
DataTypes.createStructType(Array(StructField("c3", DataTypes.DateType))))))

ParquetGenerator.makeParquetFile(
random,
spark,
filename.toString,
schema,
100,
schemaGenOptions,
dataGenOptions)
}

Expand All @@ -309,18 +326,10 @@ class CometFuzzTestSuite extends CometFuzzTestBase {

val df = spark.read.parquet(filename.toString)
df.createOrReplaceTempView("t1")

def hasTemporalType(t: DataType): Boolean = t match {
case DataTypes.DateType | DataTypes.TimestampType |
DataTypes.TimestampNTZType =>
true
case t: StructType => t.exists(f => hasTemporalType(f.dataType))
case t: ArrayType => hasTemporalType(t.elementType)
Comment on lines -317 to -318
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was missing a check for MapType so I added that and moved the method to DataTypeSupport

case _ => false
}

val columns =
df.schema.fields.filter(f => hasTemporalType(f.dataType)).map(_.name)
df.schema.fields
.filter(f => DataTypeSupport.hasTemporalType(f.dataType))
.map(_.name)

for (col <- columns) {
checkSparkAnswer(s"SELECT $col FROM t1 ORDER BY $col")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ abstract class CometTestBase
sparkPlan = dfSpark.queryExecution.executedPlan
}
val dfComet = datasetOfRows(spark, df.logicalPlan)

if (withTol.isDefined) {
checkAnswerWithTolerance(dfComet, expected, withTol.get)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

package org.apache.spark.sql

import scala.collection.mutable.ListBuffer

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Alias, ToPrettyString}
Expand All @@ -29,7 +27,6 @@ import org.apache.spark.sql.types.DataTypes

import org.apache.comet.{CometConf, CometFuzzTestBase}
import org.apache.comet.expressions.{CometCast, CometEvalMode}
import org.apache.comet.rules.CometScanTypeChecker
import org.apache.comet.serde.Compatible

class CometToPrettyStringSuite extends CometFuzzTestBase {
Expand All @@ -45,14 +42,14 @@ class CometToPrettyStringSuite extends CometFuzzTestBase {
val plan = Project(Seq(prettyExpr), table)
val analyzed = spark.sessionState.analyzer.execute(plan)
val result: DataFrame = Dataset.ofRows(spark, analyzed)
CometCast.isSupported(
val supportLevel = CometCast.isSupported(
field.dataType,
DataTypes.StringType,
Some(spark.sessionState.conf.sessionLocalTimeZone),
CometEvalMode.TRY) match {
CometEvalMode.TRY)
supportLevel match {
case _: Compatible
if CometScanTypeChecker(CometConf.COMET_NATIVE_SCAN_IMPL.get())
.isTypeSupported(field.dataType, field.name, ListBuffer.empty) =>
if CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_COMET =>
checkSparkAnswerAndOperator(result)
case _ => checkSparkAnswer(result)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

package org.apache.spark.sql

import scala.collection.mutable.ListBuffer

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Alias, ToPrettyString}
Expand All @@ -32,7 +30,6 @@ import org.apache.spark.sql.types.DataTypes

import org.apache.comet.{CometConf, CometFuzzTestBase}
import org.apache.comet.expressions.{CometCast, CometEvalMode}
import org.apache.comet.rules.CometScanTypeChecker
import org.apache.comet.serde.Compatible

class CometToPrettyStringSuite extends CometFuzzTestBase {
Expand All @@ -56,14 +53,14 @@ class CometToPrettyStringSuite extends CometFuzzTestBase {
val plan = Project(Seq(prettyExpr), table)
val analyzed = spark.sessionState.analyzer.execute(plan)
val result: DataFrame = Dataset.ofRows(spark, analyzed)
CometCast.isSupported(
val supportLevel = CometCast.isSupported(
field.dataType,
DataTypes.StringType,
Some(spark.sessionState.conf.sessionLocalTimeZone),
CometEvalMode.TRY) match {
CometEvalMode.TRY)
supportLevel match {
case _: Compatible
if CometScanTypeChecker(CometConf.COMET_NATIVE_SCAN_IMPL.get())
.isTypeSupported(field.dataType, field.name, ListBuffer.empty) =>
if CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_COMET =>
checkSparkAnswerAndOperator(result)
case _ => checkSparkAnswer(result)
}
Expand Down
Loading