Skip to content

Commit d7a6036

Browse files
author
Kazantsev Maksim
committed
WIP
1 parent 0f98a3c commit d7a6036

File tree

5 files changed

+38
-27
lines changed

5 files changed

+38
-27
lines changed

docs/source/user-guide/latest/configs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ These settings can be used to determine which parts of the plan are accelerated
324324
| `spark.comet.expression.StringTrimBoth.enabled` | Enable Comet acceleration for `StringTrimBoth` | true |
325325
| `spark.comet.expression.StringTrimLeft.enabled` | Enable Comet acceleration for `StringTrimLeft` | true |
326326
| `spark.comet.expression.StringTrimRight.enabled` | Enable Comet acceleration for `StringTrimRight` | true |
327-
Add| `spark.comet.expression.StructsToCsv.enabled` | Enable Comet acceleration for `StructsToCsv` | true |
327+
| `spark.comet.expression.StructsToCsv.enabled` | Enable Comet acceleration for `StructsToCsv` | true |
328328
| `spark.comet.expression.StructsToJson.enabled` | Enable Comet acceleration for `StructsToJson` | true |
329329
| `spark.comet.expression.Substring.enabled` | Enable Comet acceleration for `Substring` | true |
330330
| `spark.comet.expression.Subtract.enabled` | Enable Comet acceleration for `Subtract` | true |

native/spark-expr/src/csv_funcs/to_csv.rs

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
// under the License.
1717

1818
use arrow::array::{
19-
Array, ArrayRef, BooleanArray, Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray,
20-
StringArray, StringBuilder,
19+
as_boolean_array, as_largestring_array, as_string_array, Array, ArrayRef, StringBuilder,
2120
};
2221
use arrow::array::{RecordBatch, StructArray};
2322
use arrow::datatypes::{DataType, Schema};
23+
use datafusion::common::cast::{as_int16_array, as_int32_array, as_int64_array, as_int8_array};
2424
use datafusion::common::{exec_err, Result};
2525
use datafusion::logical_expr::ColumnarValue;
2626
use datafusion::physical_expr::PhysicalExpr;
@@ -96,6 +96,10 @@ impl PhysicalExpr for ToCsv {
9696
Ok(DataType::Utf8)
9797
}
9898

99+
fn nullable(&self, input_schema: &Schema) -> Result<bool> {
100+
self.expr.nullable(input_schema)
101+
}
102+
99103
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
100104
let input_value = self.expr.evaluate(batch)?.into_array(batch.num_rows())?;
101105

@@ -134,6 +138,7 @@ impl PhysicalExpr for ToCsv {
134138
fn struct_to_csv(array: &StructArray, delimiter: &str, null_value: &str) -> Result<ArrayRef> {
135139
let mut builder = StringBuilder::with_capacity(array.len(), array.len() * 16);
136140
let mut csv_string = String::with_capacity(array.len() * 16);
141+
137142
for row_idx in 0..array.len() {
138143
if array.is_null(row_idx) {
139144
builder.append_null();
@@ -146,8 +151,7 @@ fn struct_to_csv(array: &StructArray, delimiter: &str, null_value: &str) -> Resu
146151
if column.is_null(row_idx) {
147152
csv_string.push_str(null_value);
148153
} else {
149-
let value = convert_to_string(column, row_idx)?;
150-
csv_string.push_str(&value);
154+
convert_to_string(column, &mut csv_string, row_idx)?;
151155
}
152156
}
153157
}
@@ -156,38 +160,40 @@ fn struct_to_csv(array: &StructArray, delimiter: &str, null_value: &str) -> Resu
156160
Ok(Arc::new(builder.finish()))
157161
}
158162

159-
fn convert_to_string(array: &ArrayRef, row_idx: usize) -> Result<String> {
163+
#[inline]
164+
fn convert_to_string(array: &ArrayRef, csv_string: &mut String, row_idx: usize) -> Result<()> {
160165
match array.data_type() {
161166
DataType::Boolean => {
162-
let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
163-
Ok(array.value(row_idx).to_string())
167+
let array = as_boolean_array(array);
168+
csv_string.push_str(&array.value(row_idx).to_string())
164169
}
165170
DataType::Int8 => {
166-
let array = array.as_any().downcast_ref::<Int8Array>().unwrap();
167-
Ok(array.value(row_idx).to_string())
171+
let array = as_int8_array(array)?;
172+
csv_string.push_str(&array.value(row_idx).to_string())
168173
}
169174
DataType::Int16 => {
170-
let array = array.as_any().downcast_ref::<Int16Array>().unwrap();
171-
Ok(array.value(row_idx).to_string())
175+
let array = as_int16_array(array)?;
176+
csv_string.push_str(&array.value(row_idx).to_string())
172177
}
173178
DataType::Int32 => {
174-
let array = array.as_any().downcast_ref::<Int32Array>().unwrap();
175-
Ok(array.value(row_idx).to_string())
179+
let array = as_int32_array(array)?;
180+
csv_string.push_str(&array.value(row_idx).to_string())
176181
}
177182
DataType::Int64 => {
178-
let array = array.as_any().downcast_ref::<Int64Array>().unwrap();
179-
Ok(array.value(row_idx).to_string())
183+
let array = as_int64_array(array)?;
184+
csv_string.push_str(&array.value(row_idx).to_string())
180185
}
181186
DataType::Utf8 => {
182-
let array = array.as_any().downcast_ref::<StringArray>().unwrap();
183-
Ok(array.value(row_idx).to_string())
187+
let array = as_string_array(array);
188+
csv_string.push_str(&array.value(row_idx).to_string())
184189
}
185190
DataType::LargeUtf8 => {
186-
let array = array.as_any().downcast_ref::<LargeStringArray>().unwrap();
187-
Ok(array.value(row_idx).to_string())
191+
let array = as_largestring_array(array);
192+
csv_string.push_str(&array.value(row_idx).to_string())
188193
}
189-
_ => exec_err!("to_csv not implemented for type: {:?}", array.data_type()),
194+
_ => return exec_err!("to_csv not implemented for type: {:?}", array.data_type()),
190195
}
196+
Ok(())
191197
}
192198

193199
#[cfg(test)]

spark/src/test/scala/org/apache/comet/CometCsvExpressionSuite.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ class CometCsvExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper
4949
val df = spark.read
5050
.parquet(filename)
5151
.select(to_csv(struct(col("c0"), col("c1"), col("c2"))))
52-
checkSparkAnswerAndOperator(df)
53-
52+
df.explain(true)
53+
df.printSchema()
54+
checkSparkAnswer(df)
5455
}
5556
}
5657
}

spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ package org.apache.spark.sql
2222
import java.util.concurrent.atomic.AtomicInteger
2323

2424
import scala.concurrent.duration._
25+
import scala.jdk.CollectionConverters._
2526
import scala.reflect.ClassTag
2627
import scala.reflect.runtime.universe.TypeTag
2728
import scala.util.{Success, Try}
@@ -43,7 +44,7 @@ import org.apache.spark.sql.execution._
4344
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
4445
import org.apache.spark.sql.internal._
4546
import org.apache.spark.sql.test._
46-
import org.apache.spark.sql.types.{DecimalType, StructType}
47+
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
4748

4849
import org.apache.comet._
4950
import org.apache.comet.shims.ShimCometSparkSessionExtensions
@@ -119,6 +120,10 @@ abstract class CometTestBase
119120
if (withTol.isDefined) {
120121
checkAnswerWithTolerance(dfComet, expected, withTol.get)
121122
} else {
123+
val df =
124+
spark.createDataFrame(expected.toList.asJava, new StructType().add("value", StringType))
125+
df.show(false)
126+
df.printSchema()
122127
checkAnswer(dfComet, expected)
123128
}
124129

spark/src/test/scala/org/apache/spark/sql/benchmark/CometCsvExpressionBenchmark.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,8 @@ case class CsvExprConfig(
4343
// spotless:off
4444
/**
4545
* Benchmark to measure performance of Comet CSV expressions. To run this benchmark:
46-
* `SPARK_GENERATE_BENCHMARK_FILES=1 make
47-
* benchmark-org.apache.spark.sql.benchmark.CometCsvExpressionBenchmark` Results will be written
48-
* to "spark/benchmarks/CometCsvExpressionBenchmark-**results.txt".
46+
* `SPARK_GENERATE_BENCHMARK_FILES=1 make benchmark-org.apache.spark.sql.benchmark.CometCsvExpressionBenchmark`
47+
* Results will be written to "spark/benchmarks/CometCsvExpressionBenchmark-**results.txt".
4948
*/
5049
// spotless:on
5150
object CometCsvExpressionBenchmark extends CometBenchmarkBase {

0 commit comments

Comments
 (0)