Skip to content

Commit 15bc0ad

Browse files
committed
Improve handling of udf when user provides a class instead of bare function
1 parent ca1352e commit 15bc0ad

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

python/datafusion/udf.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,16 @@ def __init__(
8686
self,
8787
name: str | None,
8888
func: Callable[..., _R],
89-
input_types: list[pyarrow.DataType],
89+
input_types: pyarrow.DataType | list[pyarrow.DataType],
9090
return_type: _R,
9191
volatility: Volatility | str,
9292
) -> None:
9393
"""Instantiate a scalar user-defined function (UDF).
9494
9595
See helper method :py:func:`udf` for argument details.
9696
"""
97+
if isinstance(input_types, pyarrow.DataType):
98+
input_types = [input_types]
9799
self._udf = df_internal.ScalarUDF(
98100
name, func, input_types, return_type, str(volatility)
99101
)
@@ -133,7 +135,10 @@ def udf(
133135
if not callable(func):
134136
raise TypeError("`func` argument must be callable")
135137
if name is None:
136-
name = func.__qualname__.lower()
138+
if hasattr(func, "__qualname__"):
139+
name = func.__qualname__.lower()
140+
else:
141+
name = func.__class__.__name__.lower()
137142
return ScalarUDF(
138143
name=name,
139144
func=func,

0 commit comments

Comments
 (0)