Skip to content

Commit 06a0a65

Browse files
committed
refactor: add compile_isin_join
1 parent 1a0f710 commit 06a0a65

File tree

4 files changed

+112
-0
lines changed

4 files changed

+112
-0
lines changed

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,28 @@ def compile_join(
244244
joins_nulls=node.joins_nulls,
245245
)
246246

247+
@_compile_node.register
248+
def compile_isin_join(
249+
self, node: nodes.InNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR
250+
) -> ir.SQLGlotIR:
251+
conditions = tuple(
252+
typed_expr.TypedExpr(
253+
scalar_compiler.compile_scalar_expression(node.left_col),
254+
left.output_type,
255+
),
256+
typed_expr.TypedExpr(
257+
scalar_compiler.compile_scalar_expression(node.right_col),
258+
right.output_type,
259+
),
260+
)
261+
262+
return left.isin_join(
263+
right,
264+
indicator_col=node.indicator_col.sql,
265+
conditions=conditions,
266+
joins_nulls=node.joins_nulls,
267+
)
268+
247269
@_compile_node.register
248270
def compile_concat(
249271
self, node: nodes.ConcatNode, *children: ir.SQLGlotIR

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,56 @@ def join(
336336

337337
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
338338

339+
def isin_join(
340+
self,
341+
right: SQLGlotIR,
342+
indicator_col: str,
343+
conditions: tuple[typed_expr.TypedExpr, typed_expr.TypedExpr],
344+
*,
345+
joins_nulls: bool = True,
346+
) -> SQLGlotIR:
347+
"""Joins the current query with another SQLGlotIR instance."""
348+
# TODO: Optimization similar to Ibis:
349+
# if isinstance(values, ArrayValue):
350+
# return ops.ArrayContains(values, self).to_expr()
351+
# elif isinstance(values, Column):
352+
# return ops.InSubquery(values.as_table(), needle=self).to_expr()
353+
# else:
354+
# return ops.InValues(self, values).to_expr()
355+
356+
raise NotImplementedError
357+
# left_cte_name = sge.to_identifier(
358+
# next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
359+
# )
360+
# right_cte_name = sge.to_identifier(
361+
# next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
362+
# )
363+
364+
# left_select = _select_to_cte(self.expr, left_cte_name)
365+
# right_select = _select_to_cte(right.expr, right_cte_name)
366+
367+
# left_ctes = left_select.args.pop("with", [])
368+
# right_ctes = right_select.args.pop("with", [])
369+
# merged_ctes = [*left_ctes, *right_ctes]
370+
371+
372+
373+
# join_conditions = [
374+
# _join_condition(left, right, joins_nulls) for left, right in conditions
375+
# ]
376+
# join_on = sge.And(expressions=join_conditions) if join_conditions else None
377+
378+
# join_type_str = join_type if join_type != "outer" else "full outer"
379+
# new_expr = (
380+
# sge.Select()
381+
# .select(sge.Star())
382+
# .from_(sge.Table(this=left_cte_name))
383+
# .join(sge.Table(this=right_cte_name), on=join_on, join_type=join_type_str)
384+
# )
385+
# new_expr.set("with", sge.With(expressions=merged_ctes))
386+
387+
# return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
388+
339389
def explode(
340390
self,
341391
column_names: tuple[str, ...],
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
*
4+
FROM UNNEST(ARRAY<STRUCT<`bfcol_0` FLOAT64, `bfcol_1` INT64>>[STRUCT(314159.0, 0), STRUCT(2.0, 1), STRUCT(3.0, 2), STRUCT(CAST(NULL AS FLOAT64), 3)])
5+
)
6+
SELECT
7+
`bfcol_0` AS `0`
8+
FROM `bfcte_0`
9+
ORDER BY
10+
`bfcol_1` ASC NULLS LAST
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pandas as pd
16+
import pytest
17+
18+
import bigframes
19+
import bigframes.pandas as bpd
20+
21+
pytest.importorskip("pytest_snapshot")
22+
23+
24+
def test_compile_isin(
25+
scalar_types_df: bpd.DataFrame, compiler_session: bigframes.Session, snapshot
26+
):
27+
data = [314159, 2.0, 3, pd.NA]
28+
s = bpd.Series(data, session=compiler_session)
29+
bf_isin = scalar_types_df["int64_col"].isin(s).to_frame()
30+
snapshot.assert_match(bf_isin.sql, "out.sql")

0 commit comments

Comments
 (0)