Skip to content

Commit 7247d9c

Browse files
Chore: implement string_space as ScalarUDFImpl (#2041)
1 parent a91498d commit 7247d9c

File tree

9 files changed

+128
-158
lines changed

9 files changed

+128
-158
lines changed

native/core/src/execution/planner.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ use datafusion_comet_spark_expr::monotonically_increasing_id::MonotonicallyIncre
104104
use datafusion_comet_spark_expr::{
105105
ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct,
106106
GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, RLike,
107-
RandExpr, RandnExpr, SparkCastOptions, Stddev, StringSpaceExpr, SubstringExpr, SumDecimal,
108-
TimestampTruncExpr, ToJson, UnboundColumn, Variance,
107+
RandExpr, RandnExpr, SparkCastOptions, Stddev, SubstringExpr, SumDecimal, TimestampTruncExpr,
108+
ToJson, UnboundColumn, Variance,
109109
};
110110
use itertools::Itertools;
111111
use jni::objects::GlobalRef;
@@ -546,11 +546,6 @@ impl PhysicalPlanner {
546546
len as u64,
547547
)))
548548
}
549-
ExprStruct::StringSpace(expr) => {
550-
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
551-
552-
Ok(Arc::new(StringSpaceExpr::new(child)))
553-
}
554549
ExprStruct::Like(expr) => {
555550
let left =
556551
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;

native/proto/src/proto/expr.proto

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ message Expr {
4545
BinaryExpr or = 18;
4646
SortOrder sort_order = 19;
4747
Substring substring = 20;
48-
UnaryExpr string_space = 21;
4948
Hour hour = 22;
5049
Minute minute = 23;
5150
Second second = 24;

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use crate::{
2222
spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal,
2323
spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value,
2424
SparkBitwiseCount, SparkBitwiseGet, SparkBitwiseNot, SparkChrFunc, SparkDateTrunc,
25+
SparkStringSpace,
2526
};
2627
use arrow::datatypes::DataType;
2728
use datafusion::common::{DataFusionError, Result as DataFusionResult};
@@ -175,6 +176,7 @@ fn all_scalar_functions() -> Vec<Arc<ScalarUDF>> {
175176
Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())),
176177
Arc::new(ScalarUDF::new_from_impl(SparkBitwiseGet::default())),
177178
Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())),
179+
Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())),
178180
]
179181
}
180182

native/spark-expr/src/kernels/strings.rs

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -21,33 +21,11 @@ use std::sync::Arc;
2121

2222
use arrow::{
2323
array::*,
24-
buffer::MutableBuffer,
2524
compute::kernels::substring::{substring as arrow_substring, substring_by_char},
2625
datatypes::{DataType, Int32Type},
2726
};
2827
use datafusion::common::DataFusionError;
2928

30-
/// Returns an ArrayRef with a string consisting of `length` spaces.
31-
///
32-
/// # Preconditions
33-
///
34-
/// - elements in `length` must not be negative
35-
pub fn string_space(length: &dyn Array) -> Result<ArrayRef, DataFusionError> {
36-
match length.data_type() {
37-
DataType::Int32 => {
38-
let array = length.as_any().downcast_ref::<Int32Array>().unwrap();
39-
Ok(generic_string_space::<i32>(array))
40-
}
41-
DataType::Dictionary(_, _) => {
42-
let dict = as_dictionary_array::<Int32Type>(length);
43-
let values = string_space(dict.values())?;
44-
let result = DictionaryArray::try_new(dict.keys().clone(), values)?;
45-
Ok(Arc::new(result))
46-
}
47-
dt => panic!("Unsupported input type for function 'string_space': {dt:?}"),
48-
}
49-
}
50-
5129
pub fn substring(array: &dyn Array, start: i64, length: u64) -> Result<ArrayRef, DataFusionError> {
5230
match array.data_type() {
5331
DataType::LargeUtf8 => substring_by_char(
@@ -82,43 +60,3 @@ pub fn substring(array: &dyn Array, start: i64, length: u64) -> Result<ArrayRef,
8260
dt => panic!("Unsupported input type for function 'substring': {dt:?}"),
8361
}
8462
}
85-
86-
fn generic_string_space<OffsetSize: OffsetSizeTrait>(length: &Int32Array) -> ArrayRef {
87-
let array_len = length.len();
88-
let mut offsets = MutableBuffer::new((array_len + 1) * std::mem::size_of::<OffsetSize>());
89-
let mut length_so_far = OffsetSize::zero();
90-
91-
// compute null bitmap (copy)
92-
let null_bit_buffer = length.to_data().nulls().map(|b| b.buffer().clone());
93-
94-
// Gets slice of length array to access it directly for performance.
95-
let length_data = length.to_data();
96-
let lengths = length_data.buffers()[0].typed_data::<i32>();
97-
let total = lengths.iter().map(|l| *l as usize).sum::<usize>();
98-
let mut values = MutableBuffer::new(total);
99-
100-
offsets.push(length_so_far);
101-
102-
let blank = " ".as_bytes()[0];
103-
values.resize(total, blank);
104-
105-
(0..array_len).for_each(|i| {
106-
let current_len = lengths[i] as usize;
107-
108-
length_so_far += OffsetSize::from_usize(current_len).unwrap();
109-
offsets.push(length_so_far);
110-
});
111-
112-
let data = unsafe {
113-
ArrayData::new_unchecked(
114-
GenericStringArray::<OffsetSize>::DATA_TYPE,
115-
array_len,
116-
None,
117-
null_bit_buffer,
118-
0,
119-
vec![offsets.into(), values.into()],
120-
vec![],
121-
)
122-
};
123-
make_array(data)
124-
}

native/spark-expr/src/string_funcs/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,5 @@ mod string_space;
2020
mod substring;
2121

2222
pub use chr::SparkChrFunc;
23-
pub use string_space::StringSpaceExpr;
23+
pub use string_space::SparkStringSpace;
2424
pub use substring::SubstringExpr;

native/spark-expr/src/string_funcs/string_space.rs

Lines changed: 105 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -15,94 +15,137 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
#![allow(deprecated)]
19-
20-
use crate::kernels::strings::string_space;
21-
use arrow::datatypes::{DataType, Schema};
22-
use arrow::record_batch::RecordBatch;
23-
use datafusion::common::DataFusionError;
24-
use datafusion::logical_expr::ColumnarValue;
25-
use datafusion::physical_expr::PhysicalExpr;
26-
use std::{
27-
any::Any,
28-
fmt::{Display, Formatter},
29-
hash::Hash,
30-
sync::Arc,
18+
use arrow::array::{
19+
as_dictionary_array, make_array, Array, ArrayData, ArrayRef, DictionaryArray,
20+
GenericStringArray, Int32Array, OffsetSizeTrait,
3121
};
22+
use arrow::buffer::MutableBuffer;
23+
use arrow::datatypes::{DataType, Int32Type};
24+
use datafusion::common::{exec_err, internal_datafusion_err, DataFusionError, Result};
25+
use datafusion::logical_expr::{
26+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
27+
};
28+
use std::{any::Any, sync::Arc};
3229

33-
#[derive(Debug, Eq)]
34-
pub struct StringSpaceExpr {
35-
pub child: Arc<dyn PhysicalExpr>,
36-
}
37-
38-
impl Hash for StringSpaceExpr {
39-
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
40-
self.child.hash(state);
41-
}
42-
}
43-
44-
impl PartialEq for StringSpaceExpr {
45-
fn eq(&self, other: &Self) -> bool {
46-
self.child.eq(&other.child)
47-
}
30+
#[derive(Debug)]
31+
pub struct SparkStringSpace {
32+
signature: Signature,
33+
aliases: Vec<String>,
4834
}
4935

50-
impl StringSpaceExpr {
51-
pub fn new(child: Arc<dyn PhysicalExpr>) -> Self {
52-
Self { child }
36+
impl Default for SparkStringSpace {
37+
fn default() -> Self {
38+
Self::new()
5339
}
5440
}
5541

56-
impl Display for StringSpaceExpr {
57-
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
58-
write!(f, "StringSpace [child: {}] ", self.child)
42+
impl SparkStringSpace {
43+
pub fn new() -> Self {
44+
Self {
45+
signature: Signature::user_defined(Volatility::Immutable),
46+
aliases: vec![],
47+
}
5948
}
6049
}
6150

62-
impl PhysicalExpr for StringSpaceExpr {
51+
impl ScalarUDFImpl for SparkStringSpace {
6352
fn as_any(&self) -> &dyn Any {
6453
self
6554
}
6655

67-
fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
68-
unimplemented!()
56+
fn name(&self) -> &str {
57+
"string_space"
58+
}
59+
60+
fn signature(&self) -> &Signature {
61+
&self.signature
6962
}
7063

71-
fn data_type(&self, input_schema: &Schema) -> datafusion::common::Result<DataType> {
72-
match self.child.data_type(input_schema)? {
64+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
65+
Ok(match &arg_types[0] {
7366
DataType::Dictionary(key_type, _) => {
74-
Ok(DataType::Dictionary(key_type, Box::new(DataType::Utf8)))
67+
DataType::Dictionary(key_type.clone(), Box::new(DataType::Utf8))
7568
}
76-
_ => Ok(DataType::Utf8),
77-
}
69+
_ => DataType::Utf8,
70+
})
7871
}
7972

80-
fn nullable(&self, _: &Schema) -> datafusion::common::Result<bool> {
81-
Ok(true)
73+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
74+
let args: [ColumnarValue; 1] = args
75+
.args
76+
.try_into()
77+
.map_err(|_| internal_datafusion_err!("string_space expects exactly one argument"))?;
78+
spark_string_space(&args)
8279
}
8380

84-
fn evaluate(&self, batch: &RecordBatch) -> datafusion::common::Result<ColumnarValue> {
85-
let arg = self.child.evaluate(batch)?;
86-
match arg {
87-
ColumnarValue::Array(array) => {
88-
let result = string_space(&array)?;
81+
fn aliases(&self) -> &[String] {
82+
&self.aliases
83+
}
84+
}
8985

90-
Ok(ColumnarValue::Array(result))
91-
}
92-
_ => Err(DataFusionError::Execution(
93-
"StringSpace(scalar) should be fold in Spark JVM side.".to_string(),
94-
)),
86+
pub fn spark_string_space(args: &[ColumnarValue; 1]) -> Result<ColumnarValue> {
87+
match args {
88+
[ColumnarValue::Array(array)] => {
89+
let result = string_space(&array)?;
90+
91+
Ok(ColumnarValue::Array(result))
9592
}
93+
_ => exec_err!("StringSpace(scalar) should be fold in Spark JVM side."),
9694
}
95+
}
9796

98-
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
99-
vec![&self.child]
97+
fn string_space(length: &dyn Array) -> std::result::Result<ArrayRef, DataFusionError> {
98+
match length.data_type() {
99+
DataType::Int32 => {
100+
let array = length.as_any().downcast_ref::<Int32Array>().unwrap();
101+
Ok(generic_string_space::<i32>(array))
102+
}
103+
DataType::Dictionary(_, _) => {
104+
let dict = as_dictionary_array::<Int32Type>(length);
105+
let values = string_space(dict.values())?;
106+
let result = DictionaryArray::try_new(dict.keys().clone(), values)?;
107+
Ok(Arc::new(result))
108+
}
109+
other => exec_err!("Unsupported input type for function 'string_space': {other:?}"),
100110
}
111+
}
101112

102-
fn with_new_children(
103-
self: Arc<Self>,
104-
children: Vec<Arc<dyn PhysicalExpr>>,
105-
) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
106-
Ok(Arc::new(StringSpaceExpr::new(Arc::clone(&children[0]))))
107-
}
113+
fn generic_string_space<OffsetSize: OffsetSizeTrait>(length: &Int32Array) -> ArrayRef {
114+
let array_len = length.len();
115+
let mut offsets = MutableBuffer::new((array_len + 1) * std::mem::size_of::<OffsetSize>());
116+
let mut length_so_far = OffsetSize::zero();
117+
118+
// compute null bitmap (copy)
119+
let null_bit_buffer = length.to_data().nulls().map(|b| b.buffer().clone());
120+
121+
// Gets slice of length array to access it directly for performance.
122+
let length_data = length.to_data();
123+
let lengths = length_data.buffers()[0].typed_data::<i32>();
124+
let total = lengths.iter().map(|l| *l as usize).sum::<usize>();
125+
let mut values = MutableBuffer::new(total);
126+
127+
offsets.push(length_so_far);
128+
129+
let blank = " ".as_bytes()[0];
130+
values.resize(total, blank);
131+
132+
(0..array_len).for_each(|i| {
133+
let current_len = lengths[i] as usize;
134+
135+
length_so_far += OffsetSize::from_usize(current_len).unwrap();
136+
offsets.push(length_so_far);
137+
});
138+
139+
let data = unsafe {
140+
ArrayData::new_unchecked(
141+
GenericStringArray::<OffsetSize>::DATA_TYPE,
142+
array_len,
143+
None,
144+
null_bit_buffer,
145+
0,
146+
vec![offsets.into(), values.into()],
147+
vec![],
148+
)
149+
};
150+
make_array(data)
108151
}

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ object QueryPlanSerde extends Logging with CometExprShim {
130130
classOf[Rand] -> CometRand,
131131
classOf[Randn] -> CometRandn,
132132
classOf[SparkPartitionID] -> CometSparkPartitionId,
133-
classOf[MonotonicallyIncreasingID] -> CometMonotonicallyIncreasingId)
133+
classOf[MonotonicallyIncreasingID] -> CometMonotonicallyIncreasingId,
134+
classOf[StringSpace] -> UnaryScalarFuncSerde("string_space"))
134135

135136
/**
136137
* Mapping of Spark aggregate expression class to Comet expression handler.
@@ -943,14 +944,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
943944
val valueExpr = exprToProtoInternal(value, inputs, binding)
944945
scalarFunctionExprToProto("contains", attributeExpr, valueExpr)
945946

946-
case StringSpace(child) =>
947-
createUnaryExpr(
948-
expr,
949-
child,
950-
inputs,
951-
binding,
952-
(builder, unaryExpr) => builder.setStringSpace(unaryExpr))
953-
954947
case Hour(child, timeZoneId) =>
955948
val childExpr = exprToProtoInternal(child, inputs, binding)
956949

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -381,22 +381,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
381381
}
382382
}
383383

384-
test("string_space") {
385-
withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") {
386-
checkSparkAnswerAndOperator("SELECT space(_1), space(_2) FROM tbl")
387-
}
388-
}
389-
390-
test("string_space with dictionary") {
391-
val data = (0 until 1000).map(i => Tuple1(i % 5))
392-
393-
withSQLConf("parquet.enable.dictionary" -> "true") {
394-
withParquetTable(data, "tbl") {
395-
checkSparkAnswerAndOperator("SELECT space(_1) FROM tbl")
396-
}
397-
}
398-
}
399-
400384
test("hour, minute, second") {
401385
Seq(true, false).foreach { dictionaryEnabled =>
402386
withTempDir { dir =>

0 commit comments

Comments
 (0)