Skip to content

Commit 59e4d0e

Browse files
committed
complete most implementations
1 parent 06a0a65 commit 59e4d0e

File tree

5 files changed

+100
-61
lines changed

5 files changed

+100
-61
lines changed

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,15 +248,15 @@ def compile_join(
248248
def compile_isin_join(
249249
self, node: nodes.InNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR
250250
) -> ir.SQLGlotIR:
251-
conditions = tuple(
251+
conditions = (
252252
typed_expr.TypedExpr(
253253
scalar_compiler.compile_scalar_expression(node.left_col),
254-
left.output_type,
254+
node.left_col.output_type,
255255
),
256256
typed_expr.TypedExpr(
257257
scalar_compiler.compile_scalar_expression(node.right_col),
258-
right.output_type,
259-
),
258+
node.right_col.output_type,
259+
)
260260
)
261261

262262
return left.isin_join(

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 52 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -341,50 +341,61 @@ def isin_join(
341341
right: SQLGlotIR,
342342
indicator_col: str,
343343
conditions: tuple[typed_expr.TypedExpr, typed_expr.TypedExpr],
344-
*,
345344
joins_nulls: bool = True,
346345
) -> SQLGlotIR:
347346
"""Joins the current query with another SQLGlotIR instance."""
348-
# TODO: Optimization similar to Ibis:
349-
# if isinstance(values, ArrayValue):
350-
# return ops.ArrayContains(values, self).to_expr()
351-
# elif isinstance(values, Column):
352-
# return ops.InSubquery(values.as_table(), needle=self).to_expr()
353-
# else:
354-
# return ops.InValues(self, values).to_expr()
355-
356-
raise NotImplementedError
357-
# left_cte_name = sge.to_identifier(
358-
# next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
359-
# )
360-
# right_cte_name = sge.to_identifier(
361-
# next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
362-
# )
363-
364-
# left_select = _select_to_cte(self.expr, left_cte_name)
365-
# right_select = _select_to_cte(right.expr, right_cte_name)
366-
367-
# left_ctes = left_select.args.pop("with", [])
368-
# right_ctes = right_select.args.pop("with", [])
369-
# merged_ctes = [*left_ctes, *right_ctes]
370-
371-
372-
373-
# join_conditions = [
374-
# _join_condition(left, right, joins_nulls) for left, right in conditions
375-
# ]
376-
# join_on = sge.And(expressions=join_conditions) if join_conditions else None
377-
378-
# join_type_str = join_type if join_type != "outer" else "full outer"
379-
# new_expr = (
380-
# sge.Select()
381-
# .select(sge.Star())
382-
# .from_(sge.Table(this=left_cte_name))
383-
# .join(sge.Table(this=right_cte_name), on=join_on, join_type=join_type_str)
384-
# )
385-
# new_expr.set("with", sge.With(expressions=merged_ctes))
386-
387-
# return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
347+
left_cte_name = sge.to_identifier(
348+
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
349+
)
350+
right_cte_name = sge.to_identifier(
351+
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
352+
)
353+
354+
left_select = _select_to_cte(self.expr, left_cte_name)
355+
right_select = _select_to_cte(right.expr, right_cte_name)
356+
357+
left_ctes = left_select.args.pop("with", [])
358+
right_ctes = right_select.args.pop("with", [])
359+
merged_ctes = [*left_ctes, *right_ctes]
360+
361+
left_condition = typed_expr.TypedExpr(
362+
sge.Column(this=conditions[0].expr, table=left_cte_name),
363+
conditions[0].dtype,
364+
)
365+
right_condition = typed_expr.TypedExpr(
366+
sge.Column(this=conditions[1].expr, table=right_cte_name),
367+
conditions[1].dtype,
368+
)
369+
370+
new_column: sge.Expression
371+
if joins_nulls:
372+
new_column = sge.Exists(
373+
this=sge.Select()
374+
.select(sge.convert(1))
375+
.from_(sge.Table(this=right_cte_name))
376+
.where(
377+
_join_condition(left_condition, right_condition, joins_nulls=True)
378+
)
379+
)
380+
else:
381+
new_column = sge.In(
382+
this=left_condition.expr,
383+
expressions=[right_condition.expr],
384+
)
385+
386+
new_column = sge.Alias(
387+
this=new_column,
388+
alias=sge.to_identifier(indicator_col, quoted=self.quoted),
389+
)
390+
391+
new_expr = (
392+
sge.Select()
393+
.select(sge.Column(this=sge.Star(), table=left_cte_name), new_column)
394+
.from_(sge.Table(this=left_cte_name))
395+
)
396+
new_expr.set("with", sge.With(expressions=merged_ctes))
397+
398+
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
388399

389400
def explode(
390401
self,

tests/unit/core/compile/sqlglot/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def scalar_types_table_schema() -> typing.Sequence[bigquery.SchemaField]:
8585
bigquery.SchemaField("numeric_col", "NUMERIC"),
8686
bigquery.SchemaField("float64_col", "FLOAT"),
8787
bigquery.SchemaField("rowindex", "INTEGER"),
88-
bigquery.SchemaField("rowindex_2", "INTEGER"),
88+
bigquery.SchemaField("rowindex_2", "INTEGER", mode="REQUIRED"),
8989
bigquery.SchemaField("string_col", "STRING"),
9090
bigquery.SchemaField("time_col", "TIME"),
9191
bigquery.SchemaField("timestamp_col", "TIMESTAMP"),
Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,37 @@
1-
WITH `bfcte_0` AS (
1+
WITH `bfcte_1` AS (
22
SELECT
3-
*
4-
FROM UNNEST(ARRAY<STRUCT<`bfcol_0` FLOAT64, `bfcol_1` INT64>>[STRUCT(314159.0, 0), STRUCT(2.0, 1), STRUCT(3.0, 2), STRUCT(CAST(NULL AS FLOAT64), 3)])
3+
`int64_col` AS `bfcol_0`,
4+
`rowindex` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_2` AS (
7+
SELECT
8+
`bfcol_1` AS `bfcol_2`,
9+
`bfcol_0` AS `bfcol_3`
10+
FROM `bfcte_1`
11+
), `bfcte_0` AS (
12+
SELECT
13+
`int64_too` AS `bfcol_4`
14+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
15+
), `bfcte_3` AS (
16+
SELECT
17+
`bfcol_4`
18+
FROM `bfcte_0`
19+
GROUP BY
20+
`bfcol_4`
21+
), `bfcte_4` AS (
22+
SELECT
23+
`bfcte_2`.*,
24+
EXISTS(
25+
SELECT
26+
1
27+
FROM `bfcte_3`
28+
WHERE
29+
COALESCE(`bfcte_2`.`bfcol_3`, 0) = COALESCE(`bfcte_3`.`bfcol_4`, 0)
30+
AND COALESCE(`bfcte_2`.`bfcol_3`, 1) = COALESCE(`bfcte_3`.`bfcol_4`, 1)
31+
) AS `bfcol_5`
32+
FROM `bfcte_2`
533
)
634
SELECT
7-
`bfcol_0` AS `0`
8-
FROM `bfcte_0`
9-
ORDER BY
10-
`bfcol_1` ASC NULLS LAST
35+
`bfcol_2` AS `rowindex`,
36+
`bfcol_5` AS `int64_col`
37+
FROM `bfcte_4`

tests/unit/core/compile/sqlglot/test_compile_isin.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import pandas as pd
1615
import pytest
1716

18-
import bigframes
1917
import bigframes.pandas as bpd
2018

2119
pytest.importorskip("pytest_snapshot")
2220

2321

24-
def test_compile_isin(
25-
scalar_types_df: bpd.DataFrame, compiler_session: bigframes.Session, snapshot
26-
):
27-
data = [314159, 2.0, 3, pd.NA]
28-
s = bpd.Series(data, session=compiler_session)
29-
bf_isin = scalar_types_df["int64_col"].isin(s).to_frame()
22+
def test_compile_isin(scalar_types_df: bpd.DataFrame, snapshot):
23+
bf_isin = scalar_types_df["int64_col"].isin(scalar_types_df["int64_too"]).to_frame()
24+
snapshot.assert_match(bf_isin.sql, "out.sql")
25+
26+
27+
def test_compile_isin_not_nullable(scalar_types_df: bpd.DataFrame, snapshot):
28+
bf_isin = (
29+
scalar_types_df["rowindex_2"].isin(scalar_types_df["rowindex_2"]).to_frame()
30+
)
3031
snapshot.assert_match(bf_isin.sql, "out.sql")

0 commit comments

Comments
 (0)