Skip to content

Commit 1df253b

Browse files
authored
fix(udf): ensure per-call kwargs in udf v2 are uniquely bound per call site (#6079)
Fix row-wise/batch UDF v2 so that per-call keyword arguments (including Expression kwargs) are correctly honored and not incorrectly shared across call sites. Add a regression test that mirrors the reported `format_number` example using default, literal, and expression overrides. The v2 UDF wrapper (`daft.udf.udf_v2.Func.__call__`) used a single `func_id` derived from the decorated function to identify all UDF expressions produced by that function. This `func_id` was passed through to the Rust `row_wise_udf` / `batch_udf` builders and ultimately into the logical plan as part of `RowWisePyFn` / batch UDF metadata. Because all logical UDF nodes shared the same `func_id` regardless of their concrete arguments, they could be treated as the *same* expression by downstream components (e.g. optimizations, caching, or expression reuse keyed by this identifier). As a result, multiple calls like: ```python @daft.func def format_number(value: int, prefix: str = "$", suffix: str = "") -> str: return f"{prefix}{value}{suffix}" format_number(df["amount"]) format_number(df["amount"], prefix="€", suffix=" EUR") format_number(df["amount"], suffix=df["amount"].cast(daft.DataType.string())) ``` could end up sharing underlying UDF state keyed only by `func_id`, so that overrides for `prefix` / `suffix` were not reliably respected per call site. Introduce a per-call identifier in `Func.__call__` so that each logical UDF call site is uniquely identified, while still keeping the stable human-readable name for display: - Add a monotonically increasing `_daft_call_seq` counter on `Func` instances. - For each call that involves Expression arguments, derive a `call_id = f"{self.func_id}-{call_seq}"`. - Pass `call_id` instead of `self.func_id` as the `func_id` argument when constructing the underlying `row_wise_udf` / `batch_udf` expressions (for generator, batch, and regular row-wise variants). This keeps the original `name` used for plan display intact, but guarantees that each distinct call site (with its own bound `args`/`kwargs`) has a unique function identifier, preventing unintended sharing across calls. ## Changes Made <!-- Describe what changes were made and why. Include implementation details if necessary. --> ## Related Issues ```python import daft @daft.func def format_number(value: int, prefix: str = "$", suffix: str = "") -> str: return f"{prefix}{value}{suffix}" df = daft.from_pydict({"amount": [10, 20, 30]}) df = df.with_column("dollar", format_number(df["amount"])) df = df.with_column("euro", format_number(df["amount"], prefix="€", suffix=" EUR")) df = df.with_column("customized", format_number(df["amount"], suffix=df["amount"].cast(daft.DataType.string()))) df.show() ``` The result is error: ``` ╭────────┬─────────┬─────────┬────────────╮ │ amount ┆ dollar ┆ euro ┆ customized │ │ --- ┆ --- ┆ --- ┆ --- │ │ Int64 ┆ String ┆ String ┆ String │ ╞════════╪═════════╪═════════╪════════════╡ │ 10 ┆ €10 EUR ┆ €10 EUR ┆ $1010 │ ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤ │ 20 ┆ €20 EUR ┆ €20 EUR ┆ $2020 │ ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤ │ 30 ┆ €30 EUR ┆ €30 EUR ┆ $3030 │ ╰────────┴─────────┴─────────┴────────────╯ (Showing first 3 of 3 rows) ``` <!-- Link to related GitHub issues, e.g., "Closes #123" -->
1 parent 25c189f commit 1df253b

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

daft/udf/udf_v2.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> Expression | T:
230230
bound_method = self._cls._daft_bind_method(self._method)
231231
return bound_method(*args, **kwargs)
232232

233+
# When building expression-based UDFs, we must avoid incorrectly sharing call-site state across multiple uses of the same function.
234+
call_seq = getattr(self, "_daft_call_seq", 0)
235+
setattr(self, "_daft_call_seq", call_seq + 1)
236+
call_id = f"{self.func_id}-{call_seq}"
237+
233238
check_serializable(
234239
self._method,
235240
"Daft functions must be serializable. If your function accesses a non-serializable global or nonlocal variable to avoid reinitialization, use `@daft.cls` with a setup method instead.",
@@ -247,7 +252,7 @@ def method(s: C, *args: P.args, **kwargs: P.kwargs) -> list[Any]:
247252

248253
expr = Expression._from_pyexpr(
249254
row_wise_udf(
250-
self.func_id,
255+
call_id,
251256
self.name,
252257
self._cls,
253258
method,
@@ -266,7 +271,7 @@ def method(s: C, *args: P.args, **kwargs: P.kwargs) -> list[Any]:
266271
elif self.is_batch:
267272
expr = Expression._from_pyexpr(
268273
batch_udf(
269-
self.func_id,
274+
call_id,
270275
self.name,
271276
self._cls,
272277
self._method,
@@ -286,7 +291,7 @@ def method(s: C, *args: P.args, **kwargs: P.kwargs) -> list[Any]:
286291
else:
287292
expr = Expression._from_pyexpr(
288293
row_wise_udf(
289-
self.func_id,
294+
call_id,
290295
self.name,
291296
self._cls,
292297
self._method,

tests/udf/test_row_wise_udf.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,3 +469,26 @@ def stringify_and_sum(a: int, b: int) -> str:
469469
dynamic_batching_df = df.select("*", stringify_and_sum(col("x"), col("y")).alias("sum"))
470470
dynamic_batching_df = dynamic_batching_df.collect().sort("id")
471471
assert non_dynamic_batching_df.to_pydict() == dynamic_batching_df.to_pydict()
472+
473+
474+
def test_row_wise_udf_kwargs_prefix_suffix_literals_and_exprs():
475+
@daft.func
476+
def format_number(value: int, prefix: str = "$", suffix: str = "") -> str:
477+
return f"{prefix}{value}{suffix}"
478+
479+
df = daft.from_pydict({"amount": [10, 20, 30]})
480+
df = df.with_column("dollar", format_number(df["amount"]))
481+
df = df.with_column("euro", format_number(df["amount"], prefix="€", suffix=" EUR"))
482+
df = df.with_column(
483+
"customized",
484+
format_number(df["amount"], suffix=df["amount"].cast(DataType.string())),
485+
)
486+
487+
result = df.to_pydict()
488+
expected = {
489+
"amount": [10, 20, 30],
490+
"dollar": ["$10", "$20", "$30"],
491+
"euro": ["€10 EUR", "€20 EUR", "€30 EUR"],
492+
"customized": ["$1010", "$2020", "$3030"],
493+
}
494+
assert result == expected

0 commit comments

Comments
 (0)