Skip to content

Commit 72076c7

Browse files
authored
chore: compile concat nodes by sqlglot (#1824)
* chore: compile concat node * chore: compile concat nodes by sqlglot
1 parent aa32369 commit 72076c7

File tree

5 files changed

+203
-1
lines changed

5 files changed

+203
-1
lines changed

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,17 @@ def compile_projection(
190190
)
191191
return child.project(projected_cols)
192192

193+
@_compile_node.register
194+
def compile_concat(
195+
self, node: nodes.ConcatNode, *children: ir.SQLGlotIR
196+
) -> ir.SQLGlotIR:
197+
output_ids = [id.sql for id in node.output_ids]
198+
return ir.SQLGlotIR.from_union(
199+
[child.expr for child in children],
200+
output_ids=output_ids,
201+
uid_gen=self.uid_gen,
202+
)
203+
193204

194205
def _replace_unsupported_ops(node: nodes.BigFrameNode):
195206
node = nodes.bottom_up(node, rewrite.rewrite_slice)

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,57 @@ def from_query_string(
149149
select_expr.set("with", sge.With(expressions=[cte]))
150150
return cls(expr=select_expr, uid_gen=uid_gen)
151151

152+
@classmethod
153+
def from_union(
154+
cls,
155+
selects: typing.Sequence[sge.Select],
156+
output_ids: typing.Sequence[str],
157+
uid_gen: guid.SequentialUIDGenerator,
158+
) -> SQLGlotIR:
159+
"""Builds SQLGlot expression by union of multiple select expressions."""
160+
assert (
161+
len(list(selects)) >= 2
162+
), f"At least two select expressions must be provided, but got {selects}."
163+
164+
existing_ctes: list[sge.CTE] = []
165+
union_selects: list[sge.Select] = []
166+
for select in selects:
167+
assert isinstance(
168+
select, sge.Select
169+
), f"All provided expressions must be of type sge.Select, but got {type(select)}"
170+
171+
select_expr = select.copy()
172+
existing_ctes = [*existing_ctes, *select_expr.args.pop("with", [])]
173+
174+
new_cte_name = sge.to_identifier(
175+
next(uid_gen.get_uid_stream("bfcte_")), quoted=cls.quoted
176+
)
177+
new_cte = sge.CTE(
178+
this=select_expr,
179+
alias=new_cte_name,
180+
)
181+
existing_ctes = [*existing_ctes, new_cte]
182+
183+
selections = [
184+
sge.Alias(
185+
this=expr.alias_or_name,
186+
alias=sge.to_identifier(output_id, quoted=cls.quoted),
187+
)
188+
for expr, output_id in zip(select_expr.expressions, output_ids)
189+
]
190+
union_selects.append(
191+
sge.Select().select(*selections).from_(sge.Table(this=new_cte_name))
192+
)
193+
194+
union_expr = sg.union(
195+
*union_selects,
196+
distinct=False,
197+
copy=False,
198+
)
199+
final_select_expr = sge.Select().select(sge.Star()).from_(union_expr.subquery())
200+
final_select_expr.set("with", sge.With(expressions=existing_ctes))
201+
return cls(expr=final_select_expr, uid_gen=uid_gen)
202+
152203
def select(
153204
self,
154205
selected_cols: tuple[tuple[str, sge.Expression], ...],
@@ -181,7 +232,7 @@ def project(
181232
)
182233
for id, expr in projected_cols
183234
]
184-
new_expr = self._encapsulate_as_cte().select(*projected_cols_expr, append=False)
235+
new_expr = self._encapsulate_as_cte().select(*projected_cols_expr, append=True)
185236
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
186237

187238
def insert(
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
WITH `bfcte_1` AS (
2+
SELECT
3+
*
4+
FROM UNNEST(ARRAY<STRUCT<`bfcol_0` INT64, `bfcol_1` INT64, `bfcol_2` INT64, `bfcol_3` STRING, `bfcol_4` INT64>>[STRUCT(0, 123456789, 0, 'Hello, World!', 0), STRUCT(1, -987654321, 1, 'こんにちは', 1), STRUCT(2, 314159, 2, ' ¡Hola Mundo! ', 2), STRUCT(3, CAST(NULL AS INT64), 3, CAST(NULL AS STRING), 3), STRUCT(4, -234892, 4, 'Hello, World!', 4), STRUCT(5, 55555, 5, 'Güten Tag!', 5), STRUCT(6, 101202303, 6, 'capitalize, This ', 6), STRUCT(7, -214748367, 7, ' سلام', 7), STRUCT(8, 2, 8, 'T', 8)])
5+
), `bfcte_3` AS (
6+
SELECT
7+
`bfcol_0` AS `bfcol_5`,
8+
`bfcol_2` AS `bfcol_6`,
9+
`bfcol_1` AS `bfcol_7`,
10+
`bfcol_3` AS `bfcol_8`,
11+
`bfcol_4` AS `bfcol_9`
12+
FROM `bfcte_1`
13+
), `bfcte_5` AS (
14+
SELECT
15+
*,
16+
`bfcol_9` AS `bfcol_10`
17+
FROM `bfcte_3`
18+
), `bfcte_7` AS (
19+
SELECT
20+
`bfcol_5` AS `bfcol_11`,
21+
`bfcol_6` AS `bfcol_12`,
22+
`bfcol_7` AS `bfcol_13`,
23+
`bfcol_8` AS `bfcol_14`,
24+
`bfcol_10` AS `bfcol_15`
25+
FROM `bfcte_5`
26+
), `bfcte_9` AS (
27+
SELECT
28+
*,
29+
0 AS `bfcol_16`
30+
FROM `bfcte_7`
31+
), `bfcte_10` AS (
32+
SELECT
33+
`bfcol_11` AS `bfcol_17`,
34+
`bfcol_12` AS `bfcol_18`,
35+
`bfcol_13` AS `bfcol_19`,
36+
`bfcol_14` AS `bfcol_20`,
37+
`bfcol_16` AS `bfcol_21`,
38+
`bfcol_15` AS `bfcol_22`
39+
FROM `bfcte_9`
40+
), `bfcte_0` AS (
41+
SELECT
42+
*
43+
FROM UNNEST(ARRAY<STRUCT<`bfcol_23` INT64, `bfcol_24` INT64, `bfcol_25` INT64, `bfcol_26` STRING, `bfcol_27` INT64>>[STRUCT(0, 123456789, 0, 'Hello, World!', 0), STRUCT(1, -987654321, 1, 'こんにちは', 1), STRUCT(2, 314159, 2, ' ¡Hola Mundo! ', 2), STRUCT(3, CAST(NULL AS INT64), 3, CAST(NULL AS STRING), 3), STRUCT(4, -234892, 4, 'Hello, World!', 4), STRUCT(5, 55555, 5, 'Güten Tag!', 5), STRUCT(6, 101202303, 6, 'capitalize, This ', 6), STRUCT(7, -214748367, 7, ' سلام', 7), STRUCT(8, 2, 8, 'T', 8)])
44+
), `bfcte_2` AS (
45+
SELECT
46+
`bfcol_23` AS `bfcol_28`,
47+
`bfcol_25` AS `bfcol_29`,
48+
`bfcol_24` AS `bfcol_30`,
49+
`bfcol_26` AS `bfcol_31`,
50+
`bfcol_27` AS `bfcol_32`
51+
FROM `bfcte_0`
52+
), `bfcte_4` AS (
53+
SELECT
54+
*,
55+
`bfcol_32` AS `bfcol_33`
56+
FROM `bfcte_2`
57+
), `bfcte_6` AS (
58+
SELECT
59+
`bfcol_28` AS `bfcol_34`,
60+
`bfcol_29` AS `bfcol_35`,
61+
`bfcol_30` AS `bfcol_36`,
62+
`bfcol_31` AS `bfcol_37`,
63+
`bfcol_33` AS `bfcol_38`
64+
FROM `bfcte_4`
65+
), `bfcte_8` AS (
66+
SELECT
67+
*,
68+
1 AS `bfcol_39`
69+
FROM `bfcte_6`
70+
), `bfcte_11` AS (
71+
SELECT
72+
`bfcol_34` AS `bfcol_40`,
73+
`bfcol_35` AS `bfcol_41`,
74+
`bfcol_36` AS `bfcol_42`,
75+
`bfcol_37` AS `bfcol_43`,
76+
`bfcol_39` AS `bfcol_44`,
77+
`bfcol_38` AS `bfcol_45`
78+
FROM `bfcte_8`
79+
), `bfcte_12` AS (
80+
SELECT
81+
*
82+
FROM (
83+
SELECT
84+
bfcol_17 AS `bfcol_46`,
85+
bfcol_18 AS `bfcol_47`,
86+
bfcol_19 AS `bfcol_48`,
87+
bfcol_20 AS `bfcol_49`,
88+
bfcol_21 AS `bfcol_50`,
89+
bfcol_22 AS `bfcol_51`
90+
FROM `bfcte_10`
91+
UNION ALL
92+
SELECT
93+
bfcol_40 AS `bfcol_46`,
94+
bfcol_41 AS `bfcol_47`,
95+
bfcol_42 AS `bfcol_48`,
96+
bfcol_43 AS `bfcol_49`,
97+
bfcol_44 AS `bfcol_50`,
98+
bfcol_45 AS `bfcol_51`
99+
FROM `bfcte_11`
100+
)
101+
)
102+
SELECT
103+
`bfcol_46` AS `rowindex`,
104+
`bfcol_47` AS `rowindex_1`,
105+
`bfcol_48` AS `int64_col`,
106+
`bfcol_49` AS `string_col`
107+
FROM `bfcte_12`

tests/unit/core/compile/sqlglot/snapshots/test_compile_projection/test_compile_projection/out.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ WITH `bfcte_0` AS (
88
FROM `test-project`.`test_dataset`.`test_table`
99
), `bfcte_1` AS (
1010
SELECT
11+
*,
1112
`bfcol_0` AS `bfcol_5`,
1213
`bfcol_2` AS `bfcol_6`,
1314
`bfcol_3` AS `bfcol_7`,
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pandas as pd
16+
import pytest
17+
18+
import bigframes
19+
import bigframes.pandas as bpd
20+
21+
pytest.importorskip("pytest_snapshot")
22+
23+
24+
def test_compile_concat(
25+
scalars_types_pandas_df: pd.DataFrame, compiler_session: bigframes.Session, snapshot
26+
):
27+
# TODO: concat two same dataframes, which SQL does not get reused.
28+
# TODO: concat dataframes from a gbq table but trigger a windows compiler.
29+
df1 = bpd.DataFrame(scalars_types_pandas_df, session=compiler_session)
30+
df1 = df1[["rowindex", "int64_col", "string_col"]]
31+
concat_df = bpd.concat([df1, df1])
32+
snapshot.assert_match(concat_df.sql, "out.sql")

0 commit comments

Comments
 (0)