Skip to content
Merged
39 changes: 19 additions & 20 deletions spark/src/main/scala/org/apache/comet/serde/structs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,26 +111,6 @@ object CometStructsToJson extends CometExpressionSerde[StructsToJson] {
withInfo(expr, "StructsToJson with options is not supported")
None
} else {

def isSupportedType(dt: DataType): Boolean = {
dt match {
case StructType(fields) =>
fields.forall(f => isSupportedType(f.dataType))
case DataTypes.BooleanType | DataTypes.ByteType | DataTypes.ShortType |
DataTypes.IntegerType | DataTypes.LongType | DataTypes.FloatType |
DataTypes.DoubleType | DataTypes.StringType =>
true
case DataTypes.DateType | DataTypes.TimestampType =>
// TODO implement these types with tests for formatting options and timezone
false
case _: MapType | _: ArrayType =>
// Spark supports map and array in StructsToJson but this is not yet
// implemented in Comet
false
case _ => false
}
}

val isSupported = expr.child.dataType match {
case s: StructType =>
s.fields.forall(f => isSupportedType(f.dataType))
Expand Down Expand Up @@ -166,6 +146,25 @@ object CometStructsToJson extends CometExpressionSerde[StructsToJson] {
}
}
}

def isSupportedType(dt: DataType): Boolean = {
dt match {
case StructType(fields) =>
fields.forall(f => isSupportedType(f.dataType))
case DataTypes.BooleanType | DataTypes.ByteType | DataTypes.ShortType |
DataTypes.IntegerType | DataTypes.LongType | DataTypes.FloatType |
DataTypes.DoubleType | DataTypes.StringType =>
true
case DataTypes.DateType | DataTypes.TimestampType =>
// TODO implement these types with tests for formatting options and timezone
false
case _: MapType | _: ArrayType =>
// Spark supports map and array in StructsToJson but this is not yet
// implemented in Comet
false
case _ => false
}
}
}

object CometJsonToStructs extends CometExpressionSerde[JsonToStructs] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ object FuzzDataGenerator {
Range(0, numRows).map(_ => {
r.nextInt(20) match {
case 0 if options.allowNull => null
case 1 => Float.NegativeInfinity
case 2 => Float.PositiveInfinity
case 1 if options.generateInfinity => Float.NegativeInfinity
case 2 if options.generateInfinity => Float.PositiveInfinity
case 3 => Float.MinValue
case 4 => Float.MaxValue
case 5 => 0.0f
Expand All @@ -243,8 +243,8 @@ object FuzzDataGenerator {
Range(0, numRows).map(_ => {
r.nextInt(20) match {
case 0 if options.allowNull => null
case 1 => Double.NegativeInfinity
case 2 => Double.PositiveInfinity
case 1 if options.generateInfinity => Double.NegativeInfinity
case 2 if options.generateInfinity => Double.PositiveInfinity
case 3 => Double.MinValue
case 4 => Double.MaxValue
case 5 => 0.0
Expand Down Expand Up @@ -329,4 +329,5 @@ case class DataGenOptions(
generateNaN: Boolean = true,
baseDate: Long = FuzzDataGenerator.defaultBaseDate,
customStrings: Seq[String] = Seq.empty,
maxStringLength: Int = 8)
maxStringLength: Int = 8,
generateInfinity: Boolean = true)
Copy link
Contributor Author

@kazantsev-maksim kazantsev-maksim Dec 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comet to_json does not support NaN, +Infinity, -Infinity values for numeric types.

Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,59 @@

package org.apache.comet

import scala.util.Random

import org.scalactic.source.Position
import org.scalatest.Tag

import org.apache.hadoop.fs.Path
import org.apache.spark.sql.CometTestBase
import org.apache.spark.sql.catalyst.expressions.JsonToStructs
import org.apache.spark.sql.catalyst.expressions.{JsonToStructs, StructsToJson}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions._

import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus
import org.apache.comet.serde.CometStructsToJson
import org.apache.comet.testing.{DataGenOptions, ParquetGenerator, SchemaGenOptions}

class CometJsonExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {

override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit
pos: Position): Unit = {
super.test(testName, testTags: _*) {
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[JsonToStructs]) -> "true") {
withSQLConf(
CometConf.getExprAllowIncompatConfigKey(classOf[JsonToStructs]) -> "true",
CometConf.getExprAllowIncompatConfigKey(classOf[StructsToJson]) -> "true") {
testFun
}
}
}

test("to_json - all supported types") {
assume(!isSpark40Plus)
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
val filename = path.toString
val random = new Random(42)
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
ParquetGenerator.makeParquetFile(
random,
spark,
filename,
100,
SchemaGenOptions(generateArray = false, generateStruct = false, generateMap = false),
DataGenOptions(generateNaN = false, generateInfinity = false))
}
val table = spark.read.parquet(filename)
val fieldsNames = table.schema.fields
.filter(sf => CometStructsToJson.isSupportedType(sf.dataType))
.map(sf => col(sf.name))
.toSeq
val df = table.select(to_json(struct(fieldsNames: _*)))
checkSparkAnswerAndOperator(df)
}
}

test("from_json - basic primitives") {
Seq(true, false).foreach { dictionaryEnabled =>
withParquetTable(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

package org.apache.spark.sql.benchmark

import org.apache.spark.sql.catalyst.expressions.JsonToStructs
import org.apache.spark.sql.catalyst.expressions.{JsonToStructs, StructsToJson}

import org.apache.comet.CometConf

Expand Down Expand Up @@ -106,6 +106,44 @@ object CometJsonExpressionBenchmark extends CometBenchmarkBase {
FROM $tbl
""")

case "to_json - simple primitives" =>
spark.sql(
s"""SELECT named_struct("a", CAST(value AS INT), "b", concat("str_", CAST(value AS STRING))) AS json_struct FROM $tbl""")

case "to_json - all primitive types" =>
spark.sql(s"""
SELECT named_struct(
"i32", CAST(value % 1000 AS INT),
"i64", CAST(value * 1000000000L AS LONG),
"f32", CAST(value * 1.5 AS FLOAT),
"f64", CAST(value * 2.5 AS DOUBLE),
"bool", CASE WHEN value % 2 = 0 THEN true ELSE false END,
"str", concat("value_", CAST(value AS STRING))
) AS json_struct FROM $tbl
""")

case "to_json - with nulls" =>
spark.sql(s"""
SELECT
CASE
WHEN value % 10 = 0 THEN CAST(NULL AS STRUCT<a: INT, b: STRING>)
WHEN value % 5 = 0 THEN named_struct("a", CAST(NULL AS INT), "b", "test")
WHEN value % 3 = 0 THEN named_struct("a", CAST(123 AS INT), "b", CAST(NULL AS STRING))
ELSE named_struct("a", CAST(value AS INT), "b", concat("str_", CAST(value AS STRING)))
END AS json_struct
FROM $tbl
""")

case "to_json - nested struct" =>
spark.sql(s"""
SELECT named_struct(
"outer", named_struct(
"inner_a", CAST(value AS INT),
"inner_b", concat("nested_", CAST(value AS STRING))
)
) AS json_struct FROM $tbl
""")

case _ =>
spark.sql(s"""
SELECT
Expand All @@ -117,8 +155,9 @@ object CometJsonExpressionBenchmark extends CometBenchmarkBase {
prepareTable(dir, jsonData)

val extraConfigs = Map(
CometConf.getExprAllowIncompatConfigKey(classOf[JsonToStructs]) -> "true",
CometConf.getExprAllowIncompatConfigKey(
classOf[JsonToStructs]) -> "true") ++ config.extraCometConfigs
classOf[StructsToJson]) -> "true") ++ config.extraCometConfigs

runExpressionBenchmark(config.name, values, config.query, extraConfigs)
}
Expand All @@ -127,6 +166,7 @@ object CometJsonExpressionBenchmark extends CometBenchmarkBase {

// Configuration for all JSON expression benchmarks
private val jsonExpressions = List(
// from_json tests
JsonExprConfig(
"from_json - simple primitives",
"a INT, b STRING",
Expand All @@ -146,7 +186,25 @@ object CometJsonExpressionBenchmark extends CometBenchmarkBase {
JsonExprConfig(
"from_json - field access",
"a INT, b STRING",
"SELECT from_json(json_str, 'a INT, b STRING').a FROM parquetV1Table"))
"SELECT from_json(json_str, 'a INT, b STRING').a FROM parquetV1Table"),

// to_json tests
JsonExprConfig(
"to_json - simple primitives",
"a INT, b STRING",
"SELECT to_json(json_struct) FROM parquetV1Table"),
JsonExprConfig(
"to_json - all primitive types",
"i32 INT, i64 BIGINT, f32 FLOAT, f64 DOUBLE, bool BOOLEAN, str STRING",
"SELECT to_json(json_struct) FROM parquetV1Table"),
JsonExprConfig(
"to_json - with nulls",
"a INT, b STRING",
"SELECT to_json(json_struct) FROM parquetV1Table"),
JsonExprConfig(
"to_json - nested struct",
"outer STRUCT<inner_a: INT, inner_b: STRING>",
"SELECT to_json(json_struct) FROM parquetV1Table"))

override def runCometBenchmark(mainArgs: Array[String]): Unit = {
val values = 1024 * 1024
Expand Down
Loading