Skip to content

Commit c51f977

Browse files
Feat: support bit_get function (#1713)
1 parent d72e54c commit c51f977

File tree

7 files changed

+707
-172
lines changed

7 files changed

+707
-172
lines changed
Lines changed: 317 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,317 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::{array::*, datatypes::DataType};
19+
use datafusion::common::{exec_err, internal_datafusion_err, Result, ScalarValue};
20+
use datafusion::logical_expr::ColumnarValue;
21+
use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
22+
use std::any::Any;
23+
use std::sync::Arc;
24+
25+
#[derive(Debug)]
26+
pub struct SparkBitwiseGet {
27+
signature: Signature,
28+
aliases: Vec<String>,
29+
}
30+
31+
impl Default for SparkBitwiseGet {
32+
fn default() -> Self {
33+
Self::new()
34+
}
35+
}
36+
37+
impl SparkBitwiseGet {
38+
pub fn new() -> Self {
39+
Self {
40+
signature: Signature::user_defined(Volatility::Immutable),
41+
aliases: vec![],
42+
}
43+
}
44+
}
45+
46+
impl ScalarUDFImpl for SparkBitwiseGet {
47+
fn as_any(&self) -> &dyn Any {
48+
self
49+
}
50+
51+
fn name(&self) -> &str {
52+
"bit_get"
53+
}
54+
55+
fn signature(&self) -> &Signature {
56+
&self.signature
57+
}
58+
59+
fn aliases(&self) -> &[String] {
60+
&self.aliases
61+
}
62+
63+
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
64+
Ok(DataType::Int8)
65+
}
66+
67+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
68+
let args: [ColumnarValue; 2] = args
69+
.args
70+
.try_into()
71+
.map_err(|_| internal_datafusion_err!("bit_get expects exactly two arguments"))?;
72+
spark_bit_get(&args)
73+
}
74+
}
75+
76+
macro_rules! bit_get_scalar_position {
77+
($args:expr, $array_type:ty, $pos:expr, $bit_size:expr) => {{
78+
if let Some(pos) = $pos {
79+
check_position(*pos, $bit_size as i32)?;
80+
}
81+
let args = $args
82+
.as_any()
83+
.downcast_ref::<$array_type>()
84+
.expect("bit_get_scalar_position failed to downcast array");
85+
86+
let result: Int8Array = args
87+
.iter()
88+
.map(|x| x.and_then(|x| $pos.map(|pos| bit_get(x.into(), pos))))
89+
.collect();
90+
91+
Ok(Arc::new(result))
92+
}};
93+
}
94+
95+
macro_rules! bit_get_array_positions {
96+
($args:expr, $array_type:ty, $positions:expr, $bit_size:expr) => {{
97+
let args = $args
98+
.as_any()
99+
.downcast_ref::<$array_type>()
100+
.expect("bit_get_array_positions failed to downcast args array");
101+
102+
let positions = $positions
103+
.as_any()
104+
.downcast_ref::<Int32Array>()
105+
.expect("bit_get_array_positions failed to downcast positions array");
106+
107+
for pos in positions.iter().flatten() {
108+
check_position(pos, $bit_size as i32)?
109+
}
110+
111+
let result: Int8Array = args
112+
.iter()
113+
.zip(positions.iter())
114+
.map(|(i, p)| i.and_then(|i| p.map(|p| bit_get(i.into(), p))))
115+
.collect();
116+
117+
Ok(Arc::new(result))
118+
}};
119+
}
120+
121+
pub fn spark_bit_get(args: &[ColumnarValue; 2]) -> Result<ColumnarValue> {
122+
match args {
123+
[ColumnarValue::Array(args), ColumnarValue::Scalar(ScalarValue::Int32(pos))] => {
124+
let result: Result<ArrayRef> = match args.data_type() {
125+
DataType::Int8 => bit_get_scalar_position!(args, Int8Array, pos, i8::BITS),
126+
DataType::Int16 => bit_get_scalar_position!(args, Int16Array, pos, i16::BITS),
127+
DataType::Int32 => bit_get_scalar_position!(args, Int32Array, pos, i32::BITS),
128+
DataType::Int64 => bit_get_scalar_position!(args, Int64Array, pos, i64::BITS),
129+
_ => exec_err!(
130+
"Can't be evaluated because the expression's type is {:?}, not signed int",
131+
args.data_type()
132+
),
133+
};
134+
result.map(ColumnarValue::Array)
135+
},
136+
[ColumnarValue::Array(args), ColumnarValue::Array(positions)] => {
137+
if args.len() != positions.len() {
138+
return exec_err!(
139+
"Input arrays must have equal length. Positions array has {} elements, but arguments array has {} elements",
140+
positions.len(), args.len()
141+
);
142+
}
143+
if !matches!(positions.data_type(), DataType::Int32) {
144+
return exec_err!(
145+
"Invalid data type for positions array: expected `Int32`, found `{}`",
146+
positions.data_type()
147+
);
148+
}
149+
let result: Result<ArrayRef> = match args.data_type() {
150+
DataType::Int8 => bit_get_array_positions!(args, Int8Array, positions, i8::BITS),
151+
DataType::Int16 => bit_get_array_positions!(args, Int16Array, positions, i16::BITS),
152+
DataType::Int32 => bit_get_array_positions!(args, Int32Array, positions, i32::BITS),
153+
DataType::Int64 => bit_get_array_positions!(args, Int64Array, positions, i64::BITS),
154+
_ => exec_err!(
155+
"Can't be evaluated because the expression's type is {:?}, not signed int",
156+
args.data_type()
157+
),
158+
};
159+
result.map(ColumnarValue::Array)
160+
}
161+
_ => exec_err!(
162+
"Invalid input to function bit_get. Expected (IntegralType array, Int32Scalar) or (IntegralType array, Int32Array)"
163+
),
164+
}
165+
}
166+
167+
fn bit_get(arg: i64, pos: i32) -> i8 {
168+
((arg >> pos) & 1) as i8
169+
}
170+
171+
fn check_position(pos: i32, bit_size: i32) -> Result<()> {
172+
if pos < 0 {
173+
return exec_err!("Invalid bit position: {:?} is less than zero", pos);
174+
}
175+
if bit_size <= pos {
176+
return exec_err!(
177+
"Invalid bit position: {:?} exceeds the bit upper limit: {:?}",
178+
pos,
179+
bit_size
180+
);
181+
}
182+
Ok(())
183+
}
184+
185+
#[cfg(test)]
186+
mod tests {
187+
use super::*;
188+
use datafusion::common::cast::as_int8_array;
189+
190+
#[test]
191+
fn bitwise_get_scalar_position() -> Result<()> {
192+
let args = [
193+
ColumnarValue::Array(Arc::new(Int32Array::from(vec![
194+
Some(1),
195+
None,
196+
Some(1234553454),
197+
]))),
198+
ColumnarValue::Scalar(ScalarValue::Int32(Some(1))),
199+
];
200+
201+
let expected = &Int8Array::from(vec![Some(0), None, Some(1)]);
202+
203+
let ColumnarValue::Array(result) = spark_bit_get(&args)? else {
204+
unreachable!()
205+
};
206+
207+
let result = as_int8_array(&result).expect("failed to downcast to Int8Array");
208+
209+
assert_eq!(result, expected);
210+
211+
Ok(())
212+
}
213+
214+
#[test]
215+
fn bitwise_get_scalar_negative_position() -> Result<()> {
216+
let args = [
217+
ColumnarValue::Array(Arc::new(Int32Array::from(vec![
218+
Some(1),
219+
None,
220+
Some(1234553454),
221+
]))),
222+
ColumnarValue::Scalar(ScalarValue::Int32(Some(-1))),
223+
];
224+
225+
let expected = String::from("Execution error: Invalid bit position: -1 is less than zero");
226+
let result = spark_bit_get(&args).err().unwrap().to_string();
227+
228+
assert_eq!(result, expected);
229+
230+
Ok(())
231+
}
232+
233+
#[test]
234+
fn bitwise_get_scalar_overflow_position() -> Result<()> {
235+
let args = [
236+
ColumnarValue::Array(Arc::new(Int32Array::from(vec![
237+
Some(1),
238+
None,
239+
Some(1234553454),
240+
]))),
241+
ColumnarValue::Scalar(ScalarValue::Int32(Some(33))),
242+
];
243+
244+
let expected = String::from(
245+
"Execution error: Invalid bit position: 33 exceeds the bit upper limit: 32",
246+
);
247+
let result = spark_bit_get(&args).err().unwrap().to_string();
248+
249+
assert_eq!(result, expected);
250+
251+
Ok(())
252+
}
253+
254+
#[test]
255+
fn bitwise_get_array_positions() -> Result<()> {
256+
let args = [
257+
ColumnarValue::Array(Arc::new(Int32Array::from(vec![
258+
Some(1),
259+
None,
260+
Some(1234553454),
261+
]))),
262+
ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(1), None, Some(1)]))),
263+
];
264+
265+
let expected = &Int8Array::from(vec![Some(0), None, Some(1)]);
266+
267+
let ColumnarValue::Array(result) = spark_bit_get(&args)? else {
268+
unreachable!()
269+
};
270+
271+
let result = as_int8_array(&result).expect("failed to downcast to Int8Array");
272+
273+
assert_eq!(result, expected);
274+
275+
Ok(())
276+
}
277+
278+
#[test]
279+
fn bitwise_get_array_positions_contains_negative() -> Result<()> {
280+
let args = [
281+
ColumnarValue::Array(Arc::new(Int32Array::from(vec![
282+
Some(1),
283+
None,
284+
Some(1234553454),
285+
]))),
286+
ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(-1), None, Some(1)]))),
287+
];
288+
289+
let expected = String::from("Execution error: Invalid bit position: -1 is less than zero");
290+
let result = spark_bit_get(&args).err().unwrap().to_string();
291+
292+
assert_eq!(result, expected);
293+
294+
Ok(())
295+
}
296+
297+
#[test]
298+
fn bitwise_get_array_positions_contains_overflow() -> Result<()> {
299+
let args = [
300+
ColumnarValue::Array(Arc::new(Int32Array::from(vec![
301+
Some(1),
302+
None,
303+
Some(1234553454),
304+
]))),
305+
ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(33), None, Some(1)]))),
306+
];
307+
308+
let expected = String::from(
309+
"Execution error: Invalid bit position: 33 exceeds the bit upper limit: 32",
310+
);
311+
let result = spark_bit_get(&args).err().unwrap().to_string();
312+
313+
assert_eq!(result, expected);
314+
315+
Ok(())
316+
}
317+
}

native/spark-expr/src/bitwise_funcs/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
// under the License.
1717

1818
mod bitwise_count;
19+
mod bitwise_get;
1920
mod bitwise_not;
2021

2122
pub use bitwise_count::SparkBitwiseCount;
23+
pub use bitwise_get::SparkBitwiseGet;
2224
pub use bitwise_not::SparkBitwiseNot;

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use crate::{
2020
spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div,
2121
spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal,
2222
spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value,
23-
SparkBitwiseCount, SparkBitwiseNot, SparkChrFunc, SparkDateTrunc,
23+
SparkBitwiseCount, SparkBitwiseGet, SparkBitwiseNot, SparkChrFunc, SparkDateTrunc,
2424
};
2525
use arrow::datatypes::DataType;
2626
use datafusion::common::{DataFusionError, Result as DataFusionResult};
@@ -157,6 +157,7 @@ fn all_scalar_functions() -> Vec<Arc<ScalarUDF>> {
157157
Arc::new(ScalarUDF::new_from_impl(SparkChrFunc::default())),
158158
Arc::new(ScalarUDF::new_from_impl(SparkBitwiseNot::default())),
159159
Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())),
160+
Arc::new(ScalarUDF::new_from_impl(SparkBitwiseGet::default())),
160161
Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())),
161162
]
162163
}

0 commit comments

Comments
 (0)