diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index beb5f9dcf7..17f66b6630 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -167,6 +167,7 @@ jobs: org.apache.comet.CometStringExpressionSuite org.apache.comet.CometBitwiseExpressionSuite org.apache.comet.CometMapExpressionSuite + org.apache.comet.CometCsvExpressionSuite org.apache.comet.CometJsonExpressionSuite org.apache.comet.expressions.conditional.CometIfSuite org.apache.comet.expressions.conditional.CometCoalesceSuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index 9a45fe022d..80e8854ef6 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -131,6 +131,7 @@ jobs: org.apache.comet.CometBitwiseExpressionSuite org.apache.comet.CometMapExpressionSuite org.apache.comet.CometJsonExpressionSuite + org.apache.comet.CometCsvExpressionSuite org.apache.comet.expressions.conditional.CometIfSuite org.apache.comet.expressions.conditional.CometCoalesceSuite org.apache.comet.expressions.conditional.CometCaseWhenSuite diff --git a/docs/source/user-guide/latest/configs.md b/docs/source/user-guide/latest/configs.md index 1a273ad033..bd062ec587 100644 --- a/docs/source/user-guide/latest/configs.md +++ b/docs/source/user-guide/latest/configs.md @@ -324,6 +324,7 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.StringTrimBoth.enabled` | Enable Comet acceleration for `StringTrimBoth` | true | | `spark.comet.expression.StringTrimLeft.enabled` | Enable Comet acceleration for `StringTrimLeft` | true | | `spark.comet.expression.StringTrimRight.enabled` | Enable Comet acceleration for `StringTrimRight` | true | +| `spark.comet.expression.StructsToCsv.enabled` | Enable Comet acceleration for `StructsToCsv` | true | | `spark.comet.expression.StructsToJson.enabled` | Enable Comet acceleration for `StructsToJson` | true | | `spark.comet.expression.Substring.enabled` | Enable Comet acceleration for `Substring` | true | | `spark.comet.expression.Subtract.enabled` | Enable Comet acceleration for `Subtract` | true | diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 93fbb59c11..8250981f94 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -70,8 +70,8 @@ use datafusion::{ }; use datafusion_comet_spark_expr::{ create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, BinaryOutputStyle, - BloomFilterAgg, BloomFilterMightContain, EvalMode, SparkHour, SparkMinute, SparkSecond, - SumInteger, + BloomFilterAgg, BloomFilterMightContain, CsvWriteOptions, EvalMode, SparkHour, SparkMinute, + SparkSecond, SumInteger, ToCsv, }; use iceberg::expr::Bind; @@ -644,6 +644,25 @@ impl PhysicalPlanner { ExprStruct::MonotonicallyIncreasingId(_) => Ok(Arc::new( MonotonicallyIncreasingId::from_partition_id(self.partition), )), + ExprStruct::ToCsv(expr) => { + let csv_struct_expr = + self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?; + let options = expr.options.clone().unwrap(); + let csv_write_options = CsvWriteOptions::new( + options.delimiter, + options.quote, + options.escape, + options.null_value, + options.quote_all, + options.ignore_leading_white_space, + options.ignore_trailing_white_space, + ); + Ok(Arc::new(ToCsv::new( + csv_struct_expr, + &options.timezone, + csv_write_options, + ))) + } expr => Err(GeneralError(format!("Not implemented: {expr:?}"))), } } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 5f258fd677..5f4b3157d2 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -86,6 +86,7 @@ message Expr { EmptyExpr spark_partition_id = 63; EmptyExpr monotonically_increasing_id = 64; FromJson from_json = 89; + ToCsv to_csv = 90; } } @@ -275,6 +276,22 @@ message FromJson { string timezone = 3; } +message ToCsv { + Expr child = 1; + CsvWriteOptions options = 2; +} + +message CsvWriteOptions { + string delimiter = 1; + string quote = 2; + string escape = 3; + string null_value = 4; + bool quote_all = 5; + bool ignore_leading_white_space = 6; + bool ignore_trailing_white_space = 7; + string timezone = 8; +} + enum BinaryOutputStyle { UTF8 = 0; BASIC = 1; diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 94653d8864..fd0a211b29 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -88,6 +88,10 @@ harness = false name = "normalize_nan" harness = false +[[bench]] +name = "to_csv" +harness = false + [[test]] name = "test_udf_registration" path = "tests/spark_expr_reg.rs" diff --git a/native/spark-expr/benches/to_csv.rs b/native/spark-expr/benches/to_csv.rs new file mode 100644 index 0000000000..8620dd0f16 --- /dev/null +++ b/native/spark-expr/benches/to_csv.rs @@ -0,0 +1,108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + BooleanBuilder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, StringBuilder, + StructArray, StructBuilder, +}; +use arrow::datatypes::{DataType, Field}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion_comet_spark_expr::{to_csv_inner, CsvWriteOptions, EvalMode, SparkCastOptions}; +use std::hint::black_box; + +fn create_struct_array(array_size: usize) -> StructArray { + let fields = vec![ + Field::new("f1", DataType::Boolean, true), + Field::new("f2", DataType::Int8, true), + Field::new("f3", DataType::Int16, true), + Field::new("f4", DataType::Int32, true), + Field::new("f5", DataType::Int64, true), + Field::new("f6", DataType::Utf8, true), + ]; + let mut struct_builder = StructBuilder::from_fields(fields, array_size); + for i in 0..array_size { + struct_builder + .field_builder::(0) + .unwrap() + .append_option(if i % 10 == 0 { None } else { Some(i % 2 == 0) }); + + struct_builder + .field_builder::(1) + .unwrap() + .append_option(if i % 10 == 0 { + None + } else { + Some((i % 128) as i8) + }); + + struct_builder + .field_builder::(2) + .unwrap() + .append_option(if i % 10 == 0 { None } else { Some(i as i16) }); + + struct_builder + .field_builder::(3) + .unwrap() + .append_option(if i % 10 == 0 { None } else { Some(i as i32) }); + + struct_builder + .field_builder::(4) + .unwrap() + .append_option(if i % 10 == 0 { None } else { Some(i as i64) }); + + struct_builder + .field_builder::(5) + .unwrap() + .append_option(if i % 10 == 0 { + None + } else { + Some(format!("string_{}", i)) + }); + + struct_builder.append(true); + } + struct_builder.finish() +} + +fn criterion_benchmark(c: &mut Criterion) { + let array_size = 8192; + let timezone = "UTC"; + let struct_array = create_struct_array(array_size); + let default_delimiter = ","; + let default_null_value = ""; + let default_quote = "\""; + let default_escape = "\\"; + let mut cast_options = SparkCastOptions::new(EvalMode::Legacy, timezone, false); + cast_options.null_string = default_null_value.to_string(); + let csv_write_options = CsvWriteOptions::new( + default_delimiter.to_string(), + default_quote.to_string(), + default_escape.to_string(), + default_null_value.to_string(), + false, + true, + true, + ); + c.bench_function("to_csv", |b| { + b.iter(|| { + black_box(to_csv_inner(&struct_array, &cast_options, &csv_write_options).unwrap()) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/native/spark-expr/src/csv_funcs/csv_write_options.rs b/native/spark-expr/src/csv_funcs/csv_write_options.rs new file mode 100644 index 0000000000..4d221745ba --- /dev/null +++ b/native/spark-expr/src/csv_funcs/csv_write_options.rs @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt::{Display, Formatter}; + +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub struct CsvWriteOptions { + pub delimiter: String, + pub quote: String, + pub escape: String, + pub null_value: String, + pub quote_all: bool, + pub ignore_leading_white_space: bool, + pub ignore_trailing_white_space: bool, +} + +impl Display for CsvWriteOptions { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "csv_write_options(quote={}, escape={}, null_value={}, quote_all={}, ignore_leading_white_space={}, ignore_trailing_white_space={})", + self.quote, self.escape, self.null_value, self.quote_all, self.ignore_leading_white_space, self.ignore_trailing_white_space + ) + } +} + +impl CsvWriteOptions { + pub fn new( + delimiter: String, + quote: String, + escape: String, + null_value: String, + quote_all: bool, + ignore_leading_white_space: bool, + ignore_trailing_white_space: bool, + ) -> Self { + Self { + delimiter, + quote, + escape, + null_value, + quote_all, + ignore_leading_white_space, + ignore_trailing_white_space, + } + } +} diff --git a/native/spark-expr/src/csv_funcs/mod.rs b/native/spark-expr/src/csv_funcs/mod.rs new file mode 100644 index 0000000000..9c417f9ebb --- /dev/null +++ b/native/spark-expr/src/csv_funcs/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod csv_write_options; +mod to_csv; + +pub use csv_write_options::CsvWriteOptions; +pub use to_csv::{to_csv_inner, ToCsv}; diff --git a/native/spark-expr/src/csv_funcs/to_csv.rs b/native/spark-expr/src/csv_funcs/to_csv.rs new file mode 100644 index 0000000000..324530e3fe --- /dev/null +++ b/native/spark-expr/src/csv_funcs/to_csv.rs @@ -0,0 +1,208 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::csv_funcs::csv_write_options::CsvWriteOptions; +use crate::{spark_cast, EvalMode, SparkCastOptions}; +use arrow::array::{as_string_array, as_struct_array, Array, ArrayRef, StringArray, StringBuilder}; +use arrow::array::{RecordBatch, StructArray}; +use arrow::datatypes::{DataType, Schema}; +use datafusion::common::Result; +use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr::PhysicalExpr; +use std::any::Any; +use std::fmt::{Display, Formatter}; +use std::hash::Hash; +use std::sync::Arc; + +/// to_csv spark function +#[derive(Debug, Eq)] +pub struct ToCsv { + expr: Arc, + timezone: String, + csv_write_options: CsvWriteOptions, +} + +impl Hash for ToCsv { + fn hash(&self, state: &mut H) { + self.expr.hash(state); + self.timezone.hash(state); + self.csv_write_options.hash(state); + } +} + +impl PartialEq for ToCsv { + fn eq(&self, other: &Self) -> bool { + self.expr.eq(&other.expr) + && self.timezone.eq(&other.timezone) + && self.csv_write_options.eq(&other.csv_write_options) + } +} + +impl ToCsv { + pub fn new( + expr: Arc, + timezone: &str, + csv_write_options: CsvWriteOptions, + ) -> Self { + Self { + expr, + timezone: timezone.to_owned(), + csv_write_options, + } + } +} + +impl Display for ToCsv { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "to_csv({}, timezone={}, csv_write_options={})", + self.expr, self.timezone, self.csv_write_options + ) + } +} + +impl PhysicalExpr for ToCsv { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _: &Schema) -> Result { + Ok(DataType::Utf8) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + self.expr.nullable(input_schema) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let input_array = self.expr.evaluate(batch)?.into_array(batch.num_rows())?; + let mut cast_options = SparkCastOptions::new(EvalMode::Legacy, &self.timezone, false); + cast_options.null_string = self.csv_write_options.null_value.clone(); + let struct_array = as_struct_array(&input_array); + + let csv_array = to_csv_inner(struct_array, &cast_options, &self.csv_write_options)?; + + Ok(ColumnarValue::Array(csv_array)) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.expr] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(Self::new( + Arc::clone(&children[0]), + &self.timezone, + self.csv_write_options.clone(), + ))) + } + + fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result { + unimplemented!() + } +} + +pub fn to_csv_inner( + array: &StructArray, + cast_options: &SparkCastOptions, + write_options: &CsvWriteOptions, +) -> Result { + let string_arrays: Vec = as_struct_array(&array) + .columns() + .iter() + .map(|array| { + spark_cast( + ColumnarValue::Array(Arc::clone(array)), + &DataType::Utf8, + cast_options, + )? + .into_array(array.len()) + }) + .collect::>>()?; + let string_arrays: Vec<&StringArray> = string_arrays + .iter() + .map(|array| as_string_array(array)) + .collect(); + let is_string: Vec = array + .fields() + .iter() + .map(|f| matches!(f.data_type(), DataType::Utf8 | DataType::LargeUtf8)) + .collect(); + + let mut builder = StringBuilder::with_capacity(array.len(), array.len() * 16); + let mut csv_string = String::with_capacity(array.len() * 16); + + let quote = write_options.quote.as_ref(); + for row_idx in 0..array.len() { + if array.is_null(row_idx) { + builder.append_null(); + } else { + csv_string.clear(); + for (col_idx, column) in string_arrays.iter().enumerate() { + if col_idx > 0 { + csv_string.push_str(&write_options.delimiter); + } + let mut value = column.value(row_idx); + let is_string_field = is_string[col_idx]; + if is_string_field { + if write_options.ignore_leading_white_space { + value = value.trim_start(); + } + if write_options.ignore_trailing_white_space { + value = value.trim_end(); + } + } + let needs_quoting = write_options.quote_all + || (is_string_field + && !string_arrays[col_idx].is_null(row_idx) + && (value.contains(&write_options.delimiter) + || value.contains(quote) + || value.is_empty())); + + let needs_escaping = is_string_field && needs_quoting; + if needs_quoting { + csv_string.push_str(quote); + } + if needs_escaping { + escape_value(value, quote, &write_options.escape, &mut csv_string); + } else { + csv_string.push_str(value); + } + if needs_quoting { + csv_string.push_str(quote); + } + } + builder.append_value(&csv_string); + } + } + Ok(Arc::new(builder.finish())) +} + +#[inline] +fn escape_value(value: &str, quote: &str, escape: &str, output: &mut String) { + for ch in value.chars() { + let ch_str = ch.to_string(); + if ch_str == quote || ch_str == escape { + output.push_str(escape); + } + output.push(ch); + } +} diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index f26fd911d8..d770338eaa 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -56,6 +56,7 @@ pub use bloom_filter::{BloomFilterAgg, BloomFilterMightContain}; mod conditional_funcs; mod conversion_funcs; +mod csv_funcs; mod math_funcs; mod nondetermenistic_funcs; @@ -69,6 +70,7 @@ pub use comet_scalar_funcs::{ create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, register_all_comet_functions, }; +pub use csv_funcs::*; pub use datetime_funcs::{SparkDateTrunc, SparkHour, SparkMinute, SparkSecond, TimestampTruncExpr}; pub use error::{SparkError, SparkResult}; pub use hash_funcs::*; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index e50b1d80e6..47c96d10cf 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -133,7 +133,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[GetArrayStructFields] -> CometGetArrayStructFields, classOf[GetStructField] -> CometGetStructField, classOf[JsonToStructs] -> CometJsonToStructs, - classOf[StructsToJson] -> CometStructsToJson) + classOf[StructsToJson] -> CometStructsToJson, + classOf[StructsToCsv] -> CometStructsToCsv) private val hashExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[Md5] -> CometScalarFunction("md5"), diff --git a/spark/src/main/scala/org/apache/comet/serde/structs.scala b/spark/src/main/scala/org/apache/comet/serde/structs.scala index b76c64bac9..753194988d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/structs.scala +++ b/spark/src/main/scala/org/apache/comet/serde/structs.scala @@ -20,11 +20,13 @@ package org.apache.comet.serde import scala.jdk.CollectionConverters._ +import scala.util.Try -import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, GetArrayStructFields, GetStructField, JsonToStructs, StructsToJson} -import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, MapType, StructType} +import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, GetArrayStructFields, GetStructField, JsonToStructs, StructsToCsv, StructsToJson} +import org.apache.spark.sql.types._ import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.DataTypeSupport import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, serializeDataType} object CometCreateNamedStruct extends CometExpressionSerde[CreateNamedStruct] { @@ -230,3 +232,68 @@ object CometJsonToStructs extends CometExpressionSerde[JsonToStructs] { } } } + +object CometStructsToCsv extends CometExpressionSerde[StructsToCsv] { + + private val incompatibleDataTypes = Seq(DateType, TimestampType, TimestampNTZType, BinaryType) + + override def getSupportLevel(expr: StructsToCsv): SupportLevel = { + val dataTypes = expr.inputSchema.fields.map(_.dataType) + val containsComplexType = dataTypes.exists(DataTypeSupport.isComplexType) + if (containsComplexType) { + return Unsupported( + Some( + s"The schema ${expr.inputSchema} is not supported because it includes a complex type")) + } + val containsIncompatibleDataTypes = dataTypes.exists(incompatibleDataTypes.contains) + if (containsIncompatibleDataTypes) { + return Incompatible( + Some( + s"The schema ${expr.inputSchema} is not supported because " + + s"it includes a incompatible data types: $incompatibleDataTypes")) + } + Compatible() + } + + override def convert( + expr: StructsToCsv, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + for { + childProto <- exprToProtoInternal(expr.child, inputs, binding) + } yield { + val optionsProto = options2Proto(expr.options, expr.timeZoneId) + val toCsv = ExprOuterClass.ToCsv + .newBuilder() + .setChild(childProto) + .setOptions(optionsProto) + .build() + ExprOuterClass.Expr.newBuilder().setToCsv(toCsv).build() + } + } + + private def options2Proto( + options: Map[String, String], + timeZoneId: Option[String]): ExprOuterClass.CsvWriteOptions = { + ExprOuterClass.CsvWriteOptions + .newBuilder() + .setDelimiter(options.getOrElse("delimiter", ",")) + .setQuote(options.getOrElse("quote", "\"")) + .setEscape(options.getOrElse("escape", "\\")) + .setNullValue(options.getOrElse("nullValue", "")) + .setTimezone(timeZoneId.getOrElse("UTC")) + .setIgnoreLeadingWhiteSpace(options + .get("ignoreLeadingWhiteSpace") + .flatMap(ignoreLeadingWhiteSpace => Try(ignoreLeadingWhiteSpace.toBoolean).toOption) + .getOrElse(true)) + .setIgnoreTrailingWhiteSpace(options + .get("ignoreTrailingWhiteSpace") + .flatMap(ignoreTrailingWhiteSpace => Try(ignoreTrailingWhiteSpace.toBoolean).toOption) + .getOrElse(true)) + .setQuoteAll(options + .get("quoteAll") + .flatMap(quoteAll => Try(quoteAll.toBoolean).toOption) + .getOrElse(false)) + .build() + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometCsvExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometCsvExpressionSuite.scala new file mode 100644 index 0000000000..dcf1a05953 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometCsvExpressionSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import scala.jdk.CollectionConverters._ +import scala.util.Random + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.catalyst.expressions.StructsToCsv +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.functions._ + +import org.apache.comet.testing.{DataGenOptions, ParquetGenerator, SchemaGenOptions} + +class CometCsvExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { + + test("to_csv - default options") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + ParquetGenerator.makeParquetFile( + random, + spark, + filename, + 100, + SchemaGenOptions(generateArray = false, generateStruct = false, generateMap = false), + DataGenOptions(allowNull = true, generateNegativeZero = true)) + } + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[StructsToCsv]) -> "true") { + val df = spark.read + .parquet(filename) + .select( + to_csv( + struct( + col("c0"), + col("c1"), + col("c2"), + col("c3"), + col("c4"), + col("c5"), + col("c7"), + col("c8"), + col("c9"), + col("c12")))) + checkSparkAnswerAndOperator(df) + } + } + } + + test("to_csv - with configurable formatting options") { + val table = "t1" + withSQLConf(CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_ICEBERG_COMPAT) { + withTable(table) { + sql(s"create table $table(col string) using parquet") + sql(s"insert into $table values('')") + sql(s"insert into $table values(cast(null as string))") + sql(s"insert into $table values(' abc')") + sql(s"insert into $table values('abc ')") + sql(s"insert into $table values(' abc ')") + sql(s"""insert into $table values('abc \"abc\"')""") + val df = sql(s"select * from $table") + checkSparkAnswerAndOperator(df.select(to_csv(struct(col("col"), lit(1))))) + checkSparkAnswerAndOperator( + df.select( + to_csv( + struct(col("col"), lit(1)), + Map( + "delimiter" -> ";", + "ignoreLeadingWhiteSpace" -> "false", + "ignoreTrailingWhiteSpace" -> "false").asJava))) + checkSparkAnswerAndOperator( + df.select(to_csv(struct(col("col"), lit(1)), Map("quoteAll" -> "true").asJava))) + } + } + } +} diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometCsvExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometCsvExpressionBenchmark.scala new file mode 100644 index 0000000000..94288eb9cb --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometCsvExpressionBenchmark.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.benchmark + +import org.apache.spark.sql.catalyst.expressions.CsvToStructs + +import org.apache.comet.CometConf + +/** + * Configuration for a CSV expression benchmark. + * + * @param name + * Name for the benchmark + * @param query + * SQL query to benchmark + * @param extraCometConfigs + * Additional Comet configurations for the scan+exec case + */ +case class CsvExprConfig( + name: String, + query: String, + extraCometConfigs: Map[String, String] = Map.empty) + +// spotless:off +/** + * Benchmark to measure performance of Comet CSV expressions. To run this benchmark: + * `SPARK_GENERATE_BENCHMARK_FILES=1 make + * benchmark-org.apache.spark.sql.benchmark.CometCsvExpressionBenchmark` Results will be written + * to "spark/benchmarks/CometCsvExpressionBenchmark-**results.txt". + */ +// spotless:on +object CometCsvExpressionBenchmark extends CometBenchmarkBase { + + /** + * Generic method to run a CSV expression benchmark with the given configuration. + */ + def runCsvExprBenchmark(config: CsvExprConfig, values: Int): Unit = { + withTempPath { dir => + withTempTable("parquetV1Table") { + prepareTable( + dir, + spark.sql( + s"SELECT CAST(value AS STRING) AS c1, CAST(value AS INT) AS c2, CAST(value AS LONG) AS c3 FROM $tbl")) + + val extraConfigs = Map( + CometConf.getExprAllowIncompatConfigKey( + classOf[CsvToStructs]) -> "true") ++ config.extraCometConfigs + + runExpressionBenchmark(config.name, values, config.query, extraConfigs) + } + } + } + + // Configuration for all CSV expression benchmarks + private val csvExpressions = List( + CsvExprConfig("to_csv", "SELECT to_csv(struct(c1, c2, c3)) FROM parquetV1Table")) + + override def runCometBenchmark(args: Array[String]): Unit = { + val values = 1024 * 1024 + + csvExpressions.foreach { config => + runBenchmarkWithTable(config.name, values) { value => + runCsvExprBenchmark(config, value) + } + } + } +}