Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions daft/udf/udf_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> Expression | T:
bound_method = self._cls._daft_bind_method(self._method)
return bound_method(*args, **kwargs)

# When building expression-based UDFs, we must avoid incorrectly sharing call-site state across multiple uses of the same function.
call_seq = getattr(self, "_daft_call_seq", 0)
setattr(self, "_daft_call_seq", call_seq + 1)
call_id = f"{self.func_id}-{call_seq}"

check_serializable(
self._method,
"Daft functions must be serializable. If your function accesses a non-serializable global or nonlocal variable to avoid reinitialization, use `@daft.cls` with a setup method instead.",
Expand All @@ -247,7 +252,7 @@ def method(s: C, *args: P.args, **kwargs: P.kwargs) -> list[Any]:

expr = Expression._from_pyexpr(
row_wise_udf(
self.func_id,
call_id,
self.name,
self._cls,
method,
Expand All @@ -266,7 +271,7 @@ def method(s: C, *args: P.args, **kwargs: P.kwargs) -> list[Any]:
elif self.is_batch:
expr = Expression._from_pyexpr(
batch_udf(
self.func_id,
call_id,
self.name,
self._cls,
self._method,
Expand All @@ -286,7 +291,7 @@ def method(s: C, *args: P.args, **kwargs: P.kwargs) -> list[Any]:
else:
expr = Expression._from_pyexpr(
row_wise_udf(
self.func_id,
call_id,
self.name,
self._cls,
self._method,
Expand Down
23 changes: 23 additions & 0 deletions tests/udf/test_row_wise_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,3 +469,26 @@ def stringify_and_sum(a: int, b: int) -> str:
dynamic_batching_df = df.select("*", stringify_and_sum(col("x"), col("y")).alias("sum"))
dynamic_batching_df = dynamic_batching_df.collect().sort("id")
assert non_dynamic_batching_df.to_pydict() == dynamic_batching_df.to_pydict()


def test_row_wise_udf_kwargs_prefix_suffix_literals_and_exprs():
@daft.func
def format_number(value: int, prefix: str = "$", suffix: str = "") -> str:
return f"{prefix}{value}{suffix}"

df = daft.from_pydict({"amount": [10, 20, 30]})
df = df.with_column("dollar", format_number(df["amount"]))
df = df.with_column("euro", format_number(df["amount"], prefix="€", suffix=" EUR"))
df = df.with_column(
"customized",
format_number(df["amount"], suffix=df["amount"].cast(DataType.string())),
)

result = df.to_pydict()
expected = {
"amount": [10, 20, 30],
"dollar": ["$10", "$20", "$30"],
"euro": ["€10 EUR", "€20 EUR", "€30 EUR"],
"customized": ["$1010", "$2020", "$3030"],
}
assert result == expected
Loading