Skip to content

Commit 5c389d1

Browse files
rluvatonandygrove
andauthored
chore: extract static invoke expressions to folders based on spark grouping (#1217)
* extract static invoke expressions to folders based on spark grouping * Update native/spark-expr/src/static_invoke/mod.rs Co-authored-by: Andy Grove <[email protected]> --------- Co-authored-by: Andy Grove <[email protected]>
1 parent e39ffa6 commit 5c389d1

File tree

6 files changed

+124
-58
lines changed

6 files changed

+124
-58
lines changed

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ use crate::scalar_funcs::hash_expressions::{
2020
};
2121
use crate::scalar_funcs::{
2222
spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_floor, spark_hex,
23-
spark_isnan, spark_make_decimal, spark_murmur3_hash, spark_read_side_padding, spark_round,
24-
spark_unhex, spark_unscaled_value, spark_xxhash64, SparkChrFunc,
23+
spark_isnan, spark_make_decimal, spark_murmur3_hash, spark_round, spark_unhex,
24+
spark_unscaled_value, spark_xxhash64, SparkChrFunc,
2525
};
26+
use crate::spark_read_side_padding;
2627
use arrow_schema::DataType;
2728
use datafusion_common::{DataFusionError, Result as DataFusionResult};
2829
use datafusion_expr::registry::FunctionRegistry;

native/spark-expr/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ mod list;
4242
mod regexp;
4343
pub mod scalar_funcs;
4444
mod schema_adapter;
45+
mod static_invoke;
4546
pub use schema_adapter::SparkSchemaAdapterFactory;
47+
pub use static_invoke::*;
4648

4749
pub mod spark_hash;
4850
mod stddev;

native/spark-expr/src/scalar_funcs.rs

Lines changed: 3 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -20,25 +20,23 @@ use arrow::datatypes::IntervalDayTime;
2020
use arrow::{
2121
array::{
2222
ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, Int16Array, Int32Array,
23-
Int64Array, Int64Builder, Int8Array, OffsetSizeTrait,
23+
Int64Array, Int64Builder, Int8Array,
2424
},
2525
datatypes::{validate_decimal_precision, Decimal128Type, Int64Type},
2626
};
27-
use arrow_array::builder::{GenericStringBuilder, IntervalDayTimeBuilder};
27+
use arrow_array::builder::IntervalDayTimeBuilder;
2828
use arrow_array::types::{Int16Type, Int32Type, Int8Type};
2929
use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Datum, Decimal128Array};
3030
use arrow_schema::{ArrowError, DataType, DECIMAL128_MAX_PRECISION};
3131
use datafusion::physical_expr_common::datum;
3232
use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
3333
use datafusion_common::{
34-
cast::as_generic_string_array, exec_err, internal_err, DataFusionError,
35-
Result as DataFusionResult, ScalarValue,
34+
exec_err, internal_err, DataFusionError, Result as DataFusionResult, ScalarValue,
3635
};
3736
use num::{
3837
integer::{div_ceil, div_floor},
3938
BigInt, Signed, ToPrimitive,
4039
};
41-
use std::fmt::Write;
4240
use std::{cmp::min, sync::Arc};
4341

4442
mod unhex;
@@ -390,57 +388,6 @@ pub fn spark_round(
390388
}
391389
}
392390

393-
/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length
394-
pub fn spark_read_side_padding(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
395-
match args {
396-
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => {
397-
match array.data_type() {
398-
DataType::Utf8 => spark_read_side_padding_internal::<i32>(array, *length),
399-
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(array, *length),
400-
// TODO: handle Dictionary types
401-
other => Err(DataFusionError::Internal(format!(
402-
"Unsupported data type {other:?} for function read_side_padding",
403-
))),
404-
}
405-
}
406-
other => Err(DataFusionError::Internal(format!(
407-
"Unsupported arguments {other:?} for function read_side_padding",
408-
))),
409-
}
410-
}
411-
412-
fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
413-
array: &ArrayRef,
414-
length: i32,
415-
) -> Result<ColumnarValue, DataFusionError> {
416-
let string_array = as_generic_string_array::<T>(array)?;
417-
let length = 0.max(length) as usize;
418-
let space_string = " ".repeat(length);
419-
420-
let mut builder =
421-
GenericStringBuilder::<T>::with_capacity(string_array.len(), string_array.len() * length);
422-
423-
for string in string_array.iter() {
424-
match string {
425-
Some(string) => {
426-
// It looks Spark's UTF8String is closer to chars rather than graphemes
427-
// https://stackoverflow.com/a/46290728
428-
let char_len = string.chars().count();
429-
if length <= char_len {
430-
builder.append_value(string);
431-
} else {
432-
// write_str updates only the value buffer, not null nor offset buffer
433-
// This is convenient for concatenating str(s)
434-
builder.write_str(string)?;
435-
builder.append_value(&space_string[char_len..]);
436-
}
437-
}
438-
_ => builder.append_null(),
439-
}
440-
}
441-
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
442-
}
443-
444391
// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3).
445392
// Conversely, Decimal(p1, s1) = Decimal(p2, s2) * Decimal(p3, s3). This means that, in order to
446393
// get enough scale that matches with Spark behavior, it requires to widen s1 to s2 + s3 + 1. Since
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
mod read_side_padding;
19+
20+
pub use read_side_padding::spark_read_side_padding;
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{ArrayRef, OffsetSizeTrait};
19+
use arrow_array::builder::GenericStringBuilder;
20+
use arrow_array::Array;
21+
use arrow_schema::DataType;
22+
use datafusion::physical_plan::ColumnarValue;
23+
use datafusion_common::{cast::as_generic_string_array, DataFusionError, ScalarValue};
24+
use std::fmt::Write;
25+
use std::sync::Arc;
26+
27+
/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length
28+
pub fn spark_read_side_padding(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
29+
match args {
30+
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => {
31+
match array.data_type() {
32+
DataType::Utf8 => spark_read_side_padding_internal::<i32>(array, *length),
33+
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(array, *length),
34+
// TODO: handle Dictionary types
35+
other => Err(DataFusionError::Internal(format!(
36+
"Unsupported data type {other:?} for function read_side_padding",
37+
))),
38+
}
39+
}
40+
other => Err(DataFusionError::Internal(format!(
41+
"Unsupported arguments {other:?} for function read_side_padding",
42+
))),
43+
}
44+
}
45+
46+
fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
47+
array: &ArrayRef,
48+
length: i32,
49+
) -> Result<ColumnarValue, DataFusionError> {
50+
let string_array = as_generic_string_array::<T>(array)?;
51+
let length = 0.max(length) as usize;
52+
let space_string = " ".repeat(length);
53+
54+
let mut builder =
55+
GenericStringBuilder::<T>::with_capacity(string_array.len(), string_array.len() * length);
56+
57+
for string in string_array.iter() {
58+
match string {
59+
Some(string) => {
60+
// It looks Spark's UTF8String is closer to chars rather than graphemes
61+
// https://stackoverflow.com/a/46290728
62+
let char_len = string.chars().count();
63+
if length <= char_len {
64+
builder.append_value(string);
65+
} else {
66+
// write_str updates only the value buffer, not null nor offset buffer
67+
// This is convenient for concatenating str(s)
68+
builder.write_str(string)?;
69+
builder.append_value(&space_string[char_len..]);
70+
}
71+
}
72+
_ => builder.append_null(),
73+
}
74+
}
75+
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
76+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
mod char_varchar_utils;
19+
20+
pub use char_varchar_utils::spark_read_side_padding;

0 commit comments

Comments
 (0)