diff --git a/Cargo.lock b/Cargo.lock index fd18cd3524..81c73e6f31 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2903,6 +2903,7 @@ dependencies = [ "daft-arrow", "daft-core", "daft-dsl", + "heck 0.5.0", "itertools 0.14.0", "num-traits", "regex", diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 79740752e9..d030366400 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -1827,6 +1827,76 @@ def capitalize(self) -> Expression: return capitalize(self) + def to_camel_case(self) -> Expression: + """Convert a string to lower camel case. + + Tip: See Also + [`daft.functions.to_camel_case`](https://docs.daft.ai/en/stable/api/functions/to_camel_case/) + """ + from daft.functions import to_camel_case + + return to_camel_case(self) + + def to_upper_camel_case(self) -> Expression: + """Convert a string to upper camel case. + + Tip: See Also + [`daft.functions.to_upper_camel_case`](https://docs.daft.ai/en/stable/api/functions/to_upper_camel_case/) + """ + from daft.functions import to_upper_camel_case + + return to_upper_camel_case(self) + + def to_snake_case(self) -> Expression: + """Convert a string to snake case. + + Tip: See Also + [`daft.functions.to_snake_case`](https://docs.daft.ai/en/stable/api/functions/to_snake_case/) + """ + from daft.functions import to_snake_case + + return to_snake_case(self) + + def to_upper_snake_case(self) -> Expression: + """Convert a string to upper snake case. + + Tip: See Also + [`daft.functions.to_upper_snake_case`](https://docs.daft.ai/en/stable/api/functions/to_upper_snake_case/) + """ + from daft.functions import to_upper_snake_case + + return to_upper_snake_case(self) + + def to_kebab_case(self) -> Expression: + """Convert a string to kebab case. + + Tip: See Also + [`daft.functions.to_kebab_case`](https://docs.daft.ai/en/stable/api/functions/to_kebab_case/) + """ + from daft.functions import to_kebab_case + + return to_kebab_case(self) + + def to_upper_kebab_case(self) -> Expression: + """Convert a string to upper kebab case. + + Tip: See Also + [`daft.functions.to_upper_kebab_case`](https://docs.daft.ai/en/stable/api/functions/to_upper_kebab_case/) + """ + from daft.functions import to_upper_kebab_case + + return to_upper_kebab_case(self) + + def to_title_case(self) -> Expression: + """Convert a string to title case. + + Tip: See Also + [`daft.functions.to_title_case`](https://docs.daft.ai/en/stable/api/functions/to_title_case/) + """ + from daft.functions import to_title_case + + return to_title_case(self) + def left(self, nchars: int | Expression) -> Expression: """Gets the n (from nchars) left-most characters of each string. diff --git a/daft/functions/__init__.py b/daft/functions/__init__.py index c210a4932c..45a73d95d3 100644 --- a/daft/functions/__init__.py +++ b/daft/functions/__init__.py @@ -181,6 +181,13 @@ rstrip, reverse, capitalize, + to_camel_case, + to_upper_camel_case, + to_snake_case, + to_upper_snake_case, + to_kebab_case, + to_upper_kebab_case, + to_title_case, left, right, rpad, @@ -408,11 +415,18 @@ "tan", "tanh", "time", + "to_camel_case", "to_date", "to_datetime", + "to_kebab_case", "to_list", + "to_snake_case", "to_struct", + "to_title_case", "to_unix_epoch", + "to_upper_camel_case", + "to_upper_kebab_case", + "to_upper_snake_case", "tokenize_decode", "tokenize_encode", "total_days", diff --git a/daft/functions/str.py b/daft/functions/str.py index 72536f2643..36655cac52 100644 --- a/daft/functions/str.py +++ b/daft/functions/str.py @@ -396,6 +396,69 @@ def capitalize(expr: Expression) -> Expression: return Expression._call_builtin_scalar_fn("capitalize", expr) +def to_camel_case(expr: Expression) -> Expression: + """Convert a string to lower camel case. + + Returns: + Expression: a String expression converted to lower camel case + """ + return Expression._call_builtin_scalar_fn("to_camel_case", expr) + + +def to_upper_camel_case(expr: Expression) -> Expression: + """Convert a string to upper camel case. + + Returns: + Expression: a String expression converted to upper camel case + """ + return Expression._call_builtin_scalar_fn("to_upper_camel_case", expr) + + +def to_snake_case(expr: Expression) -> Expression: + """Convert a string to snake case. + + Returns: + Expression: a String expression converted to snake case + """ + return Expression._call_builtin_scalar_fn("to_snake_case", expr) + + +def to_upper_snake_case(expr: Expression) -> Expression: + """Convert a string to upper snake case. + + Returns: + Expression: a String expression converted to upper snake case + """ + return Expression._call_builtin_scalar_fn("to_upper_snake_case", expr) + + +def to_kebab_case(expr: Expression) -> Expression: + """Convert a string to kebab case. + + Returns: + Expression: a String expression converted to kebab case + """ + return Expression._call_builtin_scalar_fn("to_kebab_case", expr) + + +def to_upper_kebab_case(expr: Expression) -> Expression: + """Convert a string to upper kebab case. + + Returns: + Expression: a String expression converted to upper kebab case + """ + return Expression._call_builtin_scalar_fn("to_upper_kebab_case", expr) + + +def to_title_case(expr: Expression) -> Expression: + """Convert a string to title case. + + Returns: + Expression: a String expression converted to title case + """ + return Expression._call_builtin_scalar_fn("to_title_case", expr) + + def left(expr: Expression, nchars: int | Expression) -> Expression: """Gets the n (from nchars) left-most characters of each string. diff --git a/daft/series.py b/daft/series.py index bd36c325d8..0dd248f130 100644 --- a/daft/series.py +++ b/daft/series.py @@ -936,6 +936,27 @@ def reverse(self) -> Series: def capitalize(self) -> Series: return self._eval_expressions("capitalize") + def to_camel_case(self) -> Series: + return self._eval_expressions("to_camel_case") + + def to_upper_camel_case(self) -> Series: + return self._eval_expressions("to_upper_camel_case") + + def to_snake_case(self) -> Series: + return self._eval_expressions("to_snake_case") + + def to_upper_snake_case(self) -> Series: + return self._eval_expressions("to_upper_snake_case") + + def to_kebab_case(self) -> Series: + return self._eval_expressions("to_kebab_case") + + def to_upper_kebab_case(self) -> Series: + return self._eval_expressions("to_upper_kebab_case") + + def to_title_case(self) -> Series: + return self._eval_expressions("to_title_case") + def left(self, nchars: Series) -> Series: return self._eval_expressions("left", nchars) diff --git a/src/daft-functions-utf8/Cargo.toml b/src/daft-functions-utf8/Cargo.toml index f299e0ef9f..26c055e2e3 100644 --- a/src/daft-functions-utf8/Cargo.toml +++ b/src/daft-functions-utf8/Cargo.toml @@ -7,6 +7,7 @@ chrono-tz = {workspace = true} common-error = {path = "../common/error", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-dsl = {path = "../daft-dsl", default-features = false} +heck = "0.5.0" itertools = {workspace = true} num-traits = {workspace = true} regex = {workspace = true} diff --git a/src/daft-functions-utf8/src/case.rs b/src/daft-functions-utf8/src/case.rs new file mode 100644 index 0000000000..f1e73167c2 --- /dev/null +++ b/src/daft-functions-utf8/src/case.rs @@ -0,0 +1,100 @@ +use common_error::DaftResult; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::{IntoSeries, Series}, +}; +use daft_dsl::{ + ExprRef, + functions::{FunctionArgs, ScalarUDF, scalar::ScalarFn}, +}; +use heck::{ + ToKebabCase, ToLowerCamelCase, ToShoutyKebabCase, ToShoutySnakeCase, ToSnakeCase, ToTitleCase, + ToUpperCamelCase, +}; +use serde::{Deserialize, Serialize}; + +use crate::utils::{Utf8ArrayUtils, unary_utf8_evaluate, unary_utf8_to_field}; + +macro_rules! define_case_udf { + ($struct:ident, $fn_name:ident, $method:ident, $docstring:literal) => { + #[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] + pub struct $struct; + + #[typetag::serde] + impl ScalarUDF for $struct { + fn name(&self) -> &'static str { + stringify!($fn_name) + } + + fn call(&self, inputs: FunctionArgs) -> DaftResult { + unary_utf8_evaluate(inputs, |s| { + s.with_utf8_array(|arr| { + Ok(arr + .unary_broadcasted_op(|val| val.$method().into())? + .into_series()) + }) + }) + } + + fn get_return_field( + &self, + inputs: FunctionArgs, + schema: &Schema, + ) -> DaftResult { + unary_utf8_to_field(inputs, schema, self.name(), DataType::Utf8) + } + + fn docstring(&self) -> &'static str { + $docstring + } + } + + #[must_use] + pub fn $fn_name(input: ExprRef) -> ExprRef { + ScalarFn::builtin($struct, vec![input]).into() + } + }; +} + +define_case_udf!( + CamelCase, + to_camel_case, + to_lower_camel_case, + "Converts a string to lower camel case." +); +define_case_udf!( + UpperCamelCase, + to_upper_camel_case, + to_upper_camel_case, + "Converts a string to upper camel case." +); +define_case_udf!( + SnakeCase, + to_snake_case, + to_snake_case, + "Converts a string to snake case." +); +define_case_udf!( + UpperSnakeCase, + to_upper_snake_case, + to_shouty_snake_case, + "Converts a string to upper snake case." +); +define_case_udf!( + KebabCase, + to_kebab_case, + to_kebab_case, + "Converts a string to kebab case." +); +define_case_udf!( + UpperKebabCase, + to_upper_kebab_case, + to_shouty_kebab_case, + "Converts a string to upper kebab case." +); +define_case_udf!( + TitleCase, + to_title_case, + to_title_case, + "Converts a string to title case." +); diff --git a/src/daft-functions-utf8/src/lib.rs b/src/daft-functions-utf8/src/lib.rs index 58f83af2a4..f7c195fdd8 100644 --- a/src/daft-functions-utf8/src/lib.rs +++ b/src/daft-functions-utf8/src/lib.rs @@ -1,6 +1,7 @@ #![allow(deprecated, reason = "arrow2 migration")] mod capitalize; +mod case; mod contains; mod count_matches; mod endswith; @@ -33,6 +34,7 @@ mod upper; pub(crate) mod utils; pub use capitalize::*; +pub use case::*; pub use contains::*; pub use count_matches::*; pub use endswith::*; @@ -66,12 +68,14 @@ pub struct Utf8Functions; impl daft_dsl::functions::FunctionModule for Utf8Functions { fn register(parent: &mut daft_dsl::functions::FunctionRegistry) { + parent.add_fn(CamelCase); parent.add_fn(Capitalize); parent.add_fn(Contains); parent.add_fn(CountMatches); parent.add_fn(EndsWith); parent.add_fn(Find); parent.add_fn(ILike); + parent.add_fn(KebabCase); parent.add_fn(Left); parent.add_fn(LengthBytes); parent.add_fn(Like); @@ -83,19 +87,24 @@ impl daft_dsl::functions::FunctionModule for Utf8Functions { parent.add_fn(RegexpExtract); parent.add_fn(RegexpExtractAll); parent.add_fn(RegexpMatch); + parent.add_fn(RegexpReplace); + parent.add_fn(RegexpSplit); parent.add_fn(Repeat); parent.add_fn(Replace); - parent.add_fn(RegexpReplace); parent.add_fn(Reverse); parent.add_fn(Right); parent.add_fn(RPad); parent.add_fn(RStrip); + parent.add_fn(SnakeCase); parent.add_fn(Split); - parent.add_fn(RegexpSplit); parent.add_fn(StartsWith); parent.add_fn(Substr); + parent.add_fn(TitleCase); parent.add_fn(ToDate); parent.add_fn(ToDatetime); parent.add_fn(Upper); + parent.add_fn(UpperCamelCase); + parent.add_fn(UpperKebabCase); + parent.add_fn(UpperSnakeCase); } } diff --git a/tests/dataframe/test_string_case.py b/tests/dataframe/test_string_case.py new file mode 100644 index 0000000000..f35d7bb739 --- /dev/null +++ b/tests/dataframe/test_string_case.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import pytest + +import daft +from daft import col +from daft.expressions import Expression + + +@pytest.mark.parametrize( + ("name", "build_expr", "expected"), + [ + ("camel_case", Expression.to_camel_case, ["helloWorld", "helloWorld", "helloWorld"]), + ("upper_camel_case", Expression.to_upper_camel_case, ["HelloWorld", "HelloWorld", "HelloWorld"]), + ("snake_case", Expression.to_snake_case, ["hello_world", "hello_world", "hello_world"]), + ("upper_snake_case", Expression.to_upper_snake_case, ["HELLO_WORLD", "HELLO_WORLD", "HELLO_WORLD"]), + ("kebab_case", Expression.to_kebab_case, ["hello-world", "hello-world", "hello-world"]), + ("upper_kebab_case", Expression.to_upper_kebab_case, ["HELLO-WORLD", "HELLO-WORLD", "HELLO-WORLD"]), + ("title_case", Expression.to_title_case, ["Hello World", "Hello World", "Hello World"]), + ], +) +def test_dataframe_string_case(name, build_expr, expected) -> None: + df = daft.from_pydict({"text": ["helloWorld", "hello-world", "HelloWorld"]}) + expr = build_expr(col("text")) + result = df.select(expr.alias(name)).to_pydict()[name] + assert result == expected diff --git a/tests/series/test_utf8_ops.py b/tests/series/test_utf8_ops.py index 97f840acce..81e694bcc7 100644 --- a/tests/series/test_utf8_ops.py +++ b/tests/series/test_utf8_ops.py @@ -10,6 +10,7 @@ import daft.exceptions from daft import DataType, Series +from daft.series import SeriesStringNamespace @pytest.mark.parametrize( @@ -412,6 +413,24 @@ def test_series_utf8_capitalize(data, expected) -> None: assert result.to_pylist() == expected +@pytest.mark.parametrize( + ("convert", "expected"), + [ + (SeriesStringNamespace.to_camel_case, ["helloWorld", "helloWorld", "helloWorld", None]), + (SeriesStringNamespace.to_upper_camel_case, ["HelloWorld", "HelloWorld", "HelloWorld", None]), + (SeriesStringNamespace.to_snake_case, ["hello_world", "hello_world", "hello_world", None]), + (SeriesStringNamespace.to_upper_snake_case, ["HELLO_WORLD", "HELLO_WORLD", "HELLO_WORLD", None]), + (SeriesStringNamespace.to_kebab_case, ["hello-world", "hello-world", "hello-world", None]), + (SeriesStringNamespace.to_upper_kebab_case, ["HELLO-WORLD", "HELLO-WORLD", "HELLO-WORLD", None]), + (SeriesStringNamespace.to_title_case, ["Hello World", "Hello World", "Hello World", None]), + ], +) +def test_series_utf8_case_conversions(convert, expected) -> None: + s = Series.from_arrow(pa.array(["helloWorld", "hello-world", "HelloWorld", None])) + result = convert(s.str) + assert result.to_pylist() == expected + + @pytest.mark.parametrize( ["data", "pattern", "expected"], [