Skip to content

Commit 5a1d1de

Browse files
authored
refactor: add compile_isin_join (#1886)
1 parent 41e8f33 commit 5a1d1de

File tree

6 files changed

+191
-1
lines changed

6 files changed

+191
-1
lines changed

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,28 @@ def compile_join(
244244
joins_nulls=node.joins_nulls,
245245
)
246246

247+
@_compile_node.register
248+
def compile_isin_join(
249+
self, node: nodes.InNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR
250+
) -> ir.SQLGlotIR:
251+
conditions = (
252+
typed_expr.TypedExpr(
253+
scalar_compiler.compile_scalar_expression(node.left_col),
254+
node.left_col.output_type,
255+
),
256+
typed_expr.TypedExpr(
257+
scalar_compiler.compile_scalar_expression(node.right_col),
258+
node.right_col.output_type,
259+
),
260+
)
261+
262+
return left.isin_join(
263+
right,
264+
indicator_col=node.indicator_col.sql,
265+
conditions=conditions,
266+
joins_nulls=node.joins_nulls,
267+
)
268+
247269
@_compile_node.register
248270
def compile_concat(
249271
self, node: nodes.ConcatNode, *children: ir.SQLGlotIR

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,68 @@ def join(
336336

337337
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
338338

339+
def isin_join(
340+
self,
341+
right: SQLGlotIR,
342+
indicator_col: str,
343+
conditions: tuple[typed_expr.TypedExpr, typed_expr.TypedExpr],
344+
joins_nulls: bool = True,
345+
) -> SQLGlotIR:
346+
"""Joins the current query with another SQLGlotIR instance."""
347+
left_cte_name = sge.to_identifier(
348+
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
349+
)
350+
351+
left_select = _select_to_cte(self.expr, left_cte_name)
352+
# Prefer subquery over CTE for the IN clause's right side to improve SQL readability.
353+
right_select = right.expr
354+
355+
left_ctes = left_select.args.pop("with", [])
356+
right_ctes = right_select.args.pop("with", [])
357+
merged_ctes = [*left_ctes, *right_ctes]
358+
359+
left_condition = typed_expr.TypedExpr(
360+
sge.Column(this=conditions[0].expr, table=left_cte_name),
361+
conditions[0].dtype,
362+
)
363+
364+
new_column: sge.Expression
365+
if joins_nulls:
366+
right_table_name = sge.to_identifier(
367+
next(self.uid_gen.get_uid_stream("bft_")), quoted=self.quoted
368+
)
369+
right_condition = typed_expr.TypedExpr(
370+
sge.Column(this=conditions[1].expr, table=right_table_name),
371+
conditions[1].dtype,
372+
)
373+
new_column = sge.Exists(
374+
this=sge.Select()
375+
.select(sge.convert(1))
376+
.from_(sge.Alias(this=right_select.subquery(), alias=right_table_name))
377+
.where(
378+
_join_condition(left_condition, right_condition, joins_nulls=True)
379+
)
380+
)
381+
else:
382+
new_column = sge.In(
383+
this=left_condition.expr,
384+
expressions=[right_select.subquery()],
385+
)
386+
387+
new_column = sge.Alias(
388+
this=new_column,
389+
alias=sge.to_identifier(indicator_col, quoted=self.quoted),
390+
)
391+
392+
new_expr = (
393+
sge.Select()
394+
.select(sge.Column(this=sge.Star(), table=left_cte_name), new_column)
395+
.from_(sge.Table(this=left_cte_name))
396+
)
397+
new_expr.set("with", sge.With(expressions=merged_ctes))
398+
399+
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
400+
339401
def explode(
340402
self,
341403
column_names: tuple[str, ...],

tests/unit/core/compile/sqlglot/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def scalar_types_table_schema() -> typing.Sequence[bigquery.SchemaField]:
8585
bigquery.SchemaField("numeric_col", "NUMERIC"),
8686
bigquery.SchemaField("float64_col", "FLOAT"),
8787
bigquery.SchemaField("rowindex", "INTEGER"),
88-
bigquery.SchemaField("rowindex_2", "INTEGER"),
88+
bigquery.SchemaField("rowindex_2", "INTEGER", mode="REQUIRED"),
8989
bigquery.SchemaField("string_col", "STRING"),
9090
bigquery.SchemaField("time_col", "TIME"),
9191
bigquery.SchemaField("timestamp_col", "TIMESTAMP"),
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
WITH `bfcte_1` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`,
4+
`rowindex` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_2` AS (
7+
SELECT
8+
`bfcol_1` AS `bfcol_2`,
9+
`bfcol_0` AS `bfcol_3`
10+
FROM `bfcte_1`
11+
), `bfcte_0` AS (
12+
SELECT
13+
`int64_too` AS `bfcol_4`
14+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
15+
), `bfcte_3` AS (
16+
SELECT
17+
`bfcte_2`.*,
18+
EXISTS(
19+
SELECT
20+
1
21+
FROM (
22+
SELECT
23+
`bfcol_4`
24+
FROM `bfcte_0`
25+
GROUP BY
26+
`bfcol_4`
27+
) AS `bft_0`
28+
WHERE
29+
COALESCE(`bfcte_2`.`bfcol_3`, 0) = COALESCE(`bft_0`.`bfcol_4`, 0)
30+
AND COALESCE(`bfcte_2`.`bfcol_3`, 1) = COALESCE(`bft_0`.`bfcol_4`, 1)
31+
) AS `bfcol_5`
32+
FROM `bfcte_2`
33+
)
34+
SELECT
35+
`bfcol_2` AS `rowindex`,
36+
`bfcol_5` AS `int64_col`
37+
FROM `bfcte_3`
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
WITH `bfcte_1` AS (
2+
SELECT
3+
`rowindex` AS `bfcol_0`,
4+
`rowindex_2` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_2` AS (
7+
SELECT
8+
`bfcol_0` AS `bfcol_2`,
9+
`bfcol_1` AS `bfcol_3`
10+
FROM `bfcte_1`
11+
), `bfcte_0` AS (
12+
SELECT
13+
`rowindex_2` AS `bfcol_4`
14+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
15+
), `bfcte_3` AS (
16+
SELECT
17+
`bfcte_2`.*,
18+
`bfcte_2`.`bfcol_3` IN ((
19+
SELECT
20+
`bfcol_4`
21+
FROM `bfcte_0`
22+
GROUP BY
23+
`bfcol_4`
24+
)) AS `bfcol_5`
25+
FROM `bfcte_2`
26+
)
27+
SELECT
28+
`bfcol_2` AS `rowindex`,
29+
`bfcol_5` AS `rowindex_2`
30+
FROM `bfcte_3`
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import sys
16+
17+
import pytest
18+
19+
import bigframes.pandas as bpd
20+
21+
pytest.importorskip("pytest_snapshot")
22+
23+
if sys.version_info < (3, 12):
24+
pytest.skip(
25+
"Skipping test due to inconsistent SQL formatting on Python < 3.12.",
26+
allow_module_level=True,
27+
)
28+
29+
30+
def test_compile_isin(scalar_types_df: bpd.DataFrame, snapshot):
31+
bf_isin = scalar_types_df["int64_col"].isin(scalar_types_df["int64_too"]).to_frame()
32+
snapshot.assert_match(bf_isin.sql, "out.sql")
33+
34+
35+
def test_compile_isin_not_nullable(scalar_types_df: bpd.DataFrame, snapshot):
36+
bf_isin = (
37+
scalar_types_df["rowindex_2"].isin(scalar_types_df["rowindex_2"]).to_frame()
38+
)
39+
snapshot.assert_match(bf_isin.sql, "out.sql")

0 commit comments

Comments
 (0)