Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 55 additions & 18 deletions bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
import bigframes.exceptions as bfe
import bigframes.formatting_helpers as formatter
import bigframes.functions
from bigframes.functions import function_typing
import bigframes.operations as ops
import bigframes.operations.aggregations as agg_ops
import bigframes.operations.ai
Expand Down Expand Up @@ -4801,37 +4802,73 @@ def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
)

# Apply the function
result_series = rows_as_json_series._apply_unary_op(
ops.RemoteFunctionOp(function_def=func.udf_def, apply_on_null=True)
)
if args:
result_series = rows_as_json_series._apply_nary_op(
ops.NaryRemoteFunctionOp(function_def=func.udf_def),
list(args),
)
else:
result_series = rows_as_json_series._apply_unary_op(
ops.RemoteFunctionOp(
function_def=func.udf_def, apply_on_null=True
)
)
else:
# This is a special case where we are providing not-pandas-like
# extension. If the bigquery function can take one or more
# params then we assume that here the user intention is to use
# the column values of the dataframe as arguments to the
# function. For this to work the following condition must be
# true:
# 1. The number or input params in the function must be same
# as the number of columns in the dataframe
# params (excluding the args) then we assume that here the user
# intention is to use the column values of the dataframe as
# arguments to the function. For this to work the following
# condition must be true:
# 1. The number or input params (excluding the args) in the
# function must be same as the number of columns in the
# dataframe.
# 2. The dtypes of the columns in the dataframe must be
# compatible with the data types of the input params
# compatible with the data types of the input params.
# 3. The order of the columns in the dataframe must correspond
# to the order of the input params in the function
# to the order of the input params in the function.
udf_input_dtypes = func.udf_def.signature.bf_input_types
if len(udf_input_dtypes) != len(self.columns):
if not args and len(udf_input_dtypes) != len(self.columns):
raise ValueError(
f"BigFrames BigQuery function takes {len(udf_input_dtypes)}"
f" arguments but DataFrame has {len(self.columns)} columns."
f"Parameter count mismatch: BigFrames BigQuery function"
f" expected {len(udf_input_dtypes)} parameters but"
f" received {len(self.columns)} DataFrame columns."
)
if udf_input_dtypes != tuple(self.dtypes.to_list()):
if args and len(udf_input_dtypes) != len(self.columns) + len(args):
raise ValueError(
f"BigFrames BigQuery function takes arguments of types "
f"{udf_input_dtypes} but DataFrame dtypes are {tuple(self.dtypes)}."
f"Parameter count mismatch: BigFrames BigQuery function"
f" expected {len(udf_input_dtypes)} parameters but"
f" received {len(self.columns) + len(args)} values"
f" ({len(self.columns)} DataFrame coulmns and"
f" {len(args)} args)."
)
end_slice = -len(args) if args else None
if udf_input_dtypes[:end_slice] != tuple(self.dtypes.to_list()):
raise ValueError(
f"Data type mismatch for DataFrame columns:"
f" Expected {udf_input_dtypes[:end_slice]}"
f" Received {tuple(self.dtypes)}."
)
if args:
bq_types = (
function_typing.sdk_type_from_python_type(type(arg))
for arg in args
)
args_dtype = tuple(
function_typing.sdk_type_to_bf_type(bq_type)
for bq_type in bq_types
)
if udf_input_dtypes[end_slice:] != args_dtype:
raise ValueError(
f"Data type mismatch for 'args' parameter:"
f" Expected {udf_input_dtypes[end_slice:]}"
f" Received {args_dtype}."
)

series_list = [self[col] for col in self.columns]
op_list = series_list[1:] + list(args)
result_series = series_list[0]._apply_nary_op(
ops.NaryRemoteFunctionOp(function_def=func.udf_def), series_list[1:]
ops.NaryRemoteFunctionOp(function_def=func.udf_def), op_list
)
result_series.name = None

Expand Down
13 changes: 9 additions & 4 deletions bigframes/functions/_function_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,11 +959,16 @@ def _convert_row_processor_sig(
) -> Optional[inspect.Signature]:
import bigframes.series as bf_series

if len(signature.parameters) == 1:
only_param = next(iter(signature.parameters.values()))
param_type = only_param.annotation
if len(signature.parameters) >= 1:
first_param = next(iter(signature.parameters.values()))
param_type = first_param.annotation
if (param_type == bf_series.Series) or (param_type == pandas.Series):
msg = bfe.format_message("input_types=Series is in preview.")
warnings.warn(msg, stacklevel=1, category=bfe.PreviewWarning)
return signature.replace(parameters=[only_param.replace(annotation=str)])
return signature.replace(
parameters=[
p.replace(annotation=str) if i == 0 else p
for i, p in enumerate(signature.parameters.values())
]
)
return None
7 changes: 0 additions & 7 deletions bigframes/functions/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,6 @@ def read_gbq_function(
ValueError, f"Unknown function '{routine_ref}'."
)

if is_row_processor and len(routine.arguments) > 1:
raise bf_formatting.create_exception_with_feedback_link(
ValueError,
"A multi-input function cannot be a row processor. A row processor function "
"takes in a single input representing the row.",
)

if is_row_processor:
return _try_import_row_routine(routine, session)
else:
Expand Down
26 changes: 25 additions & 1 deletion bigframes/functions/function_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,9 @@ def udf_http_row_processor(request):
calls = request_json["calls"]
replies = []
for call in calls:
reply = convert_to_bq_json(output_type, udf(get_pd_series(call[0])))
reply = convert_to_bq_json(
output_type, udf(get_pd_series(call[0]), *call[1:])
)
if type(reply) is list:
# Since the BQ remote function does not support array yet,
# return a json serialized version of the reply.
Expand Down Expand Up @@ -332,6 +334,28 @@ def generate_managed_function_code(
f"""def bigframes_handler(str_arg):
return {udf_name}({get_pd_series.__name__}(str_arg))"""
)

sig = inspect.signature(def_)
params = list(sig.parameters.values())
additional_params = params[1:]

# Build the parameter list for the new handler function definition.
# e.g., "str_arg, y: bool, z"
handler_def_parts = ["str_arg"]
handler_def_parts.extend(str(p) for p in additional_params)
handler_def_str = ", ".join(handler_def_parts)

# Build the argument list for the call to the original UDF.
# e.g., "get_pd_series(str_arg), y, z"
udf_call_parts = [f"{get_pd_series.__name__}(str_arg)"]
udf_call_parts.extend(p.name for p in additional_params)
udf_call_str = ", ".join(udf_call_parts)

bigframes_handler_code = textwrap.dedent(
f"""def bigframes_handler({handler_def_str}):
return {udf_name}({udf_call_str})"""
)

else:
udf_code = ""
bigframes_handler_code = textwrap.dedent(
Expand Down
117 changes: 114 additions & 3 deletions tests/system/large/functions/test_managed_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,20 +468,20 @@ def foo(x, y, z):
# Fails to apply on dataframe with incompatible number of columns.
with pytest.raises(
ValueError,
match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 2 columns\\.$",
match="^Parameter count mismatch:.* expected 3 parameters but received 2 DataFrame.*",
):
bf_df[["Id", "Age"]].apply(foo, axis=1)

with pytest.raises(
ValueError,
match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 4 columns\\.$",
match="^Parameter count mismatch:.* expected 3 parameters but received 4 DataFrame.*",
):
bf_df.assign(Country="lalaland").apply(foo, axis=1)

# Fails to apply on dataframe with incompatible column datatypes.
with pytest.raises(
ValueError,
match="^BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*",
match="^Data type mismatch for DataFrame columns: Expected .* Received .*",
):
bf_df.assign(Age=bf_df["Age"].astype("Int64")).apply(foo, axis=1)

Expand Down Expand Up @@ -965,6 +965,117 @@ def float_parser(row):
)


def test_managed_function_df_apply_axis_1_args(session, dataset_id, scalars_dfs):
columns = ["int64_col", "int64_too"]
scalars_df, scalars_pandas_df = scalars_dfs

try:

def the_sum(s1, s2, x):
return s1 + s2 + x

the_sum_mf = session.udf(
input_types=[int, int, int],
output_type=int,
dataset=dataset_id,
name=prefixer.create_prefix(),
)(the_sum)

args1 = (1,)

# Fails to apply on dataframe with incompatible number of columns.
with pytest.raises(
ValueError,
match="^Parameter count mismatch:.* expected 3 parameters but received 4 values.*",
):
scalars_df[columns + ["float64_col"]].apply(the_sum_mf, axis=1, args=args1)

# Fails to apply on dataframe with incompatible column datatypes.
with pytest.raises(
ValueError,
match="^Data type mismatch for DataFrame columns: Expected .* Received .*",
):
scalars_df[columns].assign(
int64_col=lambda df: df["int64_col"].astype("Float64")
).apply(the_sum_mf, axis=1, args=args1)

# Fails to apply on dataframe with incompatible args datatypes.
with pytest.raises(
ValueError,
match="^Data type mismatch for 'args' parameter: Expected .* Received .*",
):
scalars_df[columns].apply(the_sum_mf, axis=1, args=(1.3,))

bf_result = (
scalars_df[columns]
.dropna()
.apply(the_sum_mf, axis=1, args=args1)
.to_pandas()
)
pd_result = scalars_pandas_df[columns].dropna().apply(sum, axis=1, args=args1)

pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)

finally:
# clean up the gcp assets created for the managed function.
cleanup_function_assets(the_sum_mf, session.bqclient, ignore_failures=False)


def test_managed_function_df_apply_axis_1_series_args(session, dataset_id, scalars_dfs):
columns = ["int64_col", "float64_col"]
scalars_df, scalars_pandas_df = scalars_dfs

try:

def analyze(s, x, y):
value = f"value is {s['int64_col']} and {s['float64_col']}"
if x:
return f"{value}, x is True!"
if y > 0:
return f"{value}, x is False, y is positive!"
return f"{value}, x is False, y is non-positive!"

analyze_mf = session.udf(
input_types=[bigframes.series.Series, bool, float],
output_type=str,
dataset=dataset_id,
name=prefixer.create_prefix(),
)(analyze)

args1 = (True, 10.0)
bf_result = (
scalars_df[columns]
.dropna()
.apply(analyze_mf, axis=1, args=args1)
.to_pandas()
)
pd_result = (
scalars_pandas_df[columns].dropna().apply(analyze, axis=1, args=args1)
)

pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)

args2 = (False, -10.0)
analyze_mf_ref = session.read_gbq_function(
analyze_mf.bigframes_bigquery_function, is_row_processor=True
)
bf_result = (
scalars_df[columns]
.dropna()
.apply(analyze_mf_ref, axis=1, args=args2)
.to_pandas()
)
pd_result = (
scalars_pandas_df[columns].dropna().apply(analyze, axis=1, args=args2)
)

pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)

finally:
# clean up the gcp assets created for the managed function.
cleanup_function_assets(analyze_mf, session.bqclient, ignore_failures=False)


def test_managed_function_df_where_mask(session, dataset_id, scalars_dfs):
try:

Expand Down
Loading