|
15 | 15 | // specific language governing permissions and limitations |
16 | 16 | // under the License. |
17 | 17 |
|
18 | | -#![allow(deprecated)] |
19 | | - |
20 | | -use crate::kernels::strings::string_space; |
21 | | -use arrow::datatypes::{DataType, Schema}; |
22 | | -use arrow::record_batch::RecordBatch; |
23 | | -use datafusion::common::DataFusionError; |
24 | | -use datafusion::logical_expr::ColumnarValue; |
25 | | -use datafusion::physical_expr::PhysicalExpr; |
26 | | -use std::{ |
27 | | - any::Any, |
28 | | - fmt::{Display, Formatter}, |
29 | | - hash::Hash, |
30 | | - sync::Arc, |
| 18 | +use arrow::array::{ |
| 19 | + as_dictionary_array, make_array, Array, ArrayData, ArrayRef, DictionaryArray, |
| 20 | + GenericStringArray, Int32Array, OffsetSizeTrait, |
31 | 21 | }; |
| 22 | +use arrow::buffer::MutableBuffer; |
| 23 | +use arrow::datatypes::{DataType, Int32Type}; |
| 24 | +use datafusion::common::{exec_err, internal_datafusion_err, DataFusionError, Result}; |
| 25 | +use datafusion::logical_expr::{ |
| 26 | + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, |
| 27 | +}; |
| 28 | +use std::{any::Any, sync::Arc}; |
32 | 29 |
|
33 | | -#[derive(Debug, Eq)] |
34 | | -pub struct StringSpaceExpr { |
35 | | - pub child: Arc<dyn PhysicalExpr>, |
36 | | -} |
37 | | - |
38 | | -impl Hash for StringSpaceExpr { |
39 | | - fn hash<H: std::hash::Hasher>(&self, state: &mut H) { |
40 | | - self.child.hash(state); |
41 | | - } |
42 | | -} |
43 | | - |
44 | | -impl PartialEq for StringSpaceExpr { |
45 | | - fn eq(&self, other: &Self) -> bool { |
46 | | - self.child.eq(&other.child) |
47 | | - } |
| 30 | +#[derive(Debug)] |
| 31 | +pub struct SparkStringSpace { |
| 32 | + signature: Signature, |
| 33 | + aliases: Vec<String>, |
48 | 34 | } |
49 | 35 |
|
50 | | -impl StringSpaceExpr { |
51 | | - pub fn new(child: Arc<dyn PhysicalExpr>) -> Self { |
52 | | - Self { child } |
| 36 | +impl Default for SparkStringSpace { |
| 37 | + fn default() -> Self { |
| 38 | + Self::new() |
53 | 39 | } |
54 | 40 | } |
55 | 41 |
|
56 | | -impl Display for StringSpaceExpr { |
57 | | - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
58 | | - write!(f, "StringSpace [child: {}] ", self.child) |
| 42 | +impl SparkStringSpace { |
| 43 | + pub fn new() -> Self { |
| 44 | + Self { |
| 45 | + signature: Signature::user_defined(Volatility::Immutable), |
| 46 | + aliases: vec![], |
| 47 | + } |
59 | 48 | } |
60 | 49 | } |
61 | 50 |
|
62 | | -impl PhysicalExpr for StringSpaceExpr { |
| 51 | +impl ScalarUDFImpl for SparkStringSpace { |
63 | 52 | fn as_any(&self) -> &dyn Any { |
64 | 53 | self |
65 | 54 | } |
66 | 55 |
|
67 | | - fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result { |
68 | | - unimplemented!() |
| 56 | + fn name(&self) -> &str { |
| 57 | + "string_space" |
| 58 | + } |
| 59 | + |
| 60 | + fn signature(&self) -> &Signature { |
| 61 | + &self.signature |
69 | 62 | } |
70 | 63 |
|
71 | | - fn data_type(&self, input_schema: &Schema) -> datafusion::common::Result<DataType> { |
72 | | - match self.child.data_type(input_schema)? { |
| 64 | + fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { |
| 65 | + Ok(match &arg_types[0] { |
73 | 66 | DataType::Dictionary(key_type, _) => { |
74 | | - Ok(DataType::Dictionary(key_type, Box::new(DataType::Utf8))) |
| 67 | + DataType::Dictionary(key_type.clone(), Box::new(DataType::Utf8)) |
75 | 68 | } |
76 | | - _ => Ok(DataType::Utf8), |
77 | | - } |
| 69 | + _ => DataType::Utf8, |
| 70 | + }) |
78 | 71 | } |
79 | 72 |
|
80 | | - fn nullable(&self, _: &Schema) -> datafusion::common::Result<bool> { |
81 | | - Ok(true) |
| 73 | + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { |
| 74 | + let args: [ColumnarValue; 1] = args |
| 75 | + .args |
| 76 | + .try_into() |
| 77 | + .map_err(|_| internal_datafusion_err!("string_space expects exactly one argument"))?; |
| 78 | + spark_string_space(&args) |
82 | 79 | } |
83 | 80 |
|
84 | | - fn evaluate(&self, batch: &RecordBatch) -> datafusion::common::Result<ColumnarValue> { |
85 | | - let arg = self.child.evaluate(batch)?; |
86 | | - match arg { |
87 | | - ColumnarValue::Array(array) => { |
88 | | - let result = string_space(&array)?; |
| 81 | + fn aliases(&self) -> &[String] { |
| 82 | + &self.aliases |
| 83 | + } |
| 84 | +} |
89 | 85 |
|
90 | | - Ok(ColumnarValue::Array(result)) |
91 | | - } |
92 | | - _ => Err(DataFusionError::Execution( |
93 | | - "StringSpace(scalar) should be fold in Spark JVM side.".to_string(), |
94 | | - )), |
| 86 | +pub fn spark_string_space(args: &[ColumnarValue; 1]) -> Result<ColumnarValue> { |
| 87 | + match args { |
| 88 | + [ColumnarValue::Array(array)] => { |
| 89 | + let result = string_space(&array)?; |
| 90 | + |
| 91 | + Ok(ColumnarValue::Array(result)) |
95 | 92 | } |
| 93 | + _ => exec_err!("StringSpace(scalar) should be fold in Spark JVM side."), |
96 | 94 | } |
| 95 | +} |
97 | 96 |
|
98 | | - fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> { |
99 | | - vec![&self.child] |
| 97 | +fn string_space(length: &dyn Array) -> std::result::Result<ArrayRef, DataFusionError> { |
| 98 | + match length.data_type() { |
| 99 | + DataType::Int32 => { |
| 100 | + let array = length.as_any().downcast_ref::<Int32Array>().unwrap(); |
| 101 | + Ok(generic_string_space::<i32>(array)) |
| 102 | + } |
| 103 | + DataType::Dictionary(_, _) => { |
| 104 | + let dict = as_dictionary_array::<Int32Type>(length); |
| 105 | + let values = string_space(dict.values())?; |
| 106 | + let result = DictionaryArray::try_new(dict.keys().clone(), values)?; |
| 107 | + Ok(Arc::new(result)) |
| 108 | + } |
| 109 | + other => exec_err!("Unsupported input type for function 'string_space': {other:?}"), |
100 | 110 | } |
| 111 | +} |
101 | 112 |
|
102 | | - fn with_new_children( |
103 | | - self: Arc<Self>, |
104 | | - children: Vec<Arc<dyn PhysicalExpr>>, |
105 | | - ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> { |
106 | | - Ok(Arc::new(StringSpaceExpr::new(Arc::clone(&children[0])))) |
107 | | - } |
| 113 | +fn generic_string_space<OffsetSize: OffsetSizeTrait>(length: &Int32Array) -> ArrayRef { |
| 114 | + let array_len = length.len(); |
| 115 | + let mut offsets = MutableBuffer::new((array_len + 1) * std::mem::size_of::<OffsetSize>()); |
| 116 | + let mut length_so_far = OffsetSize::zero(); |
| 117 | + |
| 118 | + // compute null bitmap (copy) |
| 119 | + let null_bit_buffer = length.to_data().nulls().map(|b| b.buffer().clone()); |
| 120 | + |
| 121 | + // Gets slice of length array to access it directly for performance. |
| 122 | + let length_data = length.to_data(); |
| 123 | + let lengths = length_data.buffers()[0].typed_data::<i32>(); |
| 124 | + let total = lengths.iter().map(|l| *l as usize).sum::<usize>(); |
| 125 | + let mut values = MutableBuffer::new(total); |
| 126 | + |
| 127 | + offsets.push(length_so_far); |
| 128 | + |
| 129 | + let blank = " ".as_bytes()[0]; |
| 130 | + values.resize(total, blank); |
| 131 | + |
| 132 | + (0..array_len).for_each(|i| { |
| 133 | + let current_len = lengths[i] as usize; |
| 134 | + |
| 135 | + length_so_far += OffsetSize::from_usize(current_len).unwrap(); |
| 136 | + offsets.push(length_so_far); |
| 137 | + }); |
| 138 | + |
| 139 | + let data = unsafe { |
| 140 | + ArrayData::new_unchecked( |
| 141 | + GenericStringArray::<OffsetSize>::DATA_TYPE, |
| 142 | + array_len, |
| 143 | + None, |
| 144 | + null_bit_buffer, |
| 145 | + 0, |
| 146 | + vec![offsets.into(), values.into()], |
| 147 | + vec![], |
| 148 | + ) |
| 149 | + }; |
| 150 | + make_array(data) |
108 | 151 | } |
0 commit comments