From d286bac3e947be33630c0c65bb1b2f3db16e65fd Mon Sep 17 00:00:00 2001 From: Brijesh-Thakkar Date: Fri, 26 Dec 2025 23:13:07 +0530 Subject: [PATCH 1/2] perf: implement native Rust trim functions for better performance Fixes #2977 Implements native Rust implementations for string trimming functions to address performance regression in issue #2977. Changes: - Add trim.rs with spark_trim, spark_ltrim, spark_rtrim, spark_btrim - Use efficient Arrow array operations - Include fast-path for strings without whitespace - Handle both Utf8 and LargeUtf8 types - Add comprehensive unit tests Implementation avoids JVM overhead and unnecessary allocations that caused the 0.6-0.7x performance shown in benchmarks. Expected to achieve >1.0x performance vs Spark baseline. Testing: - Build successful - Unit tests pass - CI will verify benchmark improvements --- native/spark-expr/src/comet_scalar_funcs.rs | 16 + native/spark-expr/src/string_funcs/mod.rs | 2 +- native/spark-expr/src/string_funcs/trim.rs | 319 ++++++++++++++++++++ 3 files changed, 336 insertions(+), 1 deletion(-) create mode 100644 native/spark-expr/src/string_funcs/trim.rs diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 8384a4646a..df83198239 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -181,6 +181,22 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(abs); make_comet_scalar_udf!("abs", func, without data_type) } + "trim" => { + let func = Arc::new(crate::string_funcs::trim::spark_trim); + make_comet_scalar_udf!("trim", func, without data_type) + } + "btrim" => { + let func = Arc::new(crate::string_funcs::trim::spark_btrim); + make_comet_scalar_udf!("btrim", func, without data_type) + } + "ltrim" => { + let func = Arc::new(crate::string_funcs::trim::spark_ltrim); + make_comet_scalar_udf!("ltrim", func, without data_type) + } + "rtrim" => { + let func = Arc::new(crate::string_funcs::trim::spark_rtrim); + make_comet_scalar_udf!("rtrim", func, without data_type) + } _ => registry.udf(fun_name).map_err(|e| { DataFusionError::Execution(format!( "Function {fun_name} not found in the registry: {e}", diff --git a/native/spark-expr/src/string_funcs/mod.rs b/native/spark-expr/src/string_funcs/mod.rs index aac8204e29..afb431ca26 100644 --- a/native/spark-expr/src/string_funcs/mod.rs +++ b/native/spark-expr/src/string_funcs/mod.rs @@ -17,6 +17,6 @@ mod string_space; mod substring; - +pub mod trim; pub use string_space::SparkStringSpace; pub use substring::SubstringExpr; diff --git a/native/spark-expr/src/string_funcs/trim.rs b/native/spark-expr/src/string_funcs/trim.rs new file mode 100644 index 0000000000..a7bf50a452 --- /dev/null +++ b/native/spark-expr/src/string_funcs/trim.rs @@ -0,0 +1,319 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +use arrow::array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion::common::{Result as DataFusionResult, ScalarValue}; +use datafusion::logical_expr::ColumnarValue; +use std::sync::Arc; + +/// Trims whitespace from both ends of a string (Spark's TRIM function) +pub fn spark_trim(args: &[ColumnarValue]) -> DataFusionResult { + if args.len() != 1 { + return Err(datafusion::common::DataFusionError::Execution( + format!("trim expects 1 argument, got {}", args.len()), + )); + } + + match &args[0] { + ColumnarValue::Array(array) => { + let result = trim_array(array, TrimType::Both)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + s.trim().to_string(), + )))) + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) + } + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => { + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some( + s.trim().to_string(), + )))) + } + ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => { + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(None))) + } + _ => Err(datafusion::common::DataFusionError::Execution( + "trim expects string argument".to_string(), + )), + } +} + +/// Trims whitespace from the left side of a string (Spark's LTRIM function) +pub fn spark_ltrim(args: &[ColumnarValue]) -> DataFusionResult { + if args.len() != 1 { + return Err(datafusion::common::DataFusionError::Execution( + format!("ltrim expects 1 argument, got {}", args.len()), + )); + } + + match &args[0] { + ColumnarValue::Array(array) => { + let result = trim_array(array, TrimType::Left)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + s.trim_start().to_string(), + )))) + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) + } + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => { + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some( + s.trim_start().to_string(), + )))) + } + ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => { + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(None))) + } + _ => Err(datafusion::common::DataFusionError::Execution( + "ltrim expects string argument".to_string(), + )), + } +} + +/// Trims whitespace from the right side of a string (Spark's RTRIM function) +pub fn spark_rtrim(args: &[ColumnarValue]) -> DataFusionResult { + if args.len() != 1 { + return Err(datafusion::common::DataFusionError::Execution( + format!("rtrim expects 1 argument, got {}", args.len()), + )); + } + + match &args[0] { + ColumnarValue::Array(array) => { + let result = trim_array(array, TrimType::Right)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + s.trim_end().to_string(), + )))) + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) + } + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => { + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some( + s.trim_end().to_string(), + )))) + } + ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => { + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(None))) + } + _ => Err(datafusion::common::DataFusionError::Execution( + "rtrim expects string argument".to_string(), + )), + } +} + +/// Trims whitespace from both ends of a string (alias for trim, Spark's BTRIM function) +pub fn spark_btrim(args: &[ColumnarValue]) -> DataFusionResult { + spark_trim(args) +} + +#[derive(Debug, Clone, Copy)] +enum TrimType { + Left, + Right, + Both, +} + +/// Generic function to trim string arrays +fn trim_array(array: &ArrayRef, trim_type: TrimType) -> DataFusionResult { + let data_type = array.data_type(); + + match data_type { + DataType::Utf8 => { + let string_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion::common::DataFusionError::Execution( + "Failed to downcast to StringArray".to_string(), + ) + })?; + Ok(Arc::new(trim_string_array(string_array, trim_type))) + } + DataType::LargeUtf8 => { + let string_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion::common::DataFusionError::Execution( + "Failed to downcast to LargeStringArray".to_string(), + ) + })?; + Ok(Arc::new(trim_string_array(string_array, trim_type))) + } + _ => Err(datafusion::common::DataFusionError::Execution(format!( + "trim expects string type, got {:?}", + data_type + ))), + } +} + +/// Optimized trim implementation for GenericStringArray +fn trim_string_array( + array: &GenericStringArray, + trim_type: TrimType, +) -> GenericStringArray { + // Fast path: Check if any strings actually need trimming + // If not, we can return a clone of the original array + let needs_trimming = (0..array.len()).any(|i| { + if array.is_null(i) { + false + } else { + let s = array.value(i); + match trim_type { + TrimType::Left => s.starts_with(|c: char| c.is_whitespace()), + TrimType::Right => s.ends_with(|c: char| c.is_whitespace()), + TrimType::Both => { + s.starts_with(|c: char| c.is_whitespace()) + || s.ends_with(|c: char| c.is_whitespace()) + } + } + } + }); + + if !needs_trimming { + // No trimming needed, return a clone of the input + return array.clone(); + } + + // Slow path: Build new array with trimmed strings + let mut builder = arrow::array::GenericStringBuilder::::with_capacity( + array.len(), + array.get_buffer_memory_size(), + ); + + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null(); + } else { + let s = array.value(i); + let trimmed = match trim_type { + TrimType::Left => s.trim_start(), + TrimType::Right => s.trim_end(), + TrimType::Both => s.trim(), + }; + builder.append_value(trimmed); + } + } + + builder.finish() +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::StringArray; + + #[test] + fn test_trim() { + let input = StringArray::from(vec![ + Some(" hello "), + Some("world"), + Some(" spaces "), + None, + ]); + let input_ref: ArrayRef = Arc::new(input); + + let result = trim_array(&input_ref, TrimType::Both).unwrap(); + let result_array = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result_array.value(0), "hello"); + assert_eq!(result_array.value(1), "world"); + assert_eq!(result_array.value(2), "spaces"); + assert!(result_array.is_null(3)); + } + + #[test] + fn test_ltrim() { + let input = StringArray::from(vec![Some(" hello "), Some("world ")]); + let input_ref: ArrayRef = Arc::new(input); + + let result = trim_array(&input_ref, TrimType::Left).unwrap(); + let result_array = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result_array.value(0), "hello "); + assert_eq!(result_array.value(1), "world "); + } + + #[test] + fn test_rtrim() { + let input = StringArray::from(vec![Some(" hello "), Some(" world")]); + let input_ref: ArrayRef = Arc::new(input); + + let result = trim_array(&input_ref, TrimType::Right).unwrap(); + let result_array = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result_array.value(0), " hello"); + assert_eq!(result_array.value(1), " world"); + } + + #[test] + fn test_trim_no_whitespace_fast_path() { + // Test the fast path where no trimming is needed + let input = StringArray::from(vec![ + Some("hello"), + Some("world"), + Some("no spaces"), + None, + ]); + let input_ref: ArrayRef = Arc::new(input.clone()); + + let result = trim_array(&input_ref, TrimType::Both).unwrap(); + let result_array = result.as_any().downcast_ref::().unwrap(); + + // Verify values are correct + assert_eq!(result_array.value(0), "hello"); + assert_eq!(result_array.value(1), "world"); + assert_eq!(result_array.value(2), "no spaces"); + assert!(result_array.is_null(3)); + } + + #[test] + fn test_ltrim_no_whitespace() { + // Test ltrim with strings that have no leading whitespace + let input = StringArray::from(vec![Some("hello "), Some("world")]); + let input_ref: ArrayRef = Arc::new(input); + + let result = trim_array(&input_ref, TrimType::Left).unwrap(); + let result_array = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result_array.value(0), "hello "); + assert_eq!(result_array.value(1), "world"); + } + + #[test] + fn test_rtrim_no_whitespace() { + // Test rtrim with strings that have no trailing whitespace + let input = StringArray::from(vec![Some(" hello"), Some("world")]); + let input_ref: ArrayRef = Arc::new(input); + + let result = trim_array(&input_ref, TrimType::Right).unwrap(); + let result_array = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result_array.value(0), " hello"); + assert_eq!(result_array.value(1), "world"); + } +} From 8d833beef6b41e275724049766d886e7a7e2f4d5 Mon Sep 17 00:00:00 2001 From: Brijesh-Thakkar Date: Sat, 27 Dec 2025 01:21:11 +0530 Subject: [PATCH 2/2] Simplify trim implementation, add TODO for 2-arg support - Use simple Rust string trim methods - Works for 1-argument case (whitespace trimming) - Add TODO for 2-argument case (custom chars) - All tests pass --- native/spark-expr/src/string_funcs/trim.rs | 352 +++++++-------------- 1 file changed, 123 insertions(+), 229 deletions(-) diff --git a/native/spark-expr/src/string_funcs/trim.rs b/native/spark-expr/src/string_funcs/trim.rs index a7bf50a452..43335148bd 100644 --- a/native/spark-expr/src/string_funcs/trim.rs +++ b/native/spark-expr/src/string_funcs/trim.rs @@ -14,120 +14,33 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait}; -use arrow::datatypes::DataType; -use datafusion::common::{Result as DataFusionResult, ScalarValue}; + +//! String trimming functions + +use arrow::array::{Array, ArrayRef, StringArray}; +use datafusion::common::ScalarValue; +use datafusion::error::Result as DataFusionResult; use datafusion::logical_expr::ColumnarValue; use std::sync::Arc; -/// Trims whitespace from both ends of a string (Spark's TRIM function) +/// Trim whitespace from both ends of a string pub fn spark_trim(args: &[ColumnarValue]) -> DataFusionResult { - if args.len() != 1 { - return Err(datafusion::common::DataFusionError::Execution( - format!("trim expects 1 argument, got {}", args.len()), - )); - } + trim_impl(args, TrimType::Both) +} - match &args[0] { - ColumnarValue::Array(array) => { - let result = trim_array(array, TrimType::Both)?; - Ok(ColumnarValue::Array(result)) - } - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( - s.trim().to_string(), - )))) - } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) - } - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => { - Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some( - s.trim().to_string(), - )))) - } - ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => { - Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(None))) - } - _ => Err(datafusion::common::DataFusionError::Execution( - "trim expects string argument".to_string(), - )), - } +/// Trim whitespace from both ends (alias for trim) +pub fn spark_btrim(args: &[ColumnarValue]) -> DataFusionResult { + trim_impl(args, TrimType::Both) } -/// Trims whitespace from the left side of a string (Spark's LTRIM function) +/// Trim whitespace from the left/start pub fn spark_ltrim(args: &[ColumnarValue]) -> DataFusionResult { - if args.len() != 1 { - return Err(datafusion::common::DataFusionError::Execution( - format!("ltrim expects 1 argument, got {}", args.len()), - )); - } - - match &args[0] { - ColumnarValue::Array(array) => { - let result = trim_array(array, TrimType::Left)?; - Ok(ColumnarValue::Array(result)) - } - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( - s.trim_start().to_string(), - )))) - } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) - } - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => { - Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some( - s.trim_start().to_string(), - )))) - } - ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => { - Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(None))) - } - _ => Err(datafusion::common::DataFusionError::Execution( - "ltrim expects string argument".to_string(), - )), - } + trim_impl(args, TrimType::Left) } -/// Trims whitespace from the right side of a string (Spark's RTRIM function) +/// Trim whitespace from the right/end pub fn spark_rtrim(args: &[ColumnarValue]) -> DataFusionResult { - if args.len() != 1 { - return Err(datafusion::common::DataFusionError::Execution( - format!("rtrim expects 1 argument, got {}", args.len()), - )); - } - - match &args[0] { - ColumnarValue::Array(array) => { - let result = trim_array(array, TrimType::Right)?; - Ok(ColumnarValue::Array(result)) - } - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( - s.trim_end().to_string(), - )))) - } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) - } - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => { - Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some( - s.trim_end().to_string(), - )))) - } - ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => { - Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(None))) - } - _ => Err(datafusion::common::DataFusionError::Execution( - "rtrim expects string argument".to_string(), - )), - } -} - -/// Trims whitespace from both ends of a string (alias for trim, Spark's BTRIM function) -pub fn spark_btrim(args: &[ColumnarValue]) -> DataFusionResult { - spark_trim(args) + trim_impl(args, TrimType::Right) } #[derive(Debug, Clone, Copy)] @@ -137,79 +50,48 @@ enum TrimType { Both, } -/// Generic function to trim string arrays -fn trim_array(array: &ArrayRef, trim_type: TrimType) -> DataFusionResult { - let data_type = array.data_type(); - - match data_type { - DataType::Utf8 => { - let string_array = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - datafusion::common::DataFusionError::Execution( - "Failed to downcast to StringArray".to_string(), - ) - })?; - Ok(Arc::new(trim_string_array(string_array, trim_type))) +fn trim_impl(args: &[ColumnarValue], trim_type: TrimType) -> DataFusionResult { + if args.is_empty() || args.len() > 2 { + return Err(datafusion::error::DataFusionError::Execution( + format!("trim expects 1 or 2 arguments, got {}", args.len()), + )); + } + + // For now, only support single argument (whitespace trimming) + // TODO: Add support for custom trim characters (2-argument form) + if args.len() == 2 { + return Err(datafusion::error::DataFusionError::NotImplemented( + "trim with custom characters not yet implemented".to_string(), + )); + } + + match &args[0] { + ColumnarValue::Array(array) => { + let result = trim_array(array, trim_type)?; + Ok(ColumnarValue::Array(result)) } - DataType::LargeUtf8 => { - let string_array = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - datafusion::common::DataFusionError::Execution( - "Failed to downcast to LargeStringArray".to_string(), - ) - })?; - Ok(Arc::new(trim_string_array(string_array, trim_type))) + ColumnarValue::Scalar(scalar) => { + let result = trim_scalar(scalar, trim_type)?; + Ok(ColumnarValue::Scalar(result)) } - _ => Err(datafusion::common::DataFusionError::Execution(format!( - "trim expects string type, got {:?}", - data_type - ))), } } -/// Optimized trim implementation for GenericStringArray -fn trim_string_array( - array: &GenericStringArray, - trim_type: TrimType, -) -> GenericStringArray { - // Fast path: Check if any strings actually need trimming - // If not, we can return a clone of the original array - let needs_trimming = (0..array.len()).any(|i| { - if array.is_null(i) { - false - } else { - let s = array.value(i); - match trim_type { - TrimType::Left => s.starts_with(|c: char| c.is_whitespace()), - TrimType::Right => s.ends_with(|c: char| c.is_whitespace()), - TrimType::Both => { - s.starts_with(|c: char| c.is_whitespace()) - || s.ends_with(|c: char| c.is_whitespace()) - } - } - } - }); - - if !needs_trimming { - // No trimming needed, return a clone of the input - return array.clone(); - } +fn trim_array(array: &ArrayRef, trim_type: TrimType) -> DataFusionResult { + let string_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion::error::DataFusionError::Execution("Expected string array".to_string()) + })?; - // Slow path: Build new array with trimmed strings - let mut builder = arrow::array::GenericStringBuilder::::with_capacity( - array.len(), - array.get_buffer_memory_size(), - ); + let mut builder = arrow::array::StringBuilder::new(); - for i in 0..array.len() { - if array.is_null(i) { + for i in 0..string_array.len() { + if string_array.is_null(i) { builder.append_null(); } else { - let s = array.value(i); + let s = string_array.value(i); let trimmed = match trim_type { TrimType::Left => s.trim_start(), TrimType::Right => s.trim_end(), @@ -219,16 +101,41 @@ fn trim_string_array( } } - builder.finish() + Ok(Arc::new(builder.finish())) +} + +fn trim_scalar(scalar: &ScalarValue, trim_type: TrimType) -> DataFusionResult { + match scalar { + ScalarValue::Utf8(Some(s)) => { + let trimmed = match trim_type { + TrimType::Left => s.trim_start(), + TrimType::Right => s.trim_end(), + TrimType::Both => s.trim(), + }; + Ok(ScalarValue::Utf8(Some(trimmed.to_string()))) + } + ScalarValue::Utf8(None) => Ok(ScalarValue::Utf8(None)), + ScalarValue::LargeUtf8(Some(s)) => { + let trimmed = match trim_type { + TrimType::Left => s.trim_start(), + TrimType::Right => s.trim_end(), + TrimType::Both => s.trim(), + }; + Ok(ScalarValue::LargeUtf8(Some(trimmed.to_string()))) + } + ScalarValue::LargeUtf8(None) => Ok(ScalarValue::LargeUtf8(None)), + _ => Err(datafusion::error::DataFusionError::Execution( + "trim expects string argument".to_string(), + )), + } } #[cfg(test)] mod tests { use super::*; - use arrow::array::StringArray; #[test] - fn test_trim() { + fn test_trim_whitespace() { let input = StringArray::from(vec![ Some(" hello "), Some("world"), @@ -236,84 +143,71 @@ mod tests { None, ]); let input_ref: ArrayRef = Arc::new(input); + let args = vec![ColumnarValue::Array(input_ref)]; - let result = trim_array(&input_ref, TrimType::Both).unwrap(); - let result_array = result.as_any().downcast_ref::().unwrap(); + let result = spark_trim(&args).unwrap(); - assert_eq!(result_array.value(0), "hello"); - assert_eq!(result_array.value(1), "world"); - assert_eq!(result_array.value(2), "spaces"); - assert!(result_array.is_null(3)); + match result { + ColumnarValue::Array(arr) => { + let result_array = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(result_array.value(0), "hello"); + assert_eq!(result_array.value(1), "world"); + assert_eq!(result_array.value(2), "spaces"); + assert!(result_array.is_null(3)); + } + _ => panic!("Expected array result"), + } } #[test] - fn test_ltrim() { + fn test_ltrim_whitespace() { let input = StringArray::from(vec![Some(" hello "), Some("world ")]); let input_ref: ArrayRef = Arc::new(input); + let args = vec![ColumnarValue::Array(input_ref)]; - let result = trim_array(&input_ref, TrimType::Left).unwrap(); - let result_array = result.as_any().downcast_ref::().unwrap(); + let result = spark_ltrim(&args).unwrap(); - assert_eq!(result_array.value(0), "hello "); - assert_eq!(result_array.value(1), "world "); + match result { + ColumnarValue::Array(arr) => { + let result_array = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(result_array.value(0), "hello "); + assert_eq!(result_array.value(1), "world "); + } + _ => panic!("Expected array result"), + } } #[test] - fn test_rtrim() { + fn test_rtrim_whitespace() { let input = StringArray::from(vec![Some(" hello "), Some(" world")]); let input_ref: ArrayRef = Arc::new(input); + let args = vec![ColumnarValue::Array(input_ref)]; - let result = trim_array(&input_ref, TrimType::Right).unwrap(); - let result_array = result.as_any().downcast_ref::().unwrap(); - - assert_eq!(result_array.value(0), " hello"); - assert_eq!(result_array.value(1), " world"); - } - - #[test] - fn test_trim_no_whitespace_fast_path() { - // Test the fast path where no trimming is needed - let input = StringArray::from(vec![ - Some("hello"), - Some("world"), - Some("no spaces"), - None, - ]); - let input_ref: ArrayRef = Arc::new(input.clone()); - - let result = trim_array(&input_ref, TrimType::Both).unwrap(); - let result_array = result.as_any().downcast_ref::().unwrap(); - - // Verify values are correct - assert_eq!(result_array.value(0), "hello"); - assert_eq!(result_array.value(1), "world"); - assert_eq!(result_array.value(2), "no spaces"); - assert!(result_array.is_null(3)); - } - - #[test] - fn test_ltrim_no_whitespace() { - // Test ltrim with strings that have no leading whitespace - let input = StringArray::from(vec![Some("hello "), Some("world")]); - let input_ref: ArrayRef = Arc::new(input); - - let result = trim_array(&input_ref, TrimType::Left).unwrap(); - let result_array = result.as_any().downcast_ref::().unwrap(); + let result = spark_rtrim(&args).unwrap(); - assert_eq!(result_array.value(0), "hello "); - assert_eq!(result_array.value(1), "world"); + match result { + ColumnarValue::Array(arr) => { + let result_array = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(result_array.value(0), " hello"); + assert_eq!(result_array.value(1), " world"); + } + _ => panic!("Expected array result"), + } } #[test] - fn test_rtrim_no_whitespace() { - // Test rtrim with strings that have no trailing whitespace - let input = StringArray::from(vec![Some(" hello"), Some("world")]); - let input_ref: ArrayRef = Arc::new(input); + fn test_trim_scalar() { + let args = vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( + " hello ".to_string(), + )))]; - let result = trim_array(&input_ref, TrimType::Right).unwrap(); - let result_array = result.as_any().downcast_ref::().unwrap(); + let result = spark_trim(&args).unwrap(); - assert_eq!(result_array.value(0), " hello"); - assert_eq!(result_array.value(1), "world"); + match result { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { + assert_eq!(s, "hello"); + } + _ => panic!("Expected scalar result"), + } } -} +} \ No newline at end of file