Skip to content
40 changes: 38 additions & 2 deletions datafusion/functions/src/string/octet_length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
// specific language governing permissions and limitations
// under the License.

use arrow::compute::kernels::length::length;
use arrow::datatypes::DataType;
use std::any::Any;

use crate::utils::utf8_to_int_type;
use arrow::array::{
Array, Int32Array, Int32Builder, Int64Builder, LargeStringArray, StringArray,
StringViewArray,
};
use datafusion_common::types::logical_string;
use datafusion_common::utils::take_function_args;
use datafusion_common::{Result, ScalarValue};
Expand All @@ -28,6 +31,7 @@ use datafusion_expr::{
TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;
use std::sync::Arc;

#[user_doc(
doc_section(label = "String Functions"),
Expand Down Expand Up @@ -90,7 +94,39 @@ impl ScalarUDFImpl for OctetLengthFunc {
let [array] = take_function_args(self.name(), &args.args)?;

match array {
ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)),
ColumnarValue::Array(v) => {
if let Some(arr) = v.as_any().downcast_ref::<StringArray>() {
let mut builder = Int32Builder::with_capacity(arr.len());
for i in 0..arr.len() {
if arr.is_null(i) {
builder.append_null();
} else {
builder.append_value(arr.value_length(i));
}
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
} else if let Some(arr) = v.as_any().downcast_ref::<LargeStringArray>() {
let mut builder = Int64Builder::with_capacity(arr.len());
for i in 0..arr.len() {
if arr.is_null(i) {
builder.append_null();
} else {
builder.append_value(arr.value_length(i));
}
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
} else if let Some(arr) = v.as_any().downcast_ref::<StringViewArray>() {
let result = arr
.iter()
.map(|s| s.map(|s| s.len() as i32))
.collect::<Int32Array>();

Ok(ColumnarValue::Array(Arc::new(result)))
} else {
unreachable!("octet_length expects string arrays")
}
}

ColumnarValue::Scalar(v) => match v {
ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32(
v.as_ref().map(|x| x.len() as i32),
Expand Down