Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 34 additions & 18 deletions bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4801,37 +4801,53 @@ def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
)

# Apply the function
result_series = rows_as_json_series._apply_unary_op(
ops.RemoteFunctionOp(function_def=func.udf_def, apply_on_null=True)
)
if args:
result_series = rows_as_json_series._apply_nary_op(
ops.NaryRemoteFunctionOp(function_def=func.udf_def),
list(args),
)
else:
result_series = rows_as_json_series._apply_unary_op(
ops.RemoteFunctionOp(
function_def=func.udf_def, apply_on_null=True
)
)
else:
# This is a special case where we are providing not-pandas-like
# extension. If the bigquery function can take one or more
# params then we assume that here the user intention is to use
# the column values of the dataframe as arguments to the
# function. For this to work the following condition must be
# true:
# 1. The number or input params in the function must be same
# as the number of columns in the dataframe
# params (exclude the args) then we assume that here the user
# intention is to use the column values of the dataframe as
# arguments to the function. For this to work the following
# condition must be true:
# 1. The number or input params (exclude the args) in the
# function must be same as the number of columns in the
# dataframe.
# 2. The dtypes of the columns in the dataframe must be
# compatible with the data types of the input params
# compatible with the data types of the input params.
# 3. The order of the columns in the dataframe must correspond
# to the order of the input params in the function
# to the order of the input params in the function.
udf_input_dtypes = func.udf_def.signature.bf_input_types
if len(udf_input_dtypes) != len(self.columns):
if len(udf_input_dtypes) != len(self.columns) + len(args):
raise ValueError(
f"BigFrames BigQuery function takes {len(udf_input_dtypes)}"
f" arguments but DataFrame has {len(self.columns)} columns."
f"Column count mismatch: BigFrames BigQuery function"
f" expected {len(udf_input_dtypes) - len(args)} columns"
f" from DataFrame but received {len(self.columns)}."
)
if udf_input_dtypes != tuple(self.dtypes.to_list()):
end_slice = -len(args) if args else None
if udf_input_dtypes[:end_slice] != tuple(self.dtypes.to_list()):
raise ValueError(
f"BigFrames BigQuery function takes arguments of types "
f"{udf_input_dtypes} but DataFrame dtypes are {tuple(self.dtypes)}."
f"Data type mismatch: BigFrames BigQuery function takes"
f" arguments of types {udf_input_dtypes} but DataFrame"
f" dtypes are {tuple(self.dtypes)}."
)

series_list = [self[col] for col in self.columns]
if args:
op_list = series_list[1:] + list(args)
else:
op_list = series_list[1:]
result_series = series_list[0]._apply_nary_op(
ops.NaryRemoteFunctionOp(function_def=func.udf_def), series_list[1:]
ops.NaryRemoteFunctionOp(function_def=func.udf_def), op_list
)
result_series.name = None

Expand Down
18 changes: 11 additions & 7 deletions bigframes/functions/_function_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,11 +959,15 @@ def _convert_row_processor_sig(
) -> Optional[inspect.Signature]:
import bigframes.series as bf_series

if len(signature.parameters) == 1:
only_param = next(iter(signature.parameters.values()))
param_type = only_param.annotation
if (param_type == bf_series.Series) or (param_type == pandas.Series):
msg = bfe.format_message("input_types=Series is in preview.")
warnings.warn(msg, stacklevel=1, category=bfe.PreviewWarning)
return signature.replace(parameters=[only_param.replace(annotation=str)])
first_param = next(iter(signature.parameters.values()))
param_type = first_param.annotation
if (param_type == bf_series.Series) or (param_type == pandas.Series):
msg = bfe.format_message("input_types=Series is in preview.")
warnings.warn(msg, stacklevel=1, category=bfe.PreviewWarning)
return signature.replace(
parameters=[
p.replace(annotation=str) if i == 0 else p
for i, p in enumerate(signature.parameters.values())
]
)
return None
7 changes: 0 additions & 7 deletions bigframes/functions/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,6 @@ def read_gbq_function(
ValueError, f"Unknown function '{routine_ref}'."
)

if is_row_processor and len(routine.arguments) > 1:
raise bf_formatting.create_exception_with_feedback_link(
ValueError,
"A multi-input function cannot be a row processor. A row processor function "
"takes in a single input representing the row.",
)

if is_row_processor:
return _try_import_row_routine(routine, session)
else:
Expand Down
26 changes: 25 additions & 1 deletion bigframes/functions/function_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,9 @@ def udf_http_row_processor(request):
calls = request_json["calls"]
replies = []
for call in calls:
reply = convert_to_bq_json(output_type, udf(get_pd_series(call[0])))
reply = convert_to_bq_json(
output_type, udf(get_pd_series(call[0]), *call[1:])
)
if type(reply) is list:
# Since the BQ remote function does not support array yet,
# return a json serialized version of the reply.
Expand Down Expand Up @@ -332,6 +334,28 @@ def generate_managed_function_code(
f"""def bigframes_handler(str_arg):
return {udf_name}({get_pd_series.__name__}(str_arg))"""
)

sig = inspect.signature(def_)
params = list(sig.parameters.values())
additional_params = params[1:]

# Build the parameter list for the new handler function definition.
# e.g., "str_arg, y: bool, z"
handler_def_parts = ["str_arg"]
handler_def_parts.extend(str(p) for p in additional_params)
handler_def_str = ", ".join(handler_def_parts)

# Build the argument list for the call to the original UDF.
# e.g., "get_pd_series(str_arg), y, z"
udf_call_parts = [f"{get_pd_series.__name__}(str_arg)"]
udf_call_parts.extend(p.name for p in additional_params)
udf_call_str = ", ".join(udf_call_parts)

bigframes_handler_code = textwrap.dedent(
f"""def bigframes_handler({handler_def_str}):
return {udf_name}({udf_call_str})"""
)

else:
udf_code = ""
bigframes_handler_code = textwrap.dedent(
Expand Down
110 changes: 107 additions & 3 deletions tests/system/large/functions/test_managed_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,20 +468,20 @@ def foo(x, y, z):
# Fails to apply on dataframe with incompatible number of columns.
with pytest.raises(
ValueError,
match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 2 columns\\.$",
match="^Column count mismatch: BigFrames BigQuery function expected 3 columns from DataFrame but received 2\\.$",
):
bf_df[["Id", "Age"]].apply(foo, axis=1)

with pytest.raises(
ValueError,
match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 4 columns\\.$",
match="^Column count mismatch: BigFrames BigQuery function expected 3 columns from DataFrame but received 4\\.$",
):
bf_df.assign(Country="lalaland").apply(foo, axis=1)

# Fails to apply on dataframe with incompatible column datatypes.
with pytest.raises(
ValueError,
match="^BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*",
match="^Data type mismatch: BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*",
):
bf_df.assign(Age=bf_df["Age"].astype("Int64")).apply(foo, axis=1)

Expand Down Expand Up @@ -965,6 +965,110 @@ def float_parser(row):
)


def test_managed_function_df_apply_axis_1_args(session, dataset_id, scalars_dfs):
columns = ["int64_col", "int64_too"]
scalars_df, scalars_pandas_df = scalars_dfs

try:

def the_sum(s1, s2, x):
return s1 + s2 + x

the_sum_mf = session.udf(
input_types=[int, int, int],
output_type=int,
dataset=dataset_id,
name=prefixer.create_prefix(),
)(the_sum)

args1 = (1,)

# Fails to apply on dataframe with incompatible number of columns.
with pytest.raises(
ValueError,
match="^Column count mismatch: BigFrames BigQuery function expected 2 columns from DataFrame but received 3\\.$",
):
scalars_df[columns + ["float64_col"]].apply(the_sum_mf, axis=1, args=args1)

# Fails to apply on dataframe with incompatible column datatypes.
with pytest.raises(
ValueError,
match="^Data type mismatch: BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*",
):
scalars_df[columns].assign(
int64_col=lambda df: df["int64_col"].astype("Float64")
).apply(the_sum_mf, axis=1, args=args1)

bf_result = (
scalars_df[columns]
.dropna()
.apply(the_sum_mf, axis=1, args=args1)
.to_pandas()
)
pd_result = scalars_pandas_df[columns].dropna().apply(sum, axis=1, args=args1)

pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)

finally:
# clean up the gcp assets created for the managed function.
cleanup_function_assets(the_sum_mf, session.bqclient, ignore_failures=False)


def test_managed_function_df_apply_axis_1_series_args(session, dataset_id, scalars_dfs):
columns = ["int64_col", "float64_col"]
scalars_df, scalars_pandas_df = scalars_dfs

try:

def analyze(s, x, y):
value = f"value is {s['int64_col']} and {s['float64_col']}"
if x:
return f"{value}, x is True!"
if y > 0:
return f"{value}, x is False, y is positive!"
return f"{value}, x is False, y is non-positive!"

analyze_mf = session.udf(
input_types=[bigframes.series.Series, bool, float],
output_type=str,
dataset=dataset_id,
name=prefixer.create_prefix(),
)(analyze)

args1 = (True, 10.0)
bf_result = (
scalars_df[columns]
.dropna()
.apply(analyze_mf, axis=1, args=args1)
.to_pandas()
)
pd_result = (
scalars_pandas_df[columns].dropna().apply(analyze, axis=1, args=args1)
)

pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)

args2 = (False, -10.0)
analyze_mf_ref = session.read_gbq_function(
analyze_mf.bigframes_bigquery_function, is_row_processor=True
)
bf_result = (
scalars_df[columns]
.dropna()
.apply(analyze_mf_ref, axis=1, args=args2)
.to_pandas()
)
pd_result = (
scalars_pandas_df[columns].dropna().apply(analyze, axis=1, args=args2)
)

pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)

finally:
# clean up the gcp assets created for the managed function.
cleanup_function_assets(analyze_mf, session.bqclient, ignore_failures=False)


def test_managed_function_df_where_mask(session, dataset_id, scalars_dfs):
try:

Expand Down
Loading