Skip to content

Commit f1f6e5e

Browse files
jayzhan211alamb
andauthored
Optimize gcd for array and scalar case by avoiding make_scalar_function where has unnecessary conversion between scalar and array (#14834)
* optimize gcd * fmt * add feature * Use try_binary to make gcd even faster * rm length check --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 1fedb4e commit f1f6e5e

File tree

4 files changed

+158
-78
lines changed

4 files changed

+158
-78
lines changed

datafusion/functions/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ harness = false
113113
name = "chr"
114114
required-features = ["string_expressions"]
115115

116+
[[bench]]
117+
harness = false
118+
name = "gcd"
119+
required-features = ["math_expressions"]
120+
116121
[[bench]]
117122
harness = false
118123
name = "uuid"
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
extern crate criterion;
19+
20+
use arrow::{
21+
array::{ArrayRef, Int64Array},
22+
datatypes::DataType,
23+
};
24+
use criterion::{black_box, criterion_group, criterion_main, Criterion};
25+
use datafusion_common::ScalarValue;
26+
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
27+
use datafusion_functions::math::gcd;
28+
use rand::Rng;
29+
use std::sync::Arc;
30+
31+
fn generate_i64_array(n_rows: usize) -> ArrayRef {
32+
let mut rng = rand::thread_rng();
33+
let values = (0..n_rows)
34+
.map(|_| rng.gen_range(0..1000))
35+
.collect::<Vec<_>>();
36+
Arc::new(Int64Array::from(values)) as ArrayRef
37+
}
38+
39+
fn criterion_benchmark(c: &mut Criterion) {
40+
let n_rows = 100000;
41+
let array_a = ColumnarValue::Array(generate_i64_array(n_rows));
42+
let array_b = ColumnarValue::Array(generate_i64_array(n_rows));
43+
let udf = gcd();
44+
45+
c.bench_function("gcd both array", |b| {
46+
b.iter(|| {
47+
black_box(
48+
udf.invoke_with_args(ScalarFunctionArgs {
49+
args: vec![array_a.clone(), array_b.clone()],
50+
number_rows: 0,
51+
return_type: &DataType::Int64,
52+
})
53+
.expect("date_bin should work on valid values"),
54+
)
55+
})
56+
});
57+
58+
// 10! = 3628800
59+
let scalar_b = ColumnarValue::Scalar(ScalarValue::Int64(Some(3628800)));
60+
61+
c.bench_function("gcd array and scalar", |b| {
62+
b.iter(|| {
63+
black_box(
64+
udf.invoke_with_args(ScalarFunctionArgs {
65+
args: vec![array_a.clone(), scalar_b.clone()],
66+
number_rows: 0,
67+
return_type: &DataType::Int64,
68+
})
69+
.expect("date_bin should work on valid values"),
70+
)
71+
})
72+
});
73+
74+
// scalar and scalar
75+
let scalar_a = ColumnarValue::Scalar(ScalarValue::Int64(Some(3628800)));
76+
77+
c.bench_function("gcd both scalar", |b| {
78+
b.iter(|| {
79+
black_box(
80+
udf.invoke_with_args(ScalarFunctionArgs {
81+
args: vec![scalar_a.clone(), scalar_b.clone()],
82+
number_rows: 0,
83+
return_type: &DataType::Int64,
84+
})
85+
.expect("date_bin should work on valid values"),
86+
)
87+
})
88+
});
89+
}
90+
91+
criterion_group!(benches, criterion_benchmark);
92+
criterion_main!(benches);

datafusion/functions/src/math/gcd.rs

Lines changed: 59 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,15 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::array::{ArrayRef, Int64Array};
18+
use arrow::array::{new_null_array, ArrayRef, AsArray, Int64Array, PrimitiveArray};
19+
use arrow::compute::try_binary;
20+
use arrow::datatypes::{DataType, Int64Type};
1921
use arrow::error::ArrowError;
2022
use std::any::Any;
2123
use std::mem::swap;
2224
use std::sync::Arc;
2325

24-
use arrow::datatypes::DataType;
25-
use arrow::datatypes::DataType::Int64;
26-
27-
use crate::utils::make_scalar_function;
28-
use datafusion_common::{
29-
arrow_datafusion_err, exec_err, internal_datafusion_err, DataFusionError, Result,
30-
};
26+
use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue};
3127
use datafusion_expr::{
3228
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
3329
Volatility,
@@ -54,9 +50,12 @@ impl Default for GcdFunc {
5450

5551
impl GcdFunc {
5652
pub fn new() -> Self {
57-
use DataType::*;
5853
Self {
59-
signature: Signature::uniform(2, vec![Int64], Volatility::Immutable),
54+
signature: Signature::uniform(
55+
2,
56+
vec![DataType::Int64],
57+
Volatility::Immutable,
58+
),
6059
}
6160
}
6261
}
@@ -75,36 +74,69 @@ impl ScalarUDFImpl for GcdFunc {
7574
}
7675

7776
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
78-
Ok(Int64)
77+
Ok(DataType::Int64)
7978
}
8079

8180
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
82-
make_scalar_function(gcd, vec![])(&args.args)
81+
let args: [ColumnarValue; 2] = args.args.try_into().map_err(|_| {
82+
internal_datafusion_err!("Expected 2 arguments for function gcd")
83+
})?;
84+
85+
match args {
86+
[ColumnarValue::Array(a), ColumnarValue::Array(b)] => {
87+
compute_gcd_for_arrays(&a, &b)
88+
}
89+
[ColumnarValue::Scalar(ScalarValue::Int64(a)), ColumnarValue::Scalar(ScalarValue::Int64(b))] => {
90+
match (a, b) {
91+
(Some(a), Some(b)) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(
92+
Some(compute_gcd(a, b)?),
93+
))),
94+
_ => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))),
95+
}
96+
}
97+
[ColumnarValue::Array(a), ColumnarValue::Scalar(ScalarValue::Int64(b))] => {
98+
compute_gcd_with_scalar(&a, b)
99+
}
100+
[ColumnarValue::Scalar(ScalarValue::Int64(a)), ColumnarValue::Array(b)] => {
101+
compute_gcd_with_scalar(&b, a)
102+
}
103+
_ => exec_err!("Unsupported argument types for function gcd"),
104+
}
83105
}
84106

85107
fn documentation(&self) -> Option<&Documentation> {
86108
self.doc()
87109
}
88110
}
89111

90-
/// Gcd SQL function
91-
fn gcd(args: &[ArrayRef]) -> Result<ArrayRef> {
92-
match args[0].data_type() {
93-
Int64 => {
94-
let arg1 = downcast_named_arg!(&args[0], "x", Int64Array);
95-
let arg2 = downcast_named_arg!(&args[1], "y", Int64Array);
112+
fn compute_gcd_for_arrays(a: &ArrayRef, b: &ArrayRef) -> Result<ColumnarValue> {
113+
let a = a.as_primitive::<Int64Type>();
114+
let b = b.as_primitive::<Int64Type>();
115+
try_binary(a, b, compute_gcd)
116+
.map(|arr: PrimitiveArray<Int64Type>| {
117+
ColumnarValue::Array(Arc::new(arr) as ArrayRef)
118+
})
119+
.map_err(Into::into) // convert ArrowError to DataFusionError
120+
}
96121

97-
Ok(arg1
122+
fn compute_gcd_with_scalar(arr: &ArrayRef, scalar: Option<i64>) -> Result<ColumnarValue> {
123+
match scalar {
124+
Some(scalar_value) => {
125+
let result: Result<Int64Array> = arr
126+
.as_primitive::<Int64Type>()
98127
.iter()
99-
.zip(arg2.iter())
100-
.map(|(a1, a2)| match (a1, a2) {
101-
(Some(a1), Some(a2)) => Ok(Some(compute_gcd(a1, a2)?)),
128+
.map(|val| match val {
129+
Some(val) => Ok(Some(compute_gcd(val, scalar_value)?)),
102130
_ => Ok(None),
103131
})
104-
.collect::<Result<Int64Array>>()
105-
.map(Arc::new)? as ArrayRef)
132+
.collect();
133+
134+
result.map(|arr| ColumnarValue::Array(Arc::new(arr) as ArrayRef))
106135
}
107-
other => exec_err!("Unsupported data type {other:?} for function gcd"),
136+
None => Ok(ColumnarValue::Array(new_null_array(
137+
&DataType::Int64,
138+
arr.len(),
139+
))),
108140
}
109141
}
110142

@@ -132,61 +164,12 @@ pub(super) fn unsigned_gcd(mut a: u64, mut b: u64) -> u64 {
132164
}
133165

134166
/// Computes greatest common divisor using Binary GCD algorithm.
135-
pub fn compute_gcd(x: i64, y: i64) -> Result<i64> {
167+
pub fn compute_gcd(x: i64, y: i64) -> Result<i64, ArrowError> {
136168
let a = x.unsigned_abs();
137169
let b = y.unsigned_abs();
138170
let r = unsigned_gcd(a, b);
139171
// gcd(i64::MIN, i64::MIN) = i64::MIN.unsigned_abs() cannot fit into i64
140172
r.try_into().map_err(|_| {
141-
arrow_datafusion_err!(ArrowError::ComputeError(format!(
142-
"Signed integer overflow in GCD({x}, {y})"
143-
)))
173+
ArrowError::ComputeError(format!("Signed integer overflow in GCD({x}, {y})"))
144174
})
145175
}
146-
147-
#[cfg(test)]
148-
mod test {
149-
use std::sync::Arc;
150-
151-
use arrow::{
152-
array::{ArrayRef, Int64Array},
153-
error::ArrowError,
154-
};
155-
156-
use crate::math::gcd::gcd;
157-
use datafusion_common::{cast::as_int64_array, DataFusionError};
158-
159-
#[test]
160-
fn test_gcd_i64() {
161-
let args: Vec<ArrayRef> = vec![
162-
Arc::new(Int64Array::from(vec![0, 3, 25, -16])), // x
163-
Arc::new(Int64Array::from(vec![0, -2, 15, 8])), // y
164-
];
165-
166-
let result = gcd(&args).expect("failed to initialize function gcd");
167-
let ints = as_int64_array(&result).expect("failed to initialize function gcd");
168-
169-
assert_eq!(ints.len(), 4);
170-
assert_eq!(ints.value(0), 0);
171-
assert_eq!(ints.value(1), 1);
172-
assert_eq!(ints.value(2), 5);
173-
assert_eq!(ints.value(3), 8);
174-
}
175-
176-
#[test]
177-
fn overflow_on_both_param_i64_min() {
178-
let args: Vec<ArrayRef> = vec![
179-
Arc::new(Int64Array::from(vec![i64::MIN])), // x
180-
Arc::new(Int64Array::from(vec![i64::MIN])), // y
181-
];
182-
183-
match gcd(&args) {
184-
// we expect a overflow
185-
Err(DataFusionError::ArrowError(ArrowError::ComputeError(_), _)) => {}
186-
Err(_) => {
187-
panic!("failed to initialize function gcd")
188-
}
189-
Ok(_) => panic!("GCD({0}, {0}) should have overflown", i64::MIN),
190-
};
191-
}
192-
}

datafusion/sqllogictest/test_files/math.slt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -623,12 +623,12 @@ select
623623
1 1 1
624624

625625
# gcd with columns and expresions
626-
query II rowsort
626+
query II
627627
select gcd(a, b), gcd(c*d + 1, abs(e)) + f from signed_integers;
628628
----
629629
1 11
630-
1 13
631630
2 -10
631+
1 13
632632
NULL NULL
633633

634634
# gcd(i64::MIN, i64::MIN)

0 commit comments

Comments
 (0)