Skip to content

Commit c5cb6f4

Browse files
authored
feat(spark): implement Spark math function bit_get/bit_count (#16942)
* bit tmp * fmt * clippy * clippy * update
1 parent f0e9334 commit c5cb6f4

File tree

6 files changed

+921
-12
lines changed

6 files changed

+921
-12
lines changed
Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::any::Any;
19+
use std::sync::Arc;
20+
21+
use arrow::array::{ArrayRef, AsArray, Int32Array};
22+
use arrow::datatypes::{
23+
DataType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type,
24+
UInt64Type, UInt8Type,
25+
};
26+
use datafusion_common::{plan_err, Result};
27+
use datafusion_expr::{
28+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
29+
Volatility,
30+
};
31+
use datafusion_functions::utils::make_scalar_function;
32+
33+
#[derive(Debug)]
34+
pub struct SparkBitCount {
35+
signature: Signature,
36+
}
37+
38+
impl Default for SparkBitCount {
39+
fn default() -> Self {
40+
Self::new()
41+
}
42+
}
43+
44+
impl SparkBitCount {
45+
pub fn new() -> Self {
46+
Self {
47+
signature: Signature::one_of(
48+
vec![
49+
TypeSignature::Exact(vec![DataType::Int8]),
50+
TypeSignature::Exact(vec![DataType::Int16]),
51+
TypeSignature::Exact(vec![DataType::Int32]),
52+
TypeSignature::Exact(vec![DataType::Int64]),
53+
TypeSignature::Exact(vec![DataType::UInt8]),
54+
TypeSignature::Exact(vec![DataType::UInt16]),
55+
TypeSignature::Exact(vec![DataType::UInt32]),
56+
TypeSignature::Exact(vec![DataType::UInt64]),
57+
],
58+
Volatility::Immutable,
59+
),
60+
}
61+
}
62+
}
63+
64+
impl ScalarUDFImpl for SparkBitCount {
65+
fn as_any(&self) -> &dyn Any {
66+
self
67+
}
68+
69+
fn name(&self) -> &str {
70+
"bit_count"
71+
}
72+
73+
fn signature(&self) -> &Signature {
74+
&self.signature
75+
}
76+
77+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
78+
Ok(DataType::Int32) // Spark returns int (Int32)
79+
}
80+
81+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
82+
if args.args.len() != 1 {
83+
return plan_err!("bit_count expects exactly 1 argument");
84+
}
85+
86+
make_scalar_function(spark_bit_count, vec![])(&args.args)
87+
}
88+
}
89+
90+
fn spark_bit_count(value_array: &[ArrayRef]) -> Result<ArrayRef> {
91+
let value_array = value_array[0].as_ref();
92+
match value_array.data_type() {
93+
DataType::Int8 => {
94+
let result: Int32Array = value_array
95+
.as_primitive::<Int8Type>()
96+
.unary(|v| v.count_ones() as i32);
97+
Ok(Arc::new(result))
98+
}
99+
DataType::Int16 => {
100+
let result: Int32Array = value_array
101+
.as_primitive::<Int16Type>()
102+
.unary(|v| v.count_ones() as i32);
103+
Ok(Arc::new(result))
104+
}
105+
DataType::Int32 => {
106+
let result: Int32Array = value_array
107+
.as_primitive::<Int32Type>()
108+
.unary(|v| v.count_ones() as i32);
109+
Ok(Arc::new(result))
110+
}
111+
DataType::Int64 => {
112+
let result: Int32Array = value_array
113+
.as_primitive::<Int64Type>()
114+
.unary(|v| v.count_ones() as i32);
115+
Ok(Arc::new(result))
116+
}
117+
DataType::UInt8 => {
118+
let result: Int32Array = value_array
119+
.as_primitive::<UInt8Type>()
120+
.unary(|v| v.count_ones() as i32);
121+
Ok(Arc::new(result))
122+
}
123+
DataType::UInt16 => {
124+
let result: Int32Array = value_array
125+
.as_primitive::<UInt16Type>()
126+
.unary(|v| v.count_ones() as i32);
127+
Ok(Arc::new(result))
128+
}
129+
DataType::UInt32 => {
130+
let result: Int32Array = value_array
131+
.as_primitive::<UInt32Type>()
132+
.unary(|v| v.count_ones() as i32);
133+
Ok(Arc::new(result))
134+
}
135+
DataType::UInt64 => {
136+
let result: Int32Array = value_array
137+
.as_primitive::<UInt64Type>()
138+
.unary(|v| v.count_ones() as i32);
139+
Ok(Arc::new(result))
140+
}
141+
_ => {
142+
plan_err!(
143+
"bit_count function does not support data type: {:?}",
144+
value_array.data_type()
145+
)
146+
}
147+
}
148+
}
149+
150+
#[cfg(test)]
151+
mod tests {
152+
use super::*;
153+
use arrow::array::{
154+
Array, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, UInt32Array,
155+
UInt64Array, UInt8Array,
156+
};
157+
use arrow::datatypes::Int32Type;
158+
159+
#[test]
160+
fn test_bit_count_basic() {
161+
// Test bit_count(0) - no bits set
162+
let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![0]))]).unwrap();
163+
164+
assert_eq!(result.as_primitive::<Int32Type>().value(0), 0);
165+
166+
// Test bit_count(1) - 1 bit set
167+
let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![1]))]).unwrap();
168+
169+
assert_eq!(result.as_primitive::<Int32Type>().value(0), 1);
170+
171+
// Test bit_count(7) - 7 = 111 in binary, 3 bits set
172+
let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![7]))]).unwrap();
173+
174+
assert_eq!(result.as_primitive::<Int32Type>().value(0), 3);
175+
176+
// Test bit_count(15) - 15 = 1111 in binary, 4 bits set
177+
let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![15]))]).unwrap();
178+
179+
assert_eq!(result.as_primitive::<Int32Type>().value(0), 4);
180+
}
181+
182+
#[test]
183+
fn test_bit_count_int8() {
184+
// Test bit_count on Int8Array
185+
let result =
186+
spark_bit_count(&[Arc::new(Int8Array::from(vec![0i8, 1, 3, 7, 15, -1]))])
187+
.unwrap();
188+
189+
let arr = result.as_primitive::<Int32Type>();
190+
assert_eq!(arr.value(0), 0);
191+
assert_eq!(arr.value(1), 1);
192+
assert_eq!(arr.value(2), 2);
193+
assert_eq!(arr.value(3), 3);
194+
assert_eq!(arr.value(4), 4);
195+
assert_eq!(arr.value(5), 8);
196+
}
197+
198+
#[test]
199+
fn test_bit_count_int16() {
200+
// Test bit_count on Int16Array
201+
let result =
202+
spark_bit_count(&[Arc::new(Int16Array::from(vec![0i16, 1, 255, 1023, -1]))])
203+
.unwrap();
204+
205+
let arr = result.as_primitive::<Int32Type>();
206+
assert_eq!(arr.value(0), 0);
207+
assert_eq!(arr.value(1), 1);
208+
assert_eq!(arr.value(2), 8);
209+
assert_eq!(arr.value(3), 10);
210+
assert_eq!(arr.value(4), 16);
211+
}
212+
213+
#[test]
214+
fn test_bit_count_int32() {
215+
// Test bit_count on Int32Array
216+
let result =
217+
spark_bit_count(&[Arc::new(Int32Array::from(vec![0i32, 1, 255, 1023, -1]))])
218+
.unwrap();
219+
220+
let arr = result.as_primitive::<Int32Type>();
221+
assert_eq!(arr.value(0), 0); // 0b00000000000000000000000000000000 = 0
222+
assert_eq!(arr.value(1), 1); // 0b00000000000000000000000000000001 = 1
223+
assert_eq!(arr.value(2), 8); // 0b00000000000000000000000011111111 = 8
224+
assert_eq!(arr.value(3), 10); // 0b00000000000000000000001111111111 = 10
225+
assert_eq!(arr.value(4), 32); // -1 in two's complement = all 32 bits set
226+
}
227+
228+
#[test]
229+
fn test_bit_count_int64() {
230+
// Test bit_count on Int64Array
231+
let result =
232+
spark_bit_count(&[Arc::new(Int64Array::from(vec![0i64, 1, 255, 1023, -1]))])
233+
.unwrap();
234+
235+
let arr = result.as_primitive::<Int32Type>();
236+
assert_eq!(arr.value(0), 0); // 0b0000000000000000000000000000000000000000000000000000000000000000 = 0
237+
assert_eq!(arr.value(1), 1); // 0b0000000000000000000000000000000000000000000000000000000000000001 = 1
238+
assert_eq!(arr.value(2), 8); // 0b0000000000000000000000000000000000000000000000000000000011111111 = 8
239+
assert_eq!(arr.value(3), 10); // 0b0000000000000000000000000000000000000000000000000000001111111111 = 10
240+
assert_eq!(arr.value(4), 64); // -1 in two's complement = all 64 bits set
241+
}
242+
243+
#[test]
244+
fn test_bit_count_uint8() {
245+
// Test bit_count on UInt8Array
246+
let result =
247+
spark_bit_count(&[Arc::new(UInt8Array::from(vec![0u8, 1, 255]))]).unwrap();
248+
249+
let arr = result.as_primitive::<Int32Type>();
250+
assert_eq!(arr.value(0), 0); // 0b00000000 = 0
251+
assert_eq!(arr.value(1), 1); // 0b00000001 = 1
252+
assert_eq!(arr.value(2), 8); // 0b11111111 = 8
253+
}
254+
255+
#[test]
256+
fn test_bit_count_uint16() {
257+
// Test bit_count on UInt16Array
258+
let result =
259+
spark_bit_count(&[Arc::new(UInt16Array::from(vec![0u16, 1, 255, 65535]))])
260+
.unwrap();
261+
262+
let arr = result.as_primitive::<Int32Type>();
263+
assert_eq!(arr.value(0), 0); // 0b0000000000000000 = 0
264+
assert_eq!(arr.value(1), 1); // 0b0000000000000001 = 1
265+
assert_eq!(arr.value(2), 8); // 0b0000000011111111 = 8
266+
assert_eq!(arr.value(3), 16); // 0b1111111111111111 = 16
267+
}
268+
269+
#[test]
270+
fn test_bit_count_uint32() {
271+
// Test bit_count on UInt32Array
272+
let result = spark_bit_count(&[Arc::new(UInt32Array::from(vec![
273+
0u32, 1, 255, 4294967295,
274+
]))])
275+
.unwrap();
276+
277+
let arr = result.as_primitive::<Int32Type>();
278+
assert_eq!(arr.value(0), 0); // 0b00000000000000000000000000000000 = 0
279+
assert_eq!(arr.value(1), 1); // 0b00000000000000000000000000000001 = 1
280+
assert_eq!(arr.value(2), 8); // 0b00000000000000000000000011111111 = 8
281+
assert_eq!(arr.value(3), 32); // 0b11111111111111111111111111111111 = 32
282+
}
283+
284+
#[test]
285+
fn test_bit_count_uint64() {
286+
// Test bit_count on UInt64Array
287+
let result = spark_bit_count(&[Arc::new(UInt64Array::from(vec![
288+
0u64,
289+
1,
290+
255,
291+
256,
292+
u64::MAX,
293+
]))])
294+
.unwrap();
295+
296+
let arr = result.as_primitive::<Int32Type>();
297+
// 0b0 = 0
298+
assert_eq!(arr.value(0), 0);
299+
// 0b1 = 1
300+
assert_eq!(arr.value(1), 1);
301+
// 0b11111111 = 8
302+
assert_eq!(arr.value(2), 8);
303+
// 0b100000000 = 1
304+
assert_eq!(arr.value(3), 1);
305+
// u64::MAX = all 64 bits set
306+
assert_eq!(arr.value(4), 64);
307+
}
308+
309+
#[test]
310+
fn test_bit_count_nulls() {
311+
// Test bit_count with nulls
312+
let arr = Int32Array::from(vec![Some(3), None, Some(7)]);
313+
let result = spark_bit_count(&[Arc::new(arr)]).unwrap();
314+
let arr = result.as_primitive::<Int32Type>();
315+
assert_eq!(arr.value(0), 2); // 0b11
316+
assert!(arr.is_null(1));
317+
assert_eq!(arr.value(2), 3); // 0b111
318+
}
319+
}

0 commit comments

Comments
 (0)