Skip to content

Commit 1a6df66

Browse files
fix: bit_count function to report nullability correctly (#19197)
## Which issue does this PR close? Closes #19147 Part of #19144 (EPIC: fix nullability report for spark expression) ## Rationale for this change The `bit_count` UDF was using the default `return_type` implementation which does not preserve nullability information. This causes: 1. **Incorrect schema inference** - non-nullable integer inputs are incorrectly marked as producing nullable Int32 outputs 2. **Missed optimization opportunities** - the query optimizer cannot apply nullability-based optimizations when metadata is incorrect 3. **Inconsistent behavior** - other similar functions preserve nullability, leading to unexpected differences The `bit_count` function counts the number of set bits (ones) in the binary representation of a number and returns an Int32. The operation itself doesn't introduce nullability - if the input is non-nullable, the output will always be non-nullable. Therefore, the output nullability should match the input. ## What changes are included in this PR? 1. **Implemented `return_field_from_args`**: Creates a field with Int32 type and the same nullability as the input field 2. **Updated `return_type`**: Now returns an error directing users to use `return_field_from_args` instead (following DataFusion best practices) 3. **Added `FieldRef` and `internal_err` imports** to support the new implementation 4. **Added nullability test**: Verifies that both nullable and non-nullable inputs are handled correctly ## Are these changes tested? Yes, this PR includes a new test `test_bit_count_nullability` that verifies: - Non-nullable Int32 input produces non-nullable Int32 output - Nullable Int32 input produces nullable Int32 output - Data type is correctly set to Int32 in both cases Test results: ``` running 1 test test function::bitwise::bit_count::tests::test_bit_count_nullability ... ok test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured ``` Additionally, all existing `bit_count` tests continue to pass, ensuring backward compatibility. ## Are there any user-facing changes? **Yes - Schema metadata improvement:** Users will now see correct nullability information in the schema: **Before (Bug):** - Non-nullable Int32/Int64/Boolean input → nullable Int32 output (incorrect) **After (Fixed):** - Non-nullable Int32/Int64/Boolean input → non-nullable Int32 output (correct) - Nullable Int32/Int64/Boolean input → nullable Int32 output (correct) This is a **bug fix** that corrects schema metadata only - it does not change the actual computation or introduce any breaking changes to the API. **Impact:** - Query optimizers can now make better decisions based on accurate nullability information - Schema validation will be more accurate - No changes to function behavior or output values --- ## Code Changes Summary ### Modified File: `datafusion/spark/src/function/bitwise/bit_count.rs` #### 1. Added Imports ```rust use arrow::datatypes::{ DataType, FieldRef, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use datafusion_common::{internal_err, plan_err, Result}; ``` #### 2. Updated return_type Method ```rust fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { internal_err!("return_field_from_args should be used instead") } ``` #### 3. Added return_field_from_args Implementation ```rust fn return_field_from_args( &self, args: datafusion_expr::ReturnFieldArgs, ) -> Result<FieldRef> { use arrow::datatypes::Field; // bit_count returns Int32 with the same nullability as the input Ok(Arc::new(Field::new( args.arg_fields[0].name(), DataType::Int32, args.arg_fields[0].is_nullable(), ))) } ``` #### 4. Added Test ```rust #[test] fn test_bit_count_nullability() -> Result<()> { use datafusion_expr::ReturnFieldArgs; let bit_count = SparkBitCount::new(); // Test with non-nullable Int32 field let non_nullable_field = Arc::new(Field::new("num", DataType::Int32, false)); let result = bit_count.return_field_from_args(ReturnFieldArgs { arg_fields: &[Arc::clone(&non_nullable_field)], scalar_arguments: &[None], })?; // The result should not be nullable (same as input) assert!(!result.is_nullable()); assert_eq!(result.data_type(), &DataType::Int32); // Test with nullable Int32 field let nullable_field = Arc::new(Field::new("num", DataType::Int32, true)); let result = bit_count.return_field_from_args(ReturnFieldArgs { arg_fields: &[Arc::clone(&nullable_field)], scalar_arguments: &[None], })?; // The result should be nullable (same as input) assert!(result.is_nullable()); assert_eq!(result.data_type(), &DataType::Int32); Ok(()) } ``` --- ## Verification Steps 1. **Run the new test:** ```bash cargo test -p datafusion-spark test_bit_count_nullability --lib ``` 2. **Run all bit_count tests:** ```bash cargo test -p datafusion-spark bit_count --lib ``` 3. **Run clippy checks:** ```bash cargo clippy -p datafusion-spark --all-targets -- -D warnings ``` All checks pass successfully! --- ## Related Issues - Closes: #19147 - Part of EPIC: #19144 (fix nullability report for spark expression) - Similar fixes: - #19145 (shuffle function nullability) - #19146 (bitmap_count function nullability) --- Co-authored-by: Oleks V <[email protected]>
1 parent 4fb36b2 commit 1a6df66

File tree

1 file changed

+51
-5
lines changed

1 file changed

+51
-5
lines changed

datafusion/spark/src/function/bitwise/bit_count.rs

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ use std::sync::Arc;
2020

2121
use arrow::array::{ArrayRef, AsArray, Int32Array};
2222
use arrow::datatypes::{
23-
DataType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type,
24-
UInt64Type, UInt8Type,
23+
DataType, FieldRef, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type,
24+
UInt32Type, UInt64Type, UInt8Type,
2525
};
2626
use datafusion_common::cast::as_boolean_array;
27-
use datafusion_common::{plan_err, Result};
27+
use datafusion_common::{internal_err, plan_err, Result};
2828
use datafusion_expr::{
2929
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
3030
Volatility,
@@ -77,7 +77,20 @@ impl ScalarUDFImpl for SparkBitCount {
7777
}
7878

7979
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
80-
Ok(DataType::Int32) // Spark returns int (Int32)
80+
internal_err!("return_field_from_args should be used instead")
81+
}
82+
83+
fn return_field_from_args(
84+
&self,
85+
args: datafusion_expr::ReturnFieldArgs,
86+
) -> Result<FieldRef> {
87+
use arrow::datatypes::Field;
88+
// bit_count returns Int32 with the same nullability as the input
89+
Ok(Arc::new(Field::new(
90+
args.arg_fields[0].name(),
91+
DataType::Int32,
92+
args.arg_fields[0].is_nullable(),
93+
)))
8194
}
8295

8396
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
@@ -163,7 +176,7 @@ mod tests {
163176
Array, BooleanArray, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array,
164177
UInt32Array, UInt64Array, UInt8Array,
165178
};
166-
use arrow::datatypes::Int32Type;
179+
use arrow::datatypes::{Field, Int32Type};
167180

168181
#[test]
169182
fn test_bit_count_basic() {
@@ -336,4 +349,37 @@ mod tests {
336349
assert!(arr.is_null(1));
337350
assert_eq!(arr.value(2), 3); // 0b111
338351
}
352+
353+
#[test]
354+
fn test_bit_count_nullability() -> Result<()> {
355+
use datafusion_expr::ReturnFieldArgs;
356+
357+
let bit_count = SparkBitCount::new();
358+
359+
// Test with non-nullable Int32 field
360+
let non_nullable_field = Arc::new(Field::new("num", DataType::Int32, false));
361+
362+
let result = bit_count.return_field_from_args(ReturnFieldArgs {
363+
arg_fields: &[Arc::clone(&non_nullable_field)],
364+
scalar_arguments: &[None],
365+
})?;
366+
367+
// The result should not be nullable (same as input)
368+
assert!(!result.is_nullable());
369+
assert_eq!(result.data_type(), &DataType::Int32);
370+
371+
// Test with nullable Int32 field
372+
let nullable_field = Arc::new(Field::new("num", DataType::Int32, true));
373+
374+
let result = bit_count.return_field_from_args(ReturnFieldArgs {
375+
arg_fields: &[Arc::clone(&nullable_field)],
376+
scalar_arguments: &[None],
377+
})?;
378+
379+
// The result should be nullable (same as input)
380+
assert!(result.is_nullable());
381+
assert_eq!(result.data_type(), &DataType::Int32);
382+
383+
Ok(())
384+
}
339385
}

0 commit comments

Comments
 (0)