Skip to content

Commit 590ad29

Browse files
authored
fix: update PrimitiveGroupValueBuilder to match NaN correctly (#17979)
1 parent 5609447 commit 590ad29

File tree

1 file changed

+33
-23
lines changed
  • datafusion/physical-plan/src/aggregates/group_values/multi_group_by

1 file changed

+33
-23
lines changed

datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ impl<T: ArrowPrimitiveType, const NULLABLE: bool> GroupColumn
7070
// Otherwise, we need to check their values
7171
}
7272

73-
self.group_values[lhs_row] == array.as_primitive::<T>().value(rhs_row)
73+
self.group_values[lhs_row].is_eq(array.as_primitive::<T>().value(rhs_row))
7474
}
7575

7676
fn append_val(&mut self, array: &ArrayRef, row: usize) -> Result<()> {
@@ -217,22 +217,22 @@ mod tests {
217217
use std::sync::Arc;
218218

219219
use crate::aggregates::group_values::multi_group_by::primitive::PrimitiveGroupValueBuilder;
220-
use arrow::array::{ArrayRef, Int64Array, NullBufferBuilder};
221-
use arrow::datatypes::{DataType, Int64Type};
220+
use arrow::array::{ArrayRef, Float32Array, Int64Array, NullBufferBuilder};
221+
use arrow::datatypes::{DataType, Float32Type, Int64Type};
222222

223223
use super::GroupColumn;
224224

225225
#[test]
226226
fn test_nullable_primitive_equal_to() {
227-
let append = |builder: &mut PrimitiveGroupValueBuilder<Int64Type, true>,
227+
let append = |builder: &mut PrimitiveGroupValueBuilder<Float32Type, true>,
228228
builder_array: &ArrayRef,
229229
append_rows: &[usize]| {
230230
for &index in append_rows {
231231
builder.append_val(builder_array, index).unwrap();
232232
}
233233
};
234234

235-
let equal_to = |builder: &PrimitiveGroupValueBuilder<Int64Type, true>,
235+
let equal_to = |builder: &PrimitiveGroupValueBuilder<Float32Type, true>,
236236
lhs_rows: &[usize],
237237
input_array: &ArrayRef,
238238
rhs_rows: &[usize],
@@ -248,15 +248,15 @@ mod tests {
248248

249249
#[test]
250250
fn test_nullable_primitive_vectorized_equal_to() {
251-
let append = |builder: &mut PrimitiveGroupValueBuilder<Int64Type, true>,
251+
let append = |builder: &mut PrimitiveGroupValueBuilder<Float32Type, true>,
252252
builder_array: &ArrayRef,
253253
append_rows: &[usize]| {
254254
builder
255255
.vectorized_append(builder_array, append_rows)
256256
.unwrap();
257257
};
258258

259-
let equal_to = |builder: &PrimitiveGroupValueBuilder<Int64Type, true>,
259+
let equal_to = |builder: &PrimitiveGroupValueBuilder<Float32Type, true>,
260260
lhs_rows: &[usize],
261261
input_array: &ArrayRef,
262262
rhs_rows: &[usize],
@@ -274,9 +274,9 @@ mod tests {
274274

275275
fn test_nullable_primitive_equal_to_internal<A, E>(mut append: A, mut equal_to: E)
276276
where
277-
A: FnMut(&mut PrimitiveGroupValueBuilder<Int64Type, true>, &ArrayRef, &[usize]),
277+
A: FnMut(&mut PrimitiveGroupValueBuilder<Float32Type, true>, &ArrayRef, &[usize]),
278278
E: FnMut(
279-
&PrimitiveGroupValueBuilder<Int64Type, true>,
279+
&PrimitiveGroupValueBuilder<Float32Type, true>,
280280
&[usize],
281281
&ArrayRef,
282282
&[usize],
@@ -293,48 +293,58 @@ mod tests {
293293

294294
// Define PrimitiveGroupValueBuilder
295295
let mut builder =
296-
PrimitiveGroupValueBuilder::<Int64Type, true>::new(DataType::Int64);
297-
let builder_array = Arc::new(Int64Array::from(vec![
296+
PrimitiveGroupValueBuilder::<Float32Type, true>::new(DataType::Float32);
297+
let builder_array = Arc::new(Float32Array::from(vec![
298298
None,
299299
None,
300300
None,
301-
Some(1),
302-
Some(2),
303-
Some(3),
301+
Some(1.0),
302+
Some(2.0),
303+
Some(f32::NAN),
304+
Some(3.0),
304305
])) as ArrayRef;
305-
append(&mut builder, &builder_array, &[0, 1, 2, 3, 4, 5]);
306+
append(&mut builder, &builder_array, &[0, 1, 2, 3, 4, 5, 6]);
306307

307308
// Define input array
308-
let (_, values, _nulls) =
309-
Int64Array::from(vec![Some(1), Some(2), None, None, Some(1), Some(3)])
310-
.into_parts();
309+
let (_, values, _nulls) = Float32Array::from(vec![
310+
Some(1.0),
311+
Some(2.0),
312+
None,
313+
Some(1.0),
314+
None,
315+
Some(f32::NAN),
316+
None,
317+
])
318+
.into_parts();
311319

312320
// explicitly build a null buffer where one of the null values also happens to match
313321
let mut nulls = NullBufferBuilder::new(6);
314322
nulls.append_non_null();
315323
nulls.append_null(); // this sets Some(2) to null above
316324
nulls.append_null();
317-
nulls.append_null();
318325
nulls.append_non_null();
326+
nulls.append_null();
319327
nulls.append_non_null();
320-
let input_array = Arc::new(Int64Array::new(values, nulls.finish())) as ArrayRef;
328+
nulls.append_null();
329+
let input_array = Arc::new(Float32Array::new(values, nulls.finish())) as ArrayRef;
321330

322331
// Check
323332
let mut equal_to_results = vec![true; builder.len()];
324333
equal_to(
325334
&builder,
326-
&[0, 1, 2, 3, 4, 5],
335+
&[0, 1, 2, 3, 4, 5, 6],
327336
&input_array,
328-
&[0, 1, 2, 3, 4, 5],
337+
&[0, 1, 2, 3, 4, 5, 6],
329338
&mut equal_to_results,
330339
);
331340

332341
assert!(!equal_to_results[0]);
333342
assert!(equal_to_results[1]);
334343
assert!(equal_to_results[2]);
335-
assert!(!equal_to_results[3]);
344+
assert!(equal_to_results[3]);
336345
assert!(!equal_to_results[4]);
337346
assert!(equal_to_results[5]);
347+
assert!(!equal_to_results[6]);
338348
}
339349

340350
#[test]

0 commit comments

Comments
 (0)