Skip to content

Commit 23d6fb4

Browse files
authored
refactor: add compile_join (#1851)
1 parent c289f70 commit 23d6fb4

File tree

6 files changed

+164
-19
lines changed

6 files changed

+164
-19
lines changed

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from bigframes.core import expression, guid, identifiers, nodes, pyarrow_utils, rewrite
2424
from bigframes.core.compile import configs
25+
from bigframes.core.compile.sqlglot.expressions import typed_expr
2526
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2627
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
2728
import bigframes.core.ordering as bf_ordering
@@ -218,6 +219,29 @@ def compile_filter(
218219
condition = scalar_compiler.compile_scalar_expression(node.predicate)
219220
return child.filter(condition)
220221

222+
@_compile_node.register
223+
def compile_join(
224+
self, node: nodes.JoinNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR
225+
) -> ir.SQLGlotIR:
226+
conditions = tuple(
227+
(
228+
typed_expr.TypedExpr(
229+
scalar_compiler.compile_scalar_expression(left), left.output_type
230+
),
231+
typed_expr.TypedExpr(
232+
scalar_compiler.compile_scalar_expression(right), right.output_type
233+
),
234+
)
235+
for left, right in node.conditions
236+
)
237+
238+
return left.join(
239+
right,
240+
join_type=node.type,
241+
conditions=conditions,
242+
joins_nulls=node.joins_nulls,
243+
)
244+
221245
@_compile_node.register
222246
def compile_concat(
223247
self, node: nodes.ConcatNode, *children: ir.SQLGlotIR

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from bigframes import dtypes
2828
from bigframes.core import guid
29+
from bigframes.core.compile.sqlglot.expressions import typed_expr
2930
import bigframes.core.compile.sqlglot.sqlglot_types as sgt
3031
import bigframes.core.local_data as local_data
3132
import bigframes.core.schema as bf_schema
@@ -212,7 +213,8 @@ def select(
212213
for id, expr in selected_cols
213214
]
214215

215-
new_expr = self._encapsulate_as_cte().select(*selections, append=False)
216+
new_expr, _ = self._encapsulate_as_cte()
217+
new_expr = new_expr.select(*selections, append=False)
216218
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
217219

218220
def order_by(
@@ -247,19 +249,52 @@ def project(
247249
)
248250
for id, expr in projected_cols
249251
]
250-
new_expr = self._encapsulate_as_cte().select(*projected_cols_expr, append=True)
252+
new_expr, _ = self._encapsulate_as_cte()
253+
new_expr = new_expr.select(*projected_cols_expr, append=True)
251254
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
252255

253256
def filter(
254257
self,
255258
condition: sge.Expression,
256259
) -> SQLGlotIR:
257260
"""Filters the query with the given condition."""
258-
new_expr = self._encapsulate_as_cte()
261+
new_expr, _ = self._encapsulate_as_cte()
259262
return SQLGlotIR(
260263
expr=new_expr.where(condition, append=False), uid_gen=self.uid_gen
261264
)
262265

266+
def join(
267+
self,
268+
right: SQLGlotIR,
269+
join_type: typing.Literal["inner", "outer", "left", "right", "cross"],
270+
conditions: tuple[tuple[typed_expr.TypedExpr, typed_expr.TypedExpr], ...],
271+
*,
272+
joins_nulls: bool = True,
273+
) -> SQLGlotIR:
274+
"""Joins the current query with another SQLGlotIR instance."""
275+
left_select, left_table = self._encapsulate_as_cte()
276+
right_select, right_table = right._encapsulate_as_cte()
277+
278+
left_ctes = left_select.args.pop("with", [])
279+
right_ctes = right_select.args.pop("with", [])
280+
merged_ctes = [*left_ctes, *right_ctes]
281+
282+
join_conditions = [
283+
_join_condition(left, right, joins_nulls) for left, right in conditions
284+
]
285+
join_on = sge.And(expressions=join_conditions) if join_conditions else None
286+
287+
join_type_str = join_type if join_type != "outer" else "full outer"
288+
new_expr = (
289+
sge.Select()
290+
.select(sge.Star())
291+
.from_(left_table)
292+
.join(right_table, on=join_on, join_type=join_type_str)
293+
)
294+
new_expr.set("with", sge.With(expressions=merged_ctes))
295+
296+
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
297+
263298
def insert(
264299
self,
265300
destination: bigquery.TableReference,
@@ -320,12 +355,12 @@ def _explode_single_column(
320355
offset=offset,
321356
)
322357
selection = sge.Star(replace=[unnested_column_alias.as_(column)])
358+
323359
# TODO: "CROSS" if not keep_empty else "LEFT"
324360
# TODO: overlaps_with_parent to replace existing column.
325-
new_expr = (
326-
self._encapsulate_as_cte()
327-
.select(selection, append=False)
328-
.join(unnest_expr, join_type="CROSS")
361+
new_expr, _ = self._encapsulate_as_cte()
362+
new_expr = new_expr.select(selection, append=False).join(
363+
unnest_expr, join_type="CROSS"
329364
)
330365
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
331366

@@ -373,16 +408,15 @@ def _explode_multiple_columns(
373408
for column in columns
374409
]
375410
)
376-
new_expr = (
377-
self._encapsulate_as_cte()
378-
.select(selection, append=False)
379-
.join(unnest_expr, join_type="CROSS")
411+
new_expr, _ = self._encapsulate_as_cte()
412+
new_expr = new_expr.select(selection, append=False).join(
413+
unnest_expr, join_type="CROSS"
380414
)
381415
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
382416

383417
def _encapsulate_as_cte(
384418
self,
385-
) -> sge.Select:
419+
) -> typing.Tuple[sge.Select, sge.Table]:
386420
"""Transforms a given sge.Select query by pushing its main SELECT statement
387421
into a new CTE and then generates a 'SELECT * FROM new_cte_name'
388422
for the new query."""
@@ -397,11 +431,10 @@ def _encapsulate_as_cte(
397431
alias=new_cte_name,
398432
)
399433
new_with_clause = sge.With(expressions=[*existing_ctes, new_cte])
400-
new_select_expr = (
401-
sge.Select().select(sge.Star()).from_(sge.Table(this=new_cte_name))
402-
)
434+
new_table_expr = sge.Table(this=new_cte_name)
435+
new_select_expr = sge.Select().select(sge.Star()).from_(new_table_expr)
403436
new_select_expr.set("with", new_with_clause)
404-
return new_select_expr
437+
return new_select_expr, new_table_expr
405438

406439

407440
def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
@@ -451,3 +484,11 @@ def _table(table: bigquery.TableReference) -> sge.Table:
451484
db=sg.to_identifier(table.dataset_id, quoted=True),
452485
catalog=sg.to_identifier(table.project, quoted=True),
453486
)
487+
488+
489+
def _join_condition(
490+
left: typed_expr.TypedExpr,
491+
right: typed_expr.TypedExpr,
492+
joins_nulls: bool,
493+
) -> typing.Union[sge.EQ, sge.And]:
494+
return sge.EQ(this=left.expr, expression=right.expr)

bigframes/dataframe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3369,8 +3369,6 @@ def merge(
33693369
"right",
33703370
"cross",
33713371
] = "inner",
3372-
# TODO(garrettwu): Currently can take inner, outer, left and right. To support
3373-
# cross joins
33743372
on: Union[blocks.Label, Sequence[blocks.Label], None] = None,
33753373
*,
33763374
left_on: Union[blocks.Label, Sequence[blocks.Label], None] = None,
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
WITH `bfcte_1` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`,
4+
`rowindex` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_2` AS (
7+
SELECT
8+
`bfcol_1` AS `bfcol_2`,
9+
`bfcol_0` AS `bfcol_3`
10+
FROM `bfcte_1`
11+
), `bfcte_0` AS (
12+
SELECT
13+
`int64_col` AS `bfcol_4`,
14+
`int64_too` AS `bfcol_5`
15+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
16+
), `bfcte_3` AS (
17+
SELECT
18+
`bfcol_4` AS `bfcol_6`,
19+
`bfcol_5` AS `bfcol_7`
20+
FROM `bfcte_0`
21+
), `bfcte_4` AS (
22+
SELECT
23+
*
24+
FROM `bfcte_2`
25+
LEFT JOIN `bfcte_3`
26+
ON `bfcol_2` = `bfcol_6`
27+
)
28+
SELECT
29+
`bfcol_3` AS `int64_col`,
30+
`bfcol_7` AS `int64_too`
31+
FROM `bfcte_4`
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
import bigframes.pandas as bpd
18+
19+
pytest.importorskip("pytest_snapshot")
20+
21+
22+
def test_compile_join(scalar_types_df: bpd.DataFrame, snapshot):
23+
left = scalar_types_df[["int64_col"]]
24+
right = scalar_types_df.set_index("int64_col")[["int64_too"]]
25+
join = left.join(right)
26+
snapshot.assert_match(join.sql, "out.sql")
27+
28+
29+
def test_compile_join_w_how(scalar_types_df: bpd.DataFrame):
30+
left = scalar_types_df[["int64_col"]]
31+
right = scalar_types_df.set_index("int64_col")[["int64_too"]]
32+
33+
join_sql = left.join(right, how="left").sql
34+
assert "LEFT JOIN" in join_sql
35+
assert "ON" in join_sql
36+
37+
join_sql = left.join(right, how="right").sql
38+
assert "RIGHT JOIN" in join_sql
39+
assert "ON" in join_sql
40+
41+
join_sql = left.join(right, how="outer").sql
42+
assert "FULL OUTER JOIN" in join_sql
43+
assert "ON" in join_sql
44+
45+
join_sql = left.join(right, how="inner").sql
46+
assert "INNER JOIN" in join_sql
47+
assert "ON" in join_sql
48+
49+
join_sql = left.merge(right, how="cross").sql
50+
assert "CROSS JOIN" in join_sql
51+
assert "ON" not in join_sql

third_party/bigframes_vendored/pandas/core/frame.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4748,7 +4748,7 @@ def merge(
47484748
right:
47494749
Object to merge with.
47504750
how:
4751-
``{'left', 'right', 'outer', 'inner'}, default 'inner'``
4751+
``{'left', 'right', 'outer', 'inner', 'cross'}, default 'inner'``
47524752
Type of merge to be performed.
47534753
``left``: use only keys from left frame, similar to a SQL left outer join;
47544754
preserve key order.

0 commit comments

Comments
 (0)