diff --git a/bigframes/core/compile/sqlglot/expressions/generic_ops.py b/bigframes/core/compile/sqlglot/expressions/generic_ops.py index 27973ef8b5..4a2a5fb213 100644 --- a/bigframes/core/compile/sqlglot/expressions/generic_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/generic_ops.py @@ -140,6 +140,19 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: return sge.Coalesce(this=left.expr, expressions=[right.expr]) +@register_binary_op(ops.BinaryRemoteFunctionOp, pass_op=True) +def _( + left: TypedExpr, right: TypedExpr, op: ops.BinaryRemoteFunctionOp +) -> sge.Expression: + routine_ref = op.function_def.routine_ref + # Quote project, dataset, and routine IDs to avoid keyword clashes. + func_name = ( + f"`{routine_ref.project}`.`{routine_ref.dataset_id}`.`{routine_ref.routine_id}`" + ) + + return sge.func(func_name, left.expr, right.expr) + + @register_nary_op(ops.case_when_op) def _(*cases_and_outputs: TypedExpr) -> sge.Expression: # Need to upcast BOOL to INT if any output is numeric diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_binary_remote_function_op/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_binary_remote_function_op/out.sql new file mode 100644 index 0000000000..7272a3a5be --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_binary_remote_function_op/out.sql @@ -0,0 +1,14 @@ +WITH `bfcte_0` AS ( + SELECT + `float64_col`, + `int64_col` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + `my_project`.`my_dataset`.`my_routine`(`int64_col`, `float64_col`) AS `bfcol_2` + FROM `bfcte_0` +) +SELECT + `bfcol_2` AS `int64_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py index 11daf6813a..5657874eb5 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py @@ -168,6 +168,43 @@ def test_astype_json_invalid( ) +def test_binary_remote_function_op(scalar_types_df: bpd.DataFrame, snapshot): + from google.cloud import bigquery + + from bigframes.functions import udf_def + + bf_df = scalar_types_df[["int64_col", "float64_col"]] + op = ops.BinaryRemoteFunctionOp( + function_def=udf_def.BigqueryUdf( + routine_ref=bigquery.RoutineReference.from_string( + "my_project.my_dataset.my_routine" + ), + signature=udf_def.UdfSignature( + input_types=( + udf_def.UdfField( + "x", + bigquery.StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.INT64 + ), + ), + udf_def.UdfField( + "y", + bigquery.StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.FLOAT64 + ), + ), + ), + output_bq_type=bigquery.StandardSqlDataType( + type_kind=bigquery.StandardSqlTypeNames.FLOAT64 + ), + ), + ) + ) + sql = utils._apply_binary_op(bf_df, op, "int64_col", "float64_col") + + snapshot.assert_match(sql, "out.sql") + + def test_case_when_op(scalar_types_df: bpd.DataFrame, snapshot): ops_map = { "single_case": ops.case_when_op.as_expr(