Skip to content

Commit 62b2269

Browse files
compheadMazterQyou
authored andcommitted
Implementing math power function for SQL (apache#2324)
* Implementing POWER function * Delete pv.yaml * Delete build-ballista-docker.sh * Delete ballista.dockerfile * aligining with latest upstream changes * Readding docker files * Formatting * Leaving only 64bit types * Adding tests, remove type conversion * fix for cast * Update functions.rs (cherry picked from commit c3c02cf) Can drop this after rebase on commit c3c02cf "Implementing math power function for SQL (apache#2324)", first released in 8.0.0 # Conflicts: # datafusion/core/src/logical_plan/mod.rs # datafusion/core/src/physical_plan/functions.rs # datafusion/core/tests/sql/functions.rs # datafusion/cube_ext/Cargo.toml # datafusion/expr/src/built_in_function.rs # datafusion/expr/src/function.rs # datafusion/proto/proto/datafusion.proto # datafusion/proto/src/from_proto.rs # datafusion/proto/src/to_proto.rs # dev/docker/ballista.dockerfile
1 parent 1a612fc commit 62b2269

File tree

10 files changed

+220
-10
lines changed

10 files changed

+220
-10
lines changed

datafusion/core/src/logical_plan/mod.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,12 @@ pub use expr::{
4848
count, count_distinct, create_udaf, create_udf, create_udtf, date_part, date_trunc,
4949
digest, exp, exprlist_to_fields, exprlist_to_fields_from_schema, floor, in_list,
5050
initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim,
51-
max, md5, min, now, now_expr, nullif, octet_length, or, pi, random, regexp_match,
52-
regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256,
53-
sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan,
54-
to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_seconds, translate,
55-
trim, trunc, unalias, upper, when, Column, Expr, ExprSchema, GroupingSet, Like,
56-
Literal,
51+
max, md5, min, now, now_expr, nullif, octet_length, or, pi, power, random,
52+
regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim,
53+
sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos,
54+
substr, sum, tan, to_hex, to_timestamp_micros, to_timestamp_millis,
55+
to_timestamp_seconds, translate, trim, trunc, unalias, upper, when, Column, Expr,
56+
ExprSchema, GroupingSet, Like, Literal,
5757
};
5858
pub use expr_rewriter::{
5959
normalize_col, normalize_cols, replace_col, replace_col_to_expr,

datafusion/core/src/physical_plan/functions.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,10 @@ pub fn create_physical_fun(
312312
BuiltinScalarFunction::Sqrt => Arc::new(math_expressions::sqrt),
313313
BuiltinScalarFunction::Tan => Arc::new(math_expressions::tan),
314314
BuiltinScalarFunction::Trunc => Arc::new(math_expressions::trunc),
315+
BuiltinScalarFunction::Power => {
316+
Arc::new(|args| make_scalar_function(math_expressions::power)(args))
317+
}
318+
315319
BuiltinScalarFunction::Pi => Arc::new(math_expressions::pi),
316320
// string functions
317321
BuiltinScalarFunction::MakeArray => Arc::new(array_expressions::array),

datafusion/core/tests/sql/functions.rs

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,132 @@ async fn case_builtin_math_expression() {
555555
}
556556
}
557557

558+
#[tokio::test]
559+
async fn test_power() -> Result<()> {
560+
let schema = Arc::new(Schema::new(vec![
561+
Field::new("i32", DataType::Int16, true),
562+
Field::new("i64", DataType::Int64, true),
563+
Field::new("f32", DataType::Float32, true),
564+
Field::new("f64", DataType::Float64, true),
565+
]));
566+
567+
let data = RecordBatch::try_new(
568+
schema.clone(),
569+
vec![
570+
Arc::new(Int16Array::from(vec![
571+
Some(2),
572+
Some(5),
573+
Some(0),
574+
Some(-14),
575+
None,
576+
])),
577+
Arc::new(Int64Array::from(vec![
578+
Some(2),
579+
Some(5),
580+
Some(0),
581+
Some(-14),
582+
None,
583+
])),
584+
Arc::new(Float32Array::from(vec![
585+
Some(1.0),
586+
Some(2.5),
587+
Some(0.0),
588+
Some(-14.5),
589+
None,
590+
])),
591+
Arc::new(Float64Array::from(vec![
592+
Some(1.0),
593+
Some(2.5),
594+
Some(0.0),
595+
Some(-14.5),
596+
None,
597+
])),
598+
],
599+
)?;
600+
601+
let table = MemTable::try_new(schema, vec![vec![data]])?;
602+
603+
let ctx = SessionContext::new();
604+
ctx.register_table("test", Arc::new(table))?;
605+
let sql = r"SELECT power(i32, exp_i) as power_i32,
606+
power(i64, exp_f) as power_i64,
607+
power(f32, exp_i) as power_f32,
608+
power(f64, exp_f) as power_f64,
609+
power(2, 3) as power_int_scalar,
610+
power(2.5, 3.0) as power_float_scalar
611+
FROM (select test.*, 3 as exp_i, 3.0 as exp_f from test) a";
612+
let actual = execute_to_batches(&ctx, sql).await;
613+
let expected = vec![
614+
"+-----------+-----------+-----------+-----------+------------------+--------------------+",
615+
"| power_i32 | power_i64 | power_f32 | power_f64 | power_int_scalar | power_float_scalar |",
616+
"+-----------+-----------+-----------+-----------+------------------+--------------------+",
617+
"| 8 | 8 | 1 | 1 | 8 | 15.625 |",
618+
"| 125 | 125 | 15.625 | 15.625 | 8 | 15.625 |",
619+
"| 0 | 0 | 0 | 0 | 8 | 15.625 |",
620+
"| -2744 | -2744 | -3048.625 | -3048.625 | 8 | 15.625 |",
621+
"| | | | | 8 | 15.625 |",
622+
"+-----------+-----------+-----------+-----------+------------------+--------------------+",
623+
];
624+
assert_batches_eq!(expected, &actual);
625+
//dbg!(actual[0].schema().fields());
626+
assert_eq!(
627+
actual[0]
628+
.schema()
629+
.field_with_name("power_i32")
630+
.unwrap()
631+
.data_type()
632+
.to_owned(),
633+
DataType::Int64
634+
);
635+
assert_eq!(
636+
actual[0]
637+
.schema()
638+
.field_with_name("power_i64")
639+
.unwrap()
640+
.data_type()
641+
.to_owned(),
642+
DataType::Float64
643+
);
644+
assert_eq!(
645+
actual[0]
646+
.schema()
647+
.field_with_name("power_f32")
648+
.unwrap()
649+
.data_type()
650+
.to_owned(),
651+
DataType::Float64
652+
);
653+
assert_eq!(
654+
actual[0]
655+
.schema()
656+
.field_with_name("power_f64")
657+
.unwrap()
658+
.data_type()
659+
.to_owned(),
660+
DataType::Float64
661+
);
662+
assert_eq!(
663+
actual[0]
664+
.schema()
665+
.field_with_name("power_int_scalar")
666+
.unwrap()
667+
.data_type()
668+
.to_owned(),
669+
DataType::Int64
670+
);
671+
assert_eq!(
672+
actual[0]
673+
.schema()
674+
.field_with_name("power_float_scalar")
675+
.unwrap()
676+
.data_type()
677+
.to_owned(),
678+
DataType::Float64
679+
);
680+
681+
Ok(())
682+
}
683+
558684
// #[tokio::test]
559685
// async fn case_sensitive_identifiers_aggregates() {
560686
// let ctx = SessionContext::new();

datafusion/expr/src/built_in_function.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ pub enum BuiltinScalarFunction {
5454
Log10,
5555
/// log2
5656
Log2,
57+
/// power
58+
Power,
5759
/// pi
5860
Pi,
5961
/// round
@@ -196,6 +198,7 @@ impl BuiltinScalarFunction {
196198
BuiltinScalarFunction::Log => Volatility::Immutable,
197199
BuiltinScalarFunction::Log10 => Volatility::Immutable,
198200
BuiltinScalarFunction::Log2 => Volatility::Immutable,
201+
BuiltinScalarFunction::Power => Volatility::Immutable,
199202
BuiltinScalarFunction::Pi => Volatility::Immutable,
200203
BuiltinScalarFunction::Round => Volatility::Immutable,
201204
BuiltinScalarFunction::Signum => Volatility::Immutable,
@@ -284,6 +287,7 @@ impl FromStr for BuiltinScalarFunction {
284287
"log" => BuiltinScalarFunction::Log,
285288
"log10" => BuiltinScalarFunction::Log10,
286289
"log2" => BuiltinScalarFunction::Log2,
290+
"power" => BuiltinScalarFunction::Power,
287291
"pi" => BuiltinScalarFunction::Pi,
288292
"round" => BuiltinScalarFunction::Round,
289293
"signum" => BuiltinScalarFunction::Signum,

datafusion/expr/src/expr_fn.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ unary_scalar_expr!(Log2, log2);
266266
unary_scalar_expr!(Log10, log10);
267267
unary_scalar_expr!(Ln, ln);
268268
unary_scalar_expr!(NullIf, nullif);
269+
scalar_expr!(Power, power, base, exponent);
269270

270271
// string functions
271272
scalar_expr!(Ascii, ascii, string);

datafusion/expr/src/function.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,11 @@ pub fn return_type(
244244
}
245245
}),
246246

247+
BuiltinScalarFunction::Power => match &input_expr_types[0] {
248+
DataType::Int64 => Ok(DataType::Int64),
249+
_ => Ok(DataType::Float64),
250+
},
251+
247252
BuiltinScalarFunction::Abs
248253
| BuiltinScalarFunction::Acos
249254
| BuiltinScalarFunction::Asin
@@ -550,6 +555,13 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature {
550555
),
551556
BuiltinScalarFunction::Pi => Signature::exact(vec![], fun.volatility()),
552557
BuiltinScalarFunction::Random => Signature::exact(vec![], fun.volatility()),
558+
BuiltinScalarFunction::Power => Signature::one_of(
559+
vec![
560+
TypeSignature::Exact(vec![DataType::Int64, DataType::Int64]),
561+
TypeSignature::Exact(vec![DataType::Float64, DataType::Float64]),
562+
],
563+
fun.volatility(),
564+
),
553565
BuiltinScalarFunction::Log => Signature::one_of(
554566
vec![
555567
TypeSignature::Exact(vec![DataType::Float64]),

datafusion/physical-expr/src/math_expressions.rs

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717

1818
//! Math expressions
1919
20-
use arrow::array::{Float32Array, Float64Array};
20+
use arrow::array::ArrayRef;
21+
use arrow::array::{Float32Array, Float64Array, Int64Array};
2122
use arrow::datatypes::DataType;
2223
use datafusion_common::ScalarValue;
2324
use datafusion_common::{DataFusionError, Result};
2425
use datafusion_expr::ColumnarValue;
2526
use rand::{thread_rng, Rng};
27+
use std::any::type_name;
2628
use std::iter;
2729
use std::sync::Arc;
2830

@@ -86,6 +88,33 @@ macro_rules! math_unary_function {
8688
};
8789
}
8890

91+
macro_rules! downcast_arg {
92+
($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
93+
$ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
94+
DataFusionError::Internal(format!(
95+
"could not cast {} to {}",
96+
$NAME,
97+
type_name::<$ARRAY_TYPE>()
98+
))
99+
})?
100+
}};
101+
}
102+
103+
macro_rules! make_function_inputs2 {
104+
($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{
105+
let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE);
106+
let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE);
107+
108+
arg1.iter()
109+
.zip(arg2.iter())
110+
.map(|(a1, a2)| match (a1, a2) {
111+
(Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)),
112+
_ => None,
113+
})
114+
.collect::<$ARRAY_TYPE>()
115+
}};
116+
}
117+
89118
math_unary_function!("sqrt", sqrt);
90119
math_unary_function!("sin", sin);
91120
math_unary_function!("cos", cos);
@@ -131,6 +160,33 @@ pub fn random(args: &[ColumnarValue]) -> Result<ColumnarValue> {
131160
Ok(ColumnarValue::Array(Arc::new(array)))
132161
}
133162

163+
pub fn power(args: &[ArrayRef]) -> Result<ArrayRef> {
164+
match args[0].data_type() {
165+
DataType::Float64 => Ok(Arc::new(make_function_inputs2!(
166+
&args[0],
167+
&args[1],
168+
"base",
169+
"exponent",
170+
Float64Array,
171+
{ f64::powf }
172+
)) as ArrayRef),
173+
174+
DataType::Int64 => Ok(Arc::new(make_function_inputs2!(
175+
&args[0],
176+
&args[1],
177+
"base",
178+
"exponent",
179+
Int64Array,
180+
{ i64::pow }
181+
)) as ArrayRef),
182+
183+
other => Err(DataFusionError::Internal(format!(
184+
"Unsupported data type {:?} for function power",
185+
other
186+
))),
187+
}
188+
}
189+
134190
#[cfg(test)]
135191
mod tests {
136192

datafusion/proto/proto/datafusion.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ enum ScalarFunction {
190190
Upper=62;
191191
Coalesce=63;
192192
// Upstream
193+
Power=64;
193194
CurrentDate=70;
194195
Pi=80;
195196
// Cubesql

datafusion/proto/src/from_proto.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ use datafusion::{
2525
logical_plan::{
2626
abs, acos, ascii, asin, atan, ceil, character_length, chr, concat_expr,
2727
concat_ws_expr, cos, digest, exp, floor, left, ln, log10, log2, now_expr, nullif,
28-
pi, random, regexp_replace, repeat, replace, reverse, right, round, signum, sin,
29-
split_part, sqrt, starts_with, strpos, substr, tan, to_hex, to_timestamp_micros,
30-
to_timestamp_millis, to_timestamp_seconds, translate, trunc,
28+
pi, power, random, regexp_replace, repeat, replace, reverse, right, round,
29+
signum, sin, split_part, sqrt, starts_with, strpos, substr, tan, to_hex,
30+
to_timestamp_micros, to_timestamp_millis, to_timestamp_seconds, translate, trunc,
3131
window_frames::{WindowFrame, WindowFrameBound, WindowFrameUnits},
3232
Column, DFField, DFSchema, DFSchemaRef, Expr, Like, Operator,
3333
},
@@ -430,6 +430,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
430430
ScalarFunction::Translate => Self::Translate,
431431
ScalarFunction::RegexpMatch => Self::RegexpMatch,
432432
ScalarFunction::Coalesce => Self::Coalesce,
433+
ScalarFunction::Power => Self::Power,
433434
ScalarFunction::Pi => Self::Pi,
434435
// Cube SQL
435436
ScalarFunction::UtcTimestamp => Self::UtcTimestamp,
@@ -1232,6 +1233,10 @@ pub fn parse_expr(
12321233
.map(|expr| parse_expr(expr, registry))
12331234
.collect::<Result<Vec<_>, _>>()?,
12341235
)),
1236+
ScalarFunction::Power => Ok(power(
1237+
parse_expr(&args[0], registry)?,
1238+
parse_expr(&args[1], registry)?,
1239+
)),
12351240
ScalarFunction::Pi => Ok(pi()),
12361241
_ => Err(proto_error(
12371242
"Protobuf deserialization error: Unsupported scalar function",

datafusion/proto/src/to_proto.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,6 +1079,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
10791079
BuiltinScalarFunction::Translate => Self::Translate,
10801080
BuiltinScalarFunction::RegexpMatch => Self::RegexpMatch,
10811081
BuiltinScalarFunction::Coalesce => Self::Coalesce,
1082+
BuiltinScalarFunction::Power => Self::Power,
10821083
BuiltinScalarFunction::Pi => Self::Pi,
10831084
// Cube SQL
10841085
BuiltinScalarFunction::UtcTimestamp => Self::UtcTimestamp,

0 commit comments

Comments
 (0)