diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 1231ad7809..98c4780aab 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -1960,6 +1960,16 @@ def rstrip(self) -> Expression: return rstrip(self) + def trim(self) -> Expression: + """Strip whitespace from both sides of a UTF-8 string. + + Tip: See Also + [`daft.functions.trim`](https://docs.daft.ai/en/stable/api/functions/trim/) + """ + from daft.functions import trim + + return trim(self) + def reverse(self) -> Expression: """Reverse a UTF-8 string. diff --git a/daft/functions/__init__.py b/daft/functions/__init__.py index e0ae46e12d..866a6ff3a5 100644 --- a/daft/functions/__init__.py +++ b/daft/functions/__init__.py @@ -187,6 +187,8 @@ upper, lstrip, rstrip, + strip, + trim, reverse, capitalize, to_camel_case, @@ -424,6 +426,7 @@ "stddev", "strftime", "string_agg", + "strip", "substr", "sum", "tan", @@ -450,6 +453,7 @@ "total_minutes", "total_nanoseconds", "total_seconds", + "trim", "try_compress", "try_decode", "try_decompress", diff --git a/daft/functions/str.py b/daft/functions/str.py index 36655cac52..9957c07c00 100644 --- a/daft/functions/str.py +++ b/daft/functions/str.py @@ -336,6 +336,77 @@ def rstrip(expr: Expression) -> Expression: return Expression._call_builtin_scalar_fn("rstrip", expr) +def trim(expr: Expression) -> Expression: + """Strip whitespace from both sides of string. + + Returns: + Expression: a String expression which is `self` with leading and trailing whitespace stripped + + Examples: + >>> import daft + >>> from daft.functions import trim + >>> df = daft.from_pydict({"x": ["foo", "bar", " baz "]}) + >>> df = df.select(trim(df["x"])) + >>> df.show() + ╭────────╮ + │ x │ + │ --- │ + │ String │ + ╞════════╡ + │ foo │ + ├╌╌╌╌╌╌╌╌┤ + │ bar │ + ├╌╌╌╌╌╌╌╌┤ + │ baz │ + ╰────────╯ + + (Showing first 3 of 3 rows) + + """ + return Expression._call_builtin_scalar_fn("trim", expr) + + +def strip(expr: Expression, mode: Literal["left", "right", "both"] = "both") -> Expression: + """Strip whitespace from string. + + Args: + expr: The expression to strip whitespace from. + mode: The mode to use for stripping whitespace. Can be "left", "right", or "both". Defaults to "both". + + Returns: + Expression: a String expression which is `self` with whitespace stripped according to the mode + + Examples: + >>> import daft + >>> from daft.functions import strip + >>> df = daft.from_pydict({"x": ["foo", "bar", " baz "]}) + >>> df = df.select(strip(df["x"], mode="both")) + >>> df.show() + ╭────────╮ + │ x │ + │ --- │ + │ String │ + ╞════════╡ + │ foo │ + ├╌╌╌╌╌╌╌╌┤ + │ bar │ + ├╌╌╌╌╌╌╌╌┤ + │ baz │ + ╰────────╯ + + (Showing first 3 of 3 rows) + + """ + if mode == "left": + return lstrip(expr) + elif mode == "right": + return rstrip(expr) + elif mode == "both": + return trim(expr) + else: + raise ValueError(f"Invalid mode: {mode}. Must be one of 'left', 'right', or 'both'.") + + def reverse(expr: Expression) -> Expression: """Reverse a UTF-8 string. diff --git a/daft/series.py b/daft/series.py index 8256b2a972..31b67e5ed4 100644 --- a/daft/series.py +++ b/daft/series.py @@ -927,6 +927,9 @@ def lstrip(self) -> Series: def rstrip(self) -> Series: return self._eval_expressions("rstrip") + def trim(self) -> Series: + return self._eval_expressions("trim") + def reverse(self) -> Series: return self._eval_expressions("reverse") diff --git a/src/daft-functions-utf8/src/lib.rs b/src/daft-functions-utf8/src/lib.rs index 59273f110c..6d67f98155 100644 --- a/src/daft-functions-utf8/src/lib.rs +++ b/src/daft-functions-utf8/src/lib.rs @@ -28,6 +28,7 @@ mod startswith; mod substr; mod to_date; mod to_datetime; +mod trim; mod upper; pub(crate) mod utils; @@ -60,6 +61,7 @@ pub use startswith::*; pub use substr::*; pub use to_date::*; pub use to_datetime::*; +pub use trim::*; pub use upper::*; pub struct Utf8Functions; @@ -93,6 +95,7 @@ impl daft_dsl::functions::FunctionModule for Utf8Functions { parent.add_fn(Right); parent.add_fn(RPad); parent.add_fn(RStrip); + parent.add_fn(Trim); parent.add_fn(SnakeCase); parent.add_fn(Split); parent.add_fn(StartsWith); diff --git a/src/daft-functions-utf8/src/trim.rs b/src/daft-functions-utf8/src/trim.rs new file mode 100644 index 0000000000..81243215ea --- /dev/null +++ b/src/daft-functions-utf8/src/trim.rs @@ -0,0 +1,52 @@ +use common_error::DaftResult; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::{IntoSeries, Series}, +}; +use daft_dsl::{ + ExprRef, + functions::{FunctionArgs, ScalarUDF, scalar::ScalarFn}, +}; +use serde::{Deserialize, Serialize}; + +use crate::utils::{Utf8ArrayUtils, unary_utf8_evaluate, unary_utf8_to_field}; + +#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Trim; + +#[typetag::serde] +impl ScalarUDF for Trim { + fn name(&self) -> &'static str { + "trim" + } + + fn call( + &self, + inputs: daft_dsl::functions::FunctionArgs, + _ctx: &daft_dsl::functions::scalar::EvalContext, + ) -> DaftResult { + unary_utf8_evaluate(inputs, |s| { + s.with_utf8_array(|arr| { + arr.unary_broadcasted_op(|val| val.trim().into()) + .map(IntoSeries::into_series) + }) + }) + } + + fn get_return_field( + &self, + inputs: FunctionArgs, + schema: &Schema, + ) -> DaftResult { + unary_utf8_to_field(inputs, schema, self.name(), DataType::Utf8) + } + + fn docstring(&self) -> &'static str { + "Removes leading and trailing whitespace from the string" + } +} + +#[must_use] +pub fn trim(input: ExprRef) -> ExprRef { + ScalarFn::builtin(Trim {}, vec![input]).into() +} diff --git a/tests/recordbatch/utf8/test_trim.py b/tests/recordbatch/utf8/test_trim.py new file mode 100644 index 0000000000..9d8020842a --- /dev/null +++ b/tests/recordbatch/utf8/test_trim.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from daft.expressions import col +from daft.recordbatch import MicroPartition + + +def test_utf8_trim(): + table = MicroPartition.from_pydict({"col": ["\ta\t", None, "\nb\n", "\vc\t", "\td ", "\ne", "f\n", "g"]}) + result = table.eval_expression_list([col("col").trim()]) + assert result.to_pydict() == {"col": ["a", None, "b", "c", "d", "e", "f", "g"]}