Skip to content

Commit 9ae434e

Browse files
committed
Merge remote-tracking branch 'apache/main' into smaller-preimage-pr-1
2 parents 59235de + d90d074 commit 9ae434e

File tree

14 files changed

+1264
-109
lines changed

14 files changed

+1264
-109
lines changed

datafusion/functions/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,3 +320,8 @@ required-features = ["math_expressions"]
320320
harness = false
321321
name = "floor_ceil"
322322
required-features = ["math_expressions"]
323+
324+
[[bench]]
325+
harness = false
326+
name = "round"
327+
required-features = ["math_expressions"]
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
extern crate criterion;
19+
20+
use arrow::datatypes::{DataType, Field, Float32Type, Float64Type};
21+
use arrow::util::bench_util::create_primitive_array;
22+
use criterion::{Criterion, SamplingMode, criterion_group, criterion_main};
23+
use datafusion_common::ScalarValue;
24+
use datafusion_common::config::ConfigOptions;
25+
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
26+
use datafusion_functions::math::round;
27+
use std::hint::black_box;
28+
use std::sync::Arc;
29+
use std::time::Duration;
30+
31+
fn criterion_benchmark(c: &mut Criterion) {
32+
let round_fn = round();
33+
let config_options = Arc::new(ConfigOptions::default());
34+
35+
for size in [1024, 4096, 8192] {
36+
let mut group = c.benchmark_group(format!("round size={size}"));
37+
group.sampling_mode(SamplingMode::Flat);
38+
group.sample_size(10);
39+
group.measurement_time(Duration::from_secs(10));
40+
41+
// Float64 array benchmark
42+
let f64_array = Arc::new(create_primitive_array::<Float64Type>(size, 0.1));
43+
let batch_len = f64_array.len();
44+
let f64_args = vec![
45+
ColumnarValue::Array(f64_array),
46+
ColumnarValue::Scalar(ScalarValue::Int32(Some(2))),
47+
];
48+
49+
group.bench_function("round_f64_array", |b| {
50+
b.iter(|| {
51+
let args_cloned = f64_args.clone();
52+
black_box(
53+
round_fn
54+
.invoke_with_args(ScalarFunctionArgs {
55+
args: args_cloned,
56+
arg_fields: vec![
57+
Field::new("a", DataType::Float64, true).into(),
58+
Field::new("b", DataType::Int32, false).into(),
59+
],
60+
number_rows: batch_len,
61+
return_field: Field::new("f", DataType::Float64, true).into(),
62+
config_options: Arc::clone(&config_options),
63+
})
64+
.unwrap(),
65+
)
66+
})
67+
});
68+
69+
// Float32 array benchmark
70+
let f32_array = Arc::new(create_primitive_array::<Float32Type>(size, 0.1));
71+
let f32_args = vec![
72+
ColumnarValue::Array(f32_array),
73+
ColumnarValue::Scalar(ScalarValue::Int32(Some(2))),
74+
];
75+
76+
group.bench_function("round_f32_array", |b| {
77+
b.iter(|| {
78+
let args_cloned = f32_args.clone();
79+
black_box(
80+
round_fn
81+
.invoke_with_args(ScalarFunctionArgs {
82+
args: args_cloned,
83+
arg_fields: vec![
84+
Field::new("a", DataType::Float32, true).into(),
85+
Field::new("b", DataType::Int32, false).into(),
86+
],
87+
number_rows: batch_len,
88+
return_field: Field::new("f", DataType::Float32, true).into(),
89+
config_options: Arc::clone(&config_options),
90+
})
91+
.unwrap(),
92+
)
93+
})
94+
});
95+
96+
// Scalar benchmark (the optimization we added)
97+
let scalar_f64_args = vec![
98+
ColumnarValue::Scalar(ScalarValue::Float64(Some(std::f64::consts::PI))),
99+
ColumnarValue::Scalar(ScalarValue::Int32(Some(2))),
100+
];
101+
102+
group.bench_function("round_f64_scalar", |b| {
103+
b.iter(|| {
104+
let args_cloned = scalar_f64_args.clone();
105+
black_box(
106+
round_fn
107+
.invoke_with_args(ScalarFunctionArgs {
108+
args: args_cloned,
109+
arg_fields: vec![
110+
Field::new("a", DataType::Float64, false).into(),
111+
Field::new("b", DataType::Int32, false).into(),
112+
],
113+
number_rows: 1,
114+
return_field: Field::new("f", DataType::Float64, false)
115+
.into(),
116+
config_options: Arc::clone(&config_options),
117+
})
118+
.unwrap(),
119+
)
120+
})
121+
});
122+
123+
let scalar_f32_args = vec![
124+
ColumnarValue::Scalar(ScalarValue::Float32(Some(std::f32::consts::PI))),
125+
ColumnarValue::Scalar(ScalarValue::Int32(Some(2))),
126+
];
127+
128+
group.bench_function("round_f32_scalar", |b| {
129+
b.iter(|| {
130+
let args_cloned = scalar_f32_args.clone();
131+
black_box(
132+
round_fn
133+
.invoke_with_args(ScalarFunctionArgs {
134+
args: args_cloned,
135+
arg_fields: vec![
136+
Field::new("a", DataType::Float32, false).into(),
137+
Field::new("b", DataType::Int32, false).into(),
138+
],
139+
number_rows: 1,
140+
return_field: Field::new("f", DataType::Float32, false)
141+
.into(),
142+
config_options: Arc::clone(&config_options),
143+
})
144+
.unwrap(),
145+
)
146+
})
147+
});
148+
149+
group.finish();
150+
}
151+
}
152+
153+
criterion_group!(benches, criterion_benchmark);
154+
criterion_main!(benches);

datafusion/functions/src/datetime/date_trunc.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,16 @@ use arrow::array::types::{
3434
use arrow::array::{Array, ArrayRef, PrimitiveArray};
3535
use arrow::datatypes::DataType::{self, Time32, Time64, Timestamp};
3636
use arrow::datatypes::TimeUnit::{self, Microsecond, Millisecond, Nanosecond, Second};
37+
use arrow::datatypes::{Field, FieldRef};
3738
use datafusion_common::cast::as_primitive_array;
3839
use datafusion_common::types::{NativeType, logical_date, logical_string};
3940
use datafusion_common::{
4041
DataFusionError, Result, ScalarValue, exec_datafusion_err, exec_err,
4142
};
4243
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
4344
use datafusion_expr::{
44-
ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility,
45+
ColumnarValue, Documentation, ReturnFieldArgs, ScalarUDFImpl, Signature,
46+
TypeSignature, Volatility,
4547
};
4648
use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
4749
use datafusion_macros::user_doc;
@@ -221,6 +223,7 @@ impl ScalarUDFImpl for DateTruncFunc {
221223
&self.signature
222224
}
223225

226+
// keep return_type implementation for information schema generation
224227
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
225228
if arg_types[1].is_null() {
226229
Ok(Timestamp(Nanosecond, None))
@@ -229,6 +232,21 @@ impl ScalarUDFImpl for DateTruncFunc {
229232
}
230233
}
231234

235+
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
236+
let data_types = args
237+
.arg_fields
238+
.iter()
239+
.map(|f| f.data_type())
240+
.cloned()
241+
.collect::<Vec<_>>();
242+
let return_type = self.return_type(&data_types)?;
243+
Ok(Arc::new(Field::new(
244+
self.name(),
245+
return_type,
246+
args.arg_fields[1].is_nullable(),
247+
)))
248+
}
249+
232250
fn invoke_with_args(
233251
&self,
234252
args: datafusion_expr::ScalarFunctionArgs,

datafusion/functions/src/math/round.rs

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use arrow::error::ArrowError;
3131
use datafusion_common::types::{
3232
NativeType, logical_float32, logical_float64, logical_int32,
3333
};
34-
use datafusion_common::{Result, ScalarValue, exec_err};
34+
use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
3535
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
3636
use datafusion_expr::{
3737
Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
@@ -141,7 +141,67 @@ impl ScalarUDFImpl for RoundFunc {
141141
&default_decimal_places
142142
};
143143

144-
round_columnar(&args.args[0], decimal_places, args.number_rows)
144+
// Scalar fast path for float and decimal types - avoid array conversion overhead
145+
if let (ColumnarValue::Scalar(value_scalar), ColumnarValue::Scalar(dp_scalar)) =
146+
(&args.args[0], decimal_places)
147+
{
148+
if value_scalar.is_null() || dp_scalar.is_null() {
149+
return ColumnarValue::Scalar(ScalarValue::Null)
150+
.cast_to(args.return_type(), None);
151+
}
152+
153+
let dp = if let ScalarValue::Int32(Some(dp)) = dp_scalar {
154+
*dp
155+
} else {
156+
return internal_err!(
157+
"Unexpected datatype for decimal_places: {}",
158+
dp_scalar.data_type()
159+
);
160+
};
161+
162+
match value_scalar {
163+
ScalarValue::Float32(Some(v)) => {
164+
let rounded = round_float(*v, dp)?;
165+
Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
166+
}
167+
ScalarValue::Float64(Some(v)) => {
168+
let rounded = round_float(*v, dp)?;
169+
Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
170+
}
171+
ScalarValue::Decimal128(Some(v), precision, scale) => {
172+
let rounded = round_decimal(*v, *scale, dp)?;
173+
let scalar =
174+
ScalarValue::Decimal128(Some(rounded), *precision, *scale);
175+
Ok(ColumnarValue::Scalar(scalar))
176+
}
177+
ScalarValue::Decimal256(Some(v), precision, scale) => {
178+
let rounded = round_decimal(*v, *scale, dp)?;
179+
let scalar =
180+
ScalarValue::Decimal256(Some(rounded), *precision, *scale);
181+
Ok(ColumnarValue::Scalar(scalar))
182+
}
183+
ScalarValue::Decimal64(Some(v), precision, scale) => {
184+
let rounded = round_decimal(*v, *scale, dp)?;
185+
let scalar =
186+
ScalarValue::Decimal64(Some(rounded), *precision, *scale);
187+
Ok(ColumnarValue::Scalar(scalar))
188+
}
189+
ScalarValue::Decimal32(Some(v), precision, scale) => {
190+
let rounded = round_decimal(*v, *scale, dp)?;
191+
let scalar =
192+
ScalarValue::Decimal32(Some(rounded), *precision, *scale);
193+
Ok(ColumnarValue::Scalar(scalar))
194+
}
195+
_ => {
196+
internal_err!(
197+
"Unexpected datatype for value: {}",
198+
value_scalar.data_type()
199+
)
200+
}
201+
}
202+
} else {
203+
round_columnar(&args.args[0], decimal_places, args.number_rows)
204+
}
145205
}
146206

147207
fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {

0 commit comments

Comments
 (0)