Skip to content

Commit 3e6dfe7

Browse files
authored
refactor: add _join_condition for all types (#1880)
Fixes internal issue 427501553
1 parent f30f750 commit 3e6dfe7

File tree

11 files changed

+290
-5
lines changed

11 files changed

+290
-5
lines changed

bigframes/core/compile/sqlglot/scalar_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def compile_scalar_expression(
3939

4040
@compile_scalar_expression.register
4141
def compile_deref_expression(expr: expression.DerefOp) -> sge.Expression:
42-
return sge.ColumnDef(this=sge.to_identifier(expr.id.sql, quoted=True))
42+
return sge.Column(this=sge.to_identifier(expr.id.sql, quoted=True))
4343

4444

4545
@compile_scalar_expression.register

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,4 +491,89 @@ def _join_condition(
491491
right: typed_expr.TypedExpr,
492492
joins_nulls: bool,
493493
) -> typing.Union[sge.EQ, sge.And]:
494-
return sge.EQ(this=left.expr, expression=right.expr)
494+
"""Generates a join condition to match pandas's null-handling logic.
495+
496+
Pandas treats null values as distinct from each other, leading to a
497+
cross-join-like behavior for null keys. In contrast, BigQuery SQL treats
498+
null values as equal, leading to a inner-join-like behavior.
499+
500+
This function generates the appropriate SQL condition to replicate the
501+
desired pandas behavior in BigQuery.
502+
503+
Args:
504+
left: The left-side join key.
505+
right: The right-side join key.
506+
joins_nulls: If True, generates complex logic to handle nulls/NaNs.
507+
Otherwise, uses a simple equality check where appropriate.
508+
"""
509+
is_floating_types = (
510+
left.dtype == dtypes.FLOAT_DTYPE and right.dtype == dtypes.FLOAT_DTYPE
511+
)
512+
if not is_floating_types and not joins_nulls:
513+
return sge.EQ(this=left.expr, expression=right.expr)
514+
515+
is_numeric_types = dtypes.is_numeric(
516+
left.dtype, include_bool=False
517+
) and dtypes.is_numeric(right.dtype, include_bool=False)
518+
if is_numeric_types:
519+
return _join_condition_for_numeric(left, right)
520+
else:
521+
return _join_condition_for_others(left, right)
522+
523+
524+
def _join_condition_for_others(
525+
left: typed_expr.TypedExpr,
526+
right: typed_expr.TypedExpr,
527+
) -> sge.And:
528+
"""Generates a join condition for non-numeric types to match pandas's
529+
null-handling logic.
530+
"""
531+
left_str = _cast(left.expr, "STRING")
532+
right_str = _cast(right.expr, "STRING")
533+
left_0 = sge.func("COALESCE", left_str, _literal("0", dtypes.STRING_DTYPE))
534+
left_1 = sge.func("COALESCE", left_str, _literal("1", dtypes.STRING_DTYPE))
535+
right_0 = sge.func("COALESCE", right_str, _literal("0", dtypes.STRING_DTYPE))
536+
right_1 = sge.func("COALESCE", right_str, _literal("1", dtypes.STRING_DTYPE))
537+
return sge.And(
538+
this=sge.EQ(this=left_0, expression=right_0),
539+
expression=sge.EQ(this=left_1, expression=right_1),
540+
)
541+
542+
543+
def _join_condition_for_numeric(
544+
left: typed_expr.TypedExpr,
545+
right: typed_expr.TypedExpr,
546+
) -> sge.And:
547+
"""Generates a join condition for non-numeric types to match pandas's
548+
null-handling logic. Specifically for FLOAT types, Pandas treats NaN aren't
549+
equal so need to coalesce as well with different constants.
550+
"""
551+
is_floating_types = (
552+
left.dtype == dtypes.FLOAT_DTYPE and right.dtype == dtypes.FLOAT_DTYPE
553+
)
554+
left_0 = sge.func("COALESCE", left.expr, _literal(0, left.dtype))
555+
left_1 = sge.func("COALESCE", left.expr, _literal(1, left.dtype))
556+
right_0 = sge.func("COALESCE", right.expr, _literal(0, right.dtype))
557+
right_1 = sge.func("COALESCE", right.expr, _literal(1, right.dtype))
558+
if not is_floating_types:
559+
return sge.And(
560+
this=sge.EQ(this=left_0, expression=right_0),
561+
expression=sge.EQ(this=left_1, expression=right_1),
562+
)
563+
564+
left_2 = sge.If(
565+
this=sge.IsNan(this=left.expr), true=_literal(2, left.dtype), false=left_0
566+
)
567+
left_3 = sge.If(
568+
this=sge.IsNan(this=left.expr), true=_literal(3, left.dtype), false=left_1
569+
)
570+
right_2 = sge.If(
571+
this=sge.IsNan(this=right.expr), true=_literal(2, right.dtype), false=right_0
572+
)
573+
right_3 = sge.If(
574+
this=sge.IsNan(this=right.expr), true=_literal(3, right.dtype), false=right_1
575+
)
576+
return sge.And(
577+
this=sge.EQ(this=left_2, expression=right_2),
578+
expression=sge.EQ(this=left_3, expression=right_3),
579+
)

bigframes/dtypes.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,9 @@ def is_json_encoding_type(type_: ExpressionType) -> bool:
341341
return type_ != GEO_DTYPE
342342

343343

344-
def is_numeric(type_: ExpressionType) -> bool:
345-
return type_ in NUMERIC_BIGFRAMES_TYPES_PERMISSIVE
344+
def is_numeric(type_: ExpressionType, include_bool: bool = True) -> bool:
345+
is_numeric = type_ in NUMERIC_BIGFRAMES_TYPES_PERMISSIVE
346+
return is_numeric if include_bool else is_numeric and type_ != BOOL_DTYPE
346347

347348

348349
def is_iterable(type_: ExpressionType) -> bool:

tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ WITH `bfcte_1` AS (
2323
*
2424
FROM `bfcte_2`
2525
LEFT JOIN `bfcte_3`
26-
ON `bfcol_2` = `bfcol_6`
26+
ON COALESCE(`bfcol_2`, 0) = COALESCE(`bfcol_6`, 0)
27+
AND COALESCE(`bfcol_2`, 1) = COALESCE(`bfcol_6`, 1)
2728
)
2829
SELECT
2930
`bfcol_3` AS `int64_col`,
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
WITH `bfcte_1` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`rowindex` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_2` AS (
7+
SELECT
8+
`bfcol_1` AS `bfcol_2`,
9+
`bfcol_0` AS `bfcol_3`
10+
FROM `bfcte_1`
11+
), `bfcte_0` AS (
12+
SELECT
13+
`bool_col` AS `bfcol_4`,
14+
`rowindex` AS `bfcol_5`
15+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
16+
), `bfcte_3` AS (
17+
SELECT
18+
`bfcol_5` AS `bfcol_6`,
19+
`bfcol_4` AS `bfcol_7`
20+
FROM `bfcte_0`
21+
), `bfcte_4` AS (
22+
SELECT
23+
*
24+
FROM `bfcte_2`
25+
INNER JOIN `bfcte_3`
26+
ON COALESCE(CAST(`bfcol_3` AS STRING), '0') = COALESCE(CAST(`bfcol_7` AS STRING), '0')
27+
AND COALESCE(CAST(`bfcol_3` AS STRING), '1') = COALESCE(CAST(`bfcol_7` AS STRING), '1')
28+
)
29+
SELECT
30+
`bfcol_2` AS `rowindex_x`,
31+
`bfcol_3` AS `bool_col`,
32+
`bfcol_6` AS `rowindex_y`
33+
FROM `bfcte_4`
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
WITH `bfcte_1` AS (
2+
SELECT
3+
`float64_col` AS `bfcol_0`,
4+
`rowindex` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_2` AS (
7+
SELECT
8+
`bfcol_1` AS `bfcol_2`,
9+
`bfcol_0` AS `bfcol_3`
10+
FROM `bfcte_1`
11+
), `bfcte_0` AS (
12+
SELECT
13+
`float64_col` AS `bfcol_4`,
14+
`rowindex` AS `bfcol_5`
15+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
16+
), `bfcte_3` AS (
17+
SELECT
18+
`bfcol_5` AS `bfcol_6`,
19+
`bfcol_4` AS `bfcol_7`
20+
FROM `bfcte_0`
21+
), `bfcte_4` AS (
22+
SELECT
23+
*
24+
FROM `bfcte_2`
25+
INNER JOIN `bfcte_3`
26+
ON IF(IS_NAN(`bfcol_3`), 2, COALESCE(`bfcol_3`, 0)) = IF(IS_NAN(`bfcol_7`), 2, COALESCE(`bfcol_7`, 0))
27+
AND IF(IS_NAN(`bfcol_3`), 3, COALESCE(`bfcol_3`, 1)) = IF(IS_NAN(`bfcol_7`), 3, COALESCE(`bfcol_7`, 1))
28+
)
29+
SELECT
30+
`bfcol_2` AS `rowindex_x`,
31+
`bfcol_3` AS `float64_col`,
32+
`bfcol_6` AS `rowindex_y`
33+
FROM `bfcte_4`
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
WITH `bfcte_1` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`,
4+
`rowindex` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_2` AS (
7+
SELECT
8+
`bfcol_1` AS `bfcol_2`,
9+
`bfcol_0` AS `bfcol_3`
10+
FROM `bfcte_1`
11+
), `bfcte_0` AS (
12+
SELECT
13+
`int64_col` AS `bfcol_4`,
14+
`rowindex` AS `bfcol_5`
15+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
16+
), `bfcte_3` AS (
17+
SELECT
18+
`bfcol_5` AS `bfcol_6`,
19+
`bfcol_4` AS `bfcol_7`
20+
FROM `bfcte_0`
21+
), `bfcte_4` AS (
22+
SELECT
23+
*
24+
FROM `bfcte_2`
25+
INNER JOIN `bfcte_3`
26+
ON COALESCE(`bfcol_3`, 0) = COALESCE(`bfcol_7`, 0)
27+
AND COALESCE(`bfcol_3`, 1) = COALESCE(`bfcol_7`, 1)
28+
)
29+
SELECT
30+
`bfcol_2` AS `rowindex_x`,
31+
`bfcol_3` AS `int64_col`,
32+
`bfcol_6` AS `rowindex_y`
33+
FROM `bfcte_4`
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
WITH `bfcte_1` AS (
2+
SELECT
3+
`numeric_col` AS `bfcol_0`,
4+
`rowindex` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_2` AS (
7+
SELECT
8+
`bfcol_1` AS `bfcol_2`,
9+
`bfcol_0` AS `bfcol_3`
10+
FROM `bfcte_1`
11+
), `bfcte_0` AS (
12+
SELECT
13+
`numeric_col` AS `bfcol_4`,
14+
`rowindex` AS `bfcol_5`
15+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
16+
), `bfcte_3` AS (
17+
SELECT
18+
`bfcol_5` AS `bfcol_6`,
19+
`bfcol_4` AS `bfcol_7`
20+
FROM `bfcte_0`
21+
), `bfcte_4` AS (
22+
SELECT
23+
*
24+
FROM `bfcte_2`
25+
INNER JOIN `bfcte_3`
26+
ON COALESCE(`bfcol_3`, CAST(0 AS NUMERIC)) = COALESCE(`bfcol_7`, CAST(0 AS NUMERIC))
27+
AND COALESCE(`bfcol_3`, CAST(1 AS NUMERIC)) = COALESCE(`bfcol_7`, CAST(1 AS NUMERIC))
28+
)
29+
SELECT
30+
`bfcol_2` AS `rowindex_x`,
31+
`bfcol_3` AS `numeric_col`,
32+
`bfcol_6` AS `rowindex_y`
33+
FROM `bfcte_4`
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
WITH `bfcte_1` AS (
2+
SELECT
3+
`rowindex` AS `bfcol_0`,
4+
`string_col` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_0` AS (
7+
SELECT
8+
`rowindex` AS `bfcol_2`,
9+
`string_col` AS `bfcol_3`
10+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
11+
), `bfcte_2` AS (
12+
SELECT
13+
`bfcol_2` AS `bfcol_4`,
14+
`bfcol_3` AS `bfcol_5`
15+
FROM `bfcte_0`
16+
), `bfcte_3` AS (
17+
SELECT
18+
*
19+
FROM `bfcte_1`
20+
INNER JOIN `bfcte_2`
21+
ON COALESCE(CAST(`bfcol_1` AS STRING), '0') = COALESCE(CAST(`bfcol_5` AS STRING), '0')
22+
AND COALESCE(CAST(`bfcol_1` AS STRING), '1') = COALESCE(CAST(`bfcol_5` AS STRING), '1')
23+
)
24+
SELECT
25+
`bfcol_0` AS `rowindex_x`,
26+
`bfcol_1` AS `string_col`,
27+
`bfcol_4` AS `rowindex_y`
28+
FROM `bfcte_3`
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
WITH `bfcte_1` AS (
2+
SELECT
3+
`rowindex` AS `bfcol_0`,
4+
`time_col` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_0` AS (
7+
SELECT
8+
`rowindex` AS `bfcol_2`,
9+
`time_col` AS `bfcol_3`
10+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
11+
), `bfcte_2` AS (
12+
SELECT
13+
`bfcol_2` AS `bfcol_4`,
14+
`bfcol_3` AS `bfcol_5`
15+
FROM `bfcte_0`
16+
), `bfcte_3` AS (
17+
SELECT
18+
*
19+
FROM `bfcte_1`
20+
INNER JOIN `bfcte_2`
21+
ON COALESCE(CAST(`bfcol_1` AS STRING), '0') = COALESCE(CAST(`bfcol_5` AS STRING), '0')
22+
AND COALESCE(CAST(`bfcol_1` AS STRING), '1') = COALESCE(CAST(`bfcol_5` AS STRING), '1')
23+
)
24+
SELECT
25+
`bfcol_0` AS `rowindex_x`,
26+
`bfcol_1` AS `time_col`,
27+
`bfcol_4` AS `rowindex_y`
28+
FROM `bfcte_3`

0 commit comments

Comments
 (0)