Skip to content

Commit 6ec70ae

Browse files
authored
perf: Improve performance of normalize_nan (#2999)
1 parent 937619f commit 6ec70ae

File tree

3 files changed

+114
-58
lines changed

3 files changed

+114
-58
lines changed

native/spark-expr/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ harness = false
8080
name = "padding"
8181
harness = false
8282

83+
[[bench]]
84+
name = "normalize_nan"
85+
harness = false
86+
8387
[[test]]
8488
name = "test_udf_registration"
8589
path = "tests/spark_expr_reg.rs"
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Benchmarks for NormalizeNaNAndZero expression
19+
20+
use arrow::array::Float64Array;
21+
use arrow::datatypes::{DataType, Field, Schema};
22+
use arrow::record_batch::RecordBatch;
23+
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
24+
use datafusion::physical_expr::expressions::Column;
25+
use datafusion::physical_expr::PhysicalExpr;
26+
use datafusion_comet_spark_expr::NormalizeNaNAndZero;
27+
use std::hint::black_box;
28+
use std::sync::Arc;
29+
30+
const BATCH_SIZE: usize = 8192;
31+
32+
fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
33+
Arc::new(Column::new(name, index))
34+
}
35+
36+
/// Create a batch with float64 column containing various values including NaN and -0.0
37+
fn create_float_batch(nan_pct: usize, neg_zero_pct: usize, null_pct: usize) -> RecordBatch {
38+
let mut values: Vec<Option<f64>> = Vec::with_capacity(BATCH_SIZE);
39+
40+
for i in 0..BATCH_SIZE {
41+
if null_pct > 0 && i % (100 / null_pct.max(1)) == 0 {
42+
values.push(None);
43+
} else if nan_pct > 0 && i % (100 / nan_pct.max(1)) == 1 {
44+
values.push(Some(f64::NAN));
45+
} else if neg_zero_pct > 0 && i % (100 / neg_zero_pct.max(1)) == 2 {
46+
values.push(Some(-0.0));
47+
} else {
48+
values.push(Some(i as f64 * 1.5));
49+
}
50+
}
51+
52+
let array = Float64Array::from(values);
53+
let schema = Schema::new(vec![Field::new("c1", DataType::Float64, true)]);
54+
55+
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap()
56+
}
57+
58+
fn bench_normalize_nan_and_zero(c: &mut Criterion) {
59+
let mut group = c.benchmark_group("normalize_nan_and_zero");
60+
61+
// Test with different percentages of special values
62+
let test_cases = [
63+
("no_special", 0, 0, 0),
64+
("10pct_nan", 10, 0, 0),
65+
("10pct_neg_zero", 0, 10, 0),
66+
("10pct_null", 0, 0, 10),
67+
("mixed_10pct", 5, 5, 5),
68+
("all_normal", 0, 0, 0),
69+
];
70+
71+
for (name, nan_pct, neg_zero_pct, null_pct) in test_cases {
72+
let batch = create_float_batch(nan_pct, neg_zero_pct, null_pct);
73+
74+
let normalize_expr = Arc::new(NormalizeNaNAndZero::new(
75+
DataType::Float64,
76+
make_col("c1", 0),
77+
));
78+
79+
group.bench_with_input(BenchmarkId::new("float64", name), &batch, |b, batch| {
80+
b.iter(|| black_box(normalize_expr.evaluate(black_box(batch)).unwrap()));
81+
});
82+
}
83+
84+
group.finish();
85+
}
86+
87+
criterion_group!(benches, bench_normalize_nan_and_zero);
88+
criterion_main!(benches);

native/spark-expr/src/math_funcs/internal/normalize_nan.rs

Lines changed: 22 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use arrow::compute::unary;
1819
use arrow::datatypes::{DataType, Schema};
1920
use arrow::{
20-
array::{as_primitive_array, ArrayAccessor, ArrayIter, Float32Array, Float64Array},
21-
datatypes::{ArrowNativeType, Float32Type, Float64Type},
21+
array::{as_primitive_array, Float32Array, Float64Array},
22+
datatypes::{Float32Type, Float64Type},
2223
record_batch::RecordBatch,
2324
};
2425
use datafusion::logical_expr::ColumnarValue;
@@ -78,14 +79,16 @@ impl PhysicalExpr for NormalizeNaNAndZero {
7879

7980
match &self.data_type {
8081
DataType::Float32 => {
81-
let v = eval_typed(as_primitive_array::<Float32Type>(&array));
82-
let new_array = Float32Array::from(v);
83-
Ok(ColumnarValue::Array(Arc::new(new_array)))
82+
let input = as_primitive_array::<Float32Type>(&array);
83+
// Use unary which operates directly on values buffer without intermediate allocation
84+
let result: Float32Array = unary(input, normalize_float);
85+
Ok(ColumnarValue::Array(Arc::new(result)))
8486
}
8587
DataType::Float64 => {
86-
let v = eval_typed(as_primitive_array::<Float64Type>(&array));
87-
let new_array = Float64Array::from(v);
88-
Ok(ColumnarValue::Array(Arc::new(new_array)))
88+
let input = as_primitive_array::<Float64Type>(&array);
89+
// Use unary which operates directly on values buffer without intermediate allocation
90+
let result: Float64Array = unary(input, normalize_float);
91+
Ok(ColumnarValue::Array(Arc::new(result)))
8992
}
9093
dt => panic!("Unexpected data type {dt:?}"),
9194
}
@@ -106,60 +109,21 @@ impl PhysicalExpr for NormalizeNaNAndZero {
106109
}
107110
}
108111

109-
fn eval_typed<V: FloatDouble, T: ArrayAccessor<Item = V>>(input: T) -> Vec<Option<V>> {
110-
let iter = ArrayIter::new(input);
111-
iter.map(|o| {
112-
o.map(|v| {
113-
if v.is_nan() {
114-
v.nan()
115-
} else if v.is_neg_zero() {
116-
v.zero()
117-
} else {
118-
v
119-
}
120-
})
121-
})
122-
.collect()
112+
/// Normalize a floating point value by converting all NaN representations to a canonical NaN
113+
/// and negative zero to positive zero. This is used for Spark's comparison semantics.
114+
#[inline]
115+
fn normalize_float<T: num::Float>(v: T) -> T {
116+
if v.is_nan() {
117+
T::nan()
118+
} else if v == T::neg_zero() {
119+
T::zero()
120+
} else {
121+
v
122+
}
123123
}
124124

125125
impl Display for NormalizeNaNAndZero {
126126
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
127127
write!(f, "FloatNormalize [child: {}]", self.child)
128128
}
129129
}
130-
131-
trait FloatDouble: ArrowNativeType {
132-
fn is_nan(&self) -> bool;
133-
fn nan(&self) -> Self;
134-
fn is_neg_zero(&self) -> bool;
135-
fn zero(&self) -> Self;
136-
}
137-
138-
impl FloatDouble for f32 {
139-
fn is_nan(&self) -> bool {
140-
f32::is_nan(*self)
141-
}
142-
fn nan(&self) -> Self {
143-
f32::NAN
144-
}
145-
fn is_neg_zero(&self) -> bool {
146-
*self == -0.0
147-
}
148-
fn zero(&self) -> Self {
149-
0.0
150-
}
151-
}
152-
impl FloatDouble for f64 {
153-
fn is_nan(&self) -> bool {
154-
f64::is_nan(*self)
155-
}
156-
fn nan(&self) -> Self {
157-
f64::NAN
158-
}
159-
fn is_neg_zero(&self) -> bool {
160-
*self == -0.0
161-
}
162-
fn zero(&self) -> Self {
163-
0.0
164-
}
165-
}

0 commit comments

Comments
 (0)