Skip to content

Commit fc5e9e0

Browse files
mcheshkovMazterQyou
authored andcommitted
Support round() function with two parameters (apache#5807)
Can drop this after rebase on commit 771c20c "Support round() function with two parameters (apache#5807)", first released in 22.0.0
1 parent 20001fe commit fc5e9e0

File tree

5 files changed

+128
-13
lines changed

5 files changed

+128
-13
lines changed

datafusion/core/src/optimizer/projection_drop_out.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ mod tests {
631631
)?
632632
.project_with_alias(
633633
vec![
634-
round(col("id")).alias("first"),
634+
round(vec![col("id")]).alias("first"),
635635
col("n").alias("second"),
636636
lit(2).alias("third"),
637637
],
@@ -649,7 +649,7 @@ mod tests {
649649
// select * from (select id first, a second, 2 third from (select round(a) id, 1 num from table) a) x;
650650
let plan = LogicalPlanBuilder::from(table_scan)
651651
.project_with_alias(
652-
vec![round(col("a")).alias("id"), lit(1).alias("n")],
652+
vec![round(vec![col("a")]).alias("id"), lit(1).alias("n")],
653653
Some("a".to_string()),
654654
)?
655655
.project_with_alias(
@@ -748,7 +748,7 @@ mod tests {
748748
)?
749749
.project_with_alias(
750750
vec![
751-
round(col("id")).alias("first"),
751+
round(vec![col("id")]).alias("first"),
752752
col("n").alias("second"),
753753
lit(2).alias("third"),
754754
],
@@ -826,7 +826,10 @@ mod tests {
826826
let plan = LogicalPlanBuilder::from(table_scan)
827827
.project_with_alias(vec![col("a").alias("id")], Some("a".to_string()))?
828828
.project_with_alias(
829-
vec![round(col("id")).alias("first"), lit(2).alias("second")],
829+
vec![
830+
round(vec![col("id")]).alias("first"),
831+
lit(2).alias("second"),
832+
],
830833
Some("b".to_string()),
831834
)?
832835
.sort(vec![col("first")])?
@@ -1019,7 +1022,7 @@ mod tests {
10191022
.project_with_alias(vec![col("a").alias("num")], Some("a".to_string()))?
10201023
.project_with_alias(vec![col("num")], Some("b".to_string()))?
10211024
.filter(col("num").gt(lit(0)))?
1022-
.aggregate(vec![round(col("num"))], Vec::<Expr>::new())?
1025+
.aggregate(vec![round(vec![col("num")])], Vec::<Expr>::new())?
10231026
.project(vec![col("Round(b.num)")])?
10241027
.sort(vec![col("Round(b.num)")])?
10251028
.build()?;
@@ -1044,7 +1047,7 @@ mod tests {
10441047
let plan = LogicalPlanBuilder::from(table_scan.clone())
10451048
.project_with_alias(vec![col("a").alias("num")], Some("a".to_string()))?
10461049
.project_with_alias(vec![col("num")], Some("b".to_string()))?
1047-
.aggregate(vec![round(col("num"))], Vec::<Expr>::new())?
1050+
.aggregate(vec![round(vec![col("num")])], Vec::<Expr>::new())?
10481051
.project(vec![col("Round(b.num)")])?
10491052
.sort(vec![col("Round(b.num)")])?
10501053
.build()?;
@@ -1061,7 +1064,7 @@ mod tests {
10611064
let plan = LogicalPlanBuilder::from(table_scan)
10621065
.project_with_alias(vec![col("a").alias("num")], Some("a".to_string()))?
10631066
.project_with_alias(vec![col("num")], Some("b".to_string()))?
1064-
.aggregate(vec![round(col("num"))], Vec::<Expr>::new())?
1067+
.aggregate(vec![round(vec![col("num")])], Vec::<Expr>::new())?
10651068
.project(vec![col("Round(b.num)")])?
10661069
.sort(vec![col("Round(b.num)")])?
10671070
.project_with_alias(vec![col("Round(b.num)")], Some("x".to_string()))?
@@ -1099,7 +1102,7 @@ mod tests {
10991102
.project_with_alias(vec![col("num")], Some("x".to_string()))?
11001103
.join(&right, JoinType::Left, (vec!["num"], vec!["a"]))?
11011104
.project_with_alias(
1102-
vec![col("num"), col("a"), round(col("num"))],
1105+
vec![col("num"), col("a"), round(vec![col("num")])],
11031106
Some("b".to_string()),
11041107
)?
11051108
.build()?;

datafusion/core/src/physical_plan/functions.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,9 @@ pub fn create_physical_fun(
306306
BuiltinScalarFunction::Log10 => Arc::new(math_expressions::log10),
307307
BuiltinScalarFunction::Log2 => Arc::new(math_expressions::log2),
308308
BuiltinScalarFunction::Random => Arc::new(math_expressions::random),
309-
BuiltinScalarFunction::Round => Arc::new(math_expressions::round),
309+
BuiltinScalarFunction::Round => {
310+
Arc::new(|args| make_scalar_function(math_expressions::round)(args))
311+
}
310312
BuiltinScalarFunction::Signum => Arc::new(math_expressions::signum),
311313
BuiltinScalarFunction::Sin => Arc::new(math_expressions::sin),
312314
BuiltinScalarFunction::Sqrt => Arc::new(math_expressions::sqrt),

datafusion/expr/src/expr_fn.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ unary_scalar_expr!(Atan, atan);
257257
unary_scalar_expr!(Floor, floor);
258258
unary_scalar_expr!(Ceil, ceil);
259259
unary_scalar_expr!(Now, now);
260-
unary_scalar_expr!(Round, round);
260+
nary_scalar_expr!(Round, round);
261261
unary_scalar_expr!(Trunc, trunc);
262262
unary_scalar_expr!(Abs, abs);
263263
unary_scalar_expr!(Signum, signum);
@@ -418,7 +418,8 @@ mod test {
418418
test_unary_scalar_expr!(Floor, floor);
419419
test_unary_scalar_expr!(Ceil, ceil);
420420
test_unary_scalar_expr!(Now, now);
421-
test_unary_scalar_expr!(Round, round);
421+
test_nary_scalar_expr!(Round, round, input);
422+
test_nary_scalar_expr!(Round, round, input, decimal_places);
422423
test_unary_scalar_expr!(Trunc, trunc);
423424
test_unary_scalar_expr!(Abs, abs);
424425
test_unary_scalar_expr!(Signum, signum);

datafusion/physical-expr/src/math_expressions.rs

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,18 @@ macro_rules! make_function_inputs2 {
113113
})
114114
.collect::<$ARRAY_TYPE>()
115115
}};
116+
($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE1:ident, $ARRAY_TYPE2:ident, $FUNC: block) => {{
117+
let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE1);
118+
let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE2);
119+
120+
arg1.iter()
121+
.zip(arg2.iter())
122+
.map(|(a1, a2)| match (a1, a2) {
123+
(Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)),
124+
_ => None,
125+
})
126+
.collect::<$ARRAY_TYPE1>()
127+
}};
116128
}
117129

118130
math_unary_function!("sqrt", sqrt);
@@ -124,7 +136,6 @@ math_unary_function!("acos", acos);
124136
math_unary_function!("atan", atan);
125137
math_unary_function!("floor", floor);
126138
math_unary_function!("ceil", ceil);
127-
math_unary_function!("round", round);
128139
math_unary_function!("trunc", trunc);
129140
math_unary_function!("abs", abs);
130141
math_unary_function!("signum", signum);
@@ -160,6 +171,59 @@ pub fn random(args: &[ColumnarValue]) -> Result<ColumnarValue> {
160171
Ok(ColumnarValue::Array(Arc::new(array)))
161172
}
162173

174+
/// Round SQL function
175+
pub fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
176+
if args.len() != 1 && args.len() != 2 {
177+
return Err(DataFusionError::Internal(format!(
178+
"round function requires one or two arguments, got {}",
179+
args.len()
180+
)));
181+
}
182+
183+
let mut decimal_places =
184+
&(Arc::new(Int64Array::from_value(0, args[0].len())) as ArrayRef);
185+
186+
if args.len() == 2 {
187+
decimal_places = &args[1];
188+
}
189+
190+
match args[0].data_type() {
191+
DataType::Float64 => Ok(Arc::new(make_function_inputs2!(
192+
&args[0],
193+
decimal_places,
194+
"value",
195+
"decimal_places",
196+
Float64Array,
197+
Int64Array,
198+
{
199+
|value: f64, decimal_places: i64| {
200+
(value * 10.0_f64.powi(decimal_places.try_into().unwrap())).round()
201+
/ 10.0_f64.powi(decimal_places.try_into().unwrap())
202+
}
203+
}
204+
)) as ArrayRef),
205+
206+
DataType::Float32 => Ok(Arc::new(make_function_inputs2!(
207+
&args[0],
208+
decimal_places,
209+
"value",
210+
"decimal_places",
211+
Float32Array,
212+
Int64Array,
213+
{
214+
|value: f32, decimal_places: i64| {
215+
(value * 10.0_f32.powi(decimal_places.try_into().unwrap())).round()
216+
/ 10.0_f32.powi(decimal_places.try_into().unwrap())
217+
}
218+
}
219+
)) as ArrayRef),
220+
221+
other => Err(DataFusionError::Internal(format!(
222+
"Unsupported data type {other:?} for function round"
223+
))),
224+
}
225+
}
226+
163227
pub fn power(args: &[ArrayRef]) -> Result<ArrayRef> {
164228
match args[0].data_type() {
165229
DataType::Float64 => Ok(Arc::new(make_function_inputs2!(
@@ -202,4 +266,44 @@ mod tests {
202266
assert_eq!(floats.len(), 1);
203267
assert!(0.0 <= floats.value(0) && floats.value(0) < 1.0);
204268
}
269+
270+
#[test]
271+
fn test_round_f32() {
272+
let args: Vec<ArrayRef> = vec![
273+
Arc::new(Float32Array::from(vec![125.2345; 10])), // input
274+
Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places
275+
];
276+
277+
let result = round(&args).expect("failed to initialize function round");
278+
let floats = result
279+
.as_any()
280+
.downcast_ref::<Float32Array>()
281+
.expect("failed to initialize function round");
282+
283+
let expected = Float32Array::from(vec![
284+
125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
285+
]);
286+
287+
assert_eq!(floats, &expected);
288+
}
289+
290+
#[test]
291+
fn test_round_f64() {
292+
let args: Vec<ArrayRef> = vec![
293+
Arc::new(Float64Array::from(vec![125.2345; 10])), // input
294+
Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places
295+
];
296+
297+
let result = round(&args).expect("failed to initialize function round");
298+
let floats = result
299+
.as_any()
300+
.downcast_ref::<Float64Array>()
301+
.expect("failed to initialize function round");
302+
303+
let expected = Float64Array::from(vec![
304+
125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
305+
]);
306+
307+
assert_eq!(floats, &expected);
308+
}
205309
}

datafusion/proto/src/from_proto.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1089,7 +1089,12 @@ pub fn parse_expr(
10891089
ScalarFunction::Log10 => Ok(log10(parse_expr(&args[0], registry)?)),
10901090
ScalarFunction::Floor => Ok(floor(parse_expr(&args[0], registry)?)),
10911091
ScalarFunction::Ceil => Ok(ceil(parse_expr(&args[0], registry)?)),
1092-
ScalarFunction::Round => Ok(round(parse_expr(&args[0], registry)?)),
1092+
ScalarFunction::Round => Ok(round(
1093+
args.to_owned()
1094+
.iter()
1095+
.map(|expr| parse_expr(expr, registry))
1096+
.collect::<Result<Vec<_>, _>>()?,
1097+
)),
10931098
ScalarFunction::Trunc => Ok(trunc(parse_expr(&args[0], registry)?)),
10941099
ScalarFunction::Abs => Ok(abs(parse_expr(&args[0], registry)?)),
10951100
ScalarFunction::Signum => Ok(signum(parse_expr(&args[0], registry)?)),

0 commit comments

Comments
 (0)