Skip to content

Commit 735f2cb

Browse files
authored
feat: Support round() function with two parameters (#200)
it's based on 771c20c
1 parent 00eb928 commit 735f2cb

File tree

4 files changed

+175
-7
lines changed

4 files changed

+175
-7
lines changed

datafusion/src/logical_plan/expr.rs

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1472,7 +1472,13 @@ unary_scalar_expr!(Atan, atan);
14721472
unary_scalar_expr!(Floor, floor);
14731473
unary_scalar_expr!(Ceil, ceil);
14741474
unary_scalar_expr!(Now, now);
1475-
unary_scalar_expr!(Round, round);
1475+
/// Returns the nearest integer value to the expression. Digits defaults to 0 if not provided.
1476+
pub fn round(args: Vec<Expr>) -> Expr {
1477+
Expr::ScalarFunction {
1478+
fun: functions::BuiltinScalarFunction::Round,
1479+
args,
1480+
}
1481+
}
14761482
unary_scalar_expr!(Trunc, trunc);
14771483
unary_scalar_expr!(Abs, abs);
14781484
unary_scalar_expr!(Signum, signum);
@@ -2050,6 +2056,18 @@ mod tests {
20502056
}};
20512057
}
20522058

2059+
macro_rules! test_nary_scalar_expr {
2060+
($ENUM:ident, $FUNC:ident) => {{
2061+
if let Expr::ScalarFunction { fun, args } = $FUNC(col("tableA.a")) {
2062+
let name = functions::BuiltinScalarFunction::$ENUM;
2063+
assert_eq!(name, fun);
2064+
assert_eq!(2, args.len());
2065+
} else {
2066+
assert!(false, "unexpected");
2067+
}
2068+
}};
2069+
}
2070+
20532071
#[test]
20542072
fn scalar_function_definitions() {
20552073
test_unary_scalar_expr!(Sqrt, sqrt);
@@ -2062,7 +2080,6 @@ mod tests {
20622080
test_unary_scalar_expr!(Floor, floor);
20632081
test_unary_scalar_expr!(Ceil, ceil);
20642082
test_unary_scalar_expr!(Now, now);
2065-
test_unary_scalar_expr!(Round, round);
20662083
test_unary_scalar_expr!(Trunc, trunc);
20672084
test_unary_scalar_expr!(Abs, abs);
20682085
test_unary_scalar_expr!(Signum, signum);
@@ -2104,4 +2121,25 @@ mod tests {
21042121
test_unary_scalar_expr!(Trim, trim);
21052122
test_unary_scalar_expr!(Upper, upper);
21062123
}
2124+
2125+
#[test]
2126+
fn test_round_definition() {
2127+
// test round with 1 argument
2128+
if let Expr::ScalarFunction { fun, args } = round(vec![col("tableA.a")]) {
2129+
let name = functions::BuiltinScalarFunction::Round;
2130+
assert_eq!(name, fun);
2131+
assert_eq!(1, args.len());
2132+
} else {
2133+
assert!(false, "unexpected");
2134+
}
2135+
2136+
// test round with 2 arguments
2137+
if let Expr::ScalarFunction { fun, args } = round(vec![col("tableA.a"), lit(2)]) {
2138+
let name = functions::BuiltinScalarFunction::Round;
2139+
assert_eq!(name, fun);
2140+
assert_eq!(2, args.len());
2141+
} else {
2142+
assert!(false, "unexpected");
2143+
}
2144+
}
21072145
}

datafusion/src/physical_plan/functions.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,9 @@ pub fn create_physical_fun(
571571
BuiltinScalarFunction::Log10 => Arc::new(math_expressions::log10),
572572
BuiltinScalarFunction::Log2 => Arc::new(math_expressions::log2),
573573
BuiltinScalarFunction::Random => Arc::new(math_expressions::random),
574-
BuiltinScalarFunction::Round => Arc::new(math_expressions::round),
574+
BuiltinScalarFunction::Round => {
575+
Arc::new(|args| make_scalar_function(math_expressions::round)(args))
576+
}
575577
BuiltinScalarFunction::Signum => Arc::new(math_expressions::signum),
576578
BuiltinScalarFunction::Sin => Arc::new(math_expressions::sin),
577579
BuiltinScalarFunction::Sqrt => Arc::new(math_expressions::sqrt),
@@ -1279,6 +1281,11 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature {
12791281
Signature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Utf8]),
12801282
Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8, DataType::Utf8]),
12811283
]),
1284+
BuiltinScalarFunction::Round => Signature::OneOf(vec![
1285+
Signature::Uniform(1, vec![DataType::Float64, DataType::Float32]),
1286+
Signature::Exact(vec![DataType::Float64, DataType::Int64]),
1287+
Signature::Exact(vec![DataType::Float32, DataType::Int64]),
1288+
]),
12821289
BuiltinScalarFunction::Random => Signature::Exact(vec![]),
12831290
// math expressions expect 1 argument of type f64 or f32
12841291
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we

datafusion/src/physical_plan/math_expressions.rs

Lines changed: 123 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
//! Math expressions
1919
use super::{ColumnarValue, ScalarValue};
2020
use crate::error::{DataFusionError, Result};
21-
use arrow::array::{Float32Array, Float64Array};
21+
use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
2222
use arrow::datatypes::DataType;
2323
use rand::{thread_rng, Rng};
24+
use std::any::type_name;
2425
use std::iter;
2526
use std::sync::Arc;
2627

@@ -84,6 +85,33 @@ macro_rules! math_unary_function {
8485
};
8586
}
8687

88+
macro_rules! downcast_arg {
89+
($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
90+
$ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
91+
DataFusionError::Internal(format!(
92+
"could not cast {} to {}",
93+
$NAME,
94+
type_name::<$ARRAY_TYPE>()
95+
))
96+
})?
97+
}};
98+
}
99+
100+
macro_rules! make_function_inputs2 {
101+
($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE1:ident, $ARRAY_TYPE2:ident, $FUNC: block) => {{
102+
let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE1);
103+
let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE2);
104+
105+
arg1.iter()
106+
.zip(arg2.iter())
107+
.map(|(a1, a2)| match (a1, a2) {
108+
(Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)),
109+
_ => None,
110+
})
111+
.collect::<$ARRAY_TYPE1>()
112+
}};
113+
}
114+
87115
math_unary_function!("sqrt", sqrt);
88116
math_unary_function!("sin", sin);
89117
math_unary_function!("cos", cos);
@@ -93,7 +121,6 @@ math_unary_function!("acos", acos);
93121
math_unary_function!("atan", atan);
94122
math_unary_function!("floor", floor);
95123
math_unary_function!("ceil", ceil);
96-
math_unary_function!("round", round);
97124
math_unary_function!("trunc", trunc);
98125
math_unary_function!("abs", abs);
99126
math_unary_function!("signum", signum);
@@ -118,11 +145,64 @@ pub fn random(args: &[ColumnarValue]) -> Result<ColumnarValue> {
118145
Ok(ColumnarValue::Array(Arc::new(array)))
119146
}
120147

148+
/// Round SQL function
149+
pub fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
150+
if args.len() != 1 && args.len() != 2 {
151+
return Err(DataFusionError::Internal(format!(
152+
"round function requires one or two arguments, got {}",
153+
args.len()
154+
)));
155+
}
156+
157+
let mut decimal_places =
158+
&(Arc::new(Int64Array::from_value(0, args[0].len())) as ArrayRef);
159+
160+
if args.len() == 2 {
161+
decimal_places = &args[1];
162+
}
163+
164+
match args[0].data_type() {
165+
DataType::Float64 => Ok(Arc::new(make_function_inputs2!(
166+
&args[0],
167+
decimal_places,
168+
"value",
169+
"decimal_places",
170+
Float64Array,
171+
Int64Array,
172+
{
173+
|value: f64, decimal_places: i64| {
174+
(value * 10.0_f64.powi(decimal_places.try_into().unwrap())).round()
175+
/ 10.0_f64.powi(decimal_places.try_into().unwrap())
176+
}
177+
}
178+
)) as ArrayRef),
179+
180+
DataType::Float32 => Ok(Arc::new(make_function_inputs2!(
181+
&args[0],
182+
decimal_places,
183+
"value",
184+
"decimal_places",
185+
Float32Array,
186+
Int64Array,
187+
{
188+
|value: f32, decimal_places: i64| {
189+
(value * 10.0_f32.powi(decimal_places.try_into().unwrap())).round()
190+
/ 10.0_f32.powi(decimal_places.try_into().unwrap())
191+
}
192+
}
193+
)) as ArrayRef),
194+
195+
other => Err(DataFusionError::Internal(format!(
196+
"Unsupported data type {other:?} for function round"
197+
))),
198+
}
199+
}
200+
121201
#[cfg(test)]
122202
mod tests {
123203

124204
use super::*;
125-
use arrow::array::{Float64Array, NullArray};
205+
use arrow::array::{Float32Array, Float64Array, NullArray};
126206

127207
#[test]
128208
fn test_random_expression() {
@@ -133,4 +213,44 @@ mod tests {
133213
assert_eq!(floats.len(), 1);
134214
assert!(0.0 <= floats.value(0) && floats.value(0) < 1.0);
135215
}
216+
217+
#[test]
218+
fn test_round_f32() {
219+
let args: Vec<ArrayRef> = vec![
220+
Arc::new(Float32Array::from(vec![125.2345; 10])), // input
221+
Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places
222+
];
223+
224+
let result = round(&args).expect("failed to initialize function round");
225+
let floats = result
226+
.as_any()
227+
.downcast_ref::<Float32Array>()
228+
.expect("failed to initialize function round");
229+
230+
let expected = Float32Array::from(vec![
231+
125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
232+
]);
233+
234+
assert_eq!(floats, &expected);
235+
}
236+
237+
#[test]
238+
fn test_round_f64() {
239+
let args: Vec<ArrayRef> = vec![
240+
Arc::new(Float64Array::from(vec![125.2345; 10])), // input
241+
Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places
242+
];
243+
244+
let result = round(&args).expect("failed to initialize function round");
245+
let floats = result
246+
.as_any()
247+
.downcast_ref::<Float64Array>()
248+
.expect("failed to initialize function round");
249+
250+
let expected = Float64Array::from(vec![
251+
125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
252+
]);
253+
254+
assert_eq!(floats, &expected);
255+
}
136256
}

datafusion/src/physical_plan/planner.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,10 @@ impl DefaultPhysicalPlanner {
345345
extension_planners.insert(1, Arc::new(CrossJoinPlanner {}));
346346
extension_planners.insert(2, Arc::new(CrossJoinAggPlanner {}));
347347
extension_planners.insert(3, Arc::new(crate::cube_ext::rolling::Planner {}));
348-
Self { should_evaluate_constants: true, extension_planners }
348+
Self {
349+
should_evaluate_constants: true,
350+
extension_planners,
351+
}
349352
}
350353

351354
/// Create a physical plan from a logical plan

0 commit comments

Comments
 (0)