Skip to content

Commit e1d54d2

Browse files
authored
refactor: fix some aggregation ops in the sqlglot compiler (#2382)
This change fixes several aggregation-related test failures in #2248 by including the following operators: - `agg_ops.DiffOp`: Fixes test_date_series_diff_agg - `agg_ops.AllOp/AnyOp`: Fixes test_list_apply_callable - `agg_ops.QuantileOp`: Fixes test_dataframe_aggregates_median - `agg_ops.ProductOp`: Fixes test_dataframe_groupby_analytic Fixes internal issue 417774347 🦕
1 parent 173efd9 commit e1d54d2

File tree

12 files changed

+122
-74
lines changed

12 files changed

+122
-74
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from bigframes.core import window_spec
2424
import bigframes.core.compile.sqlglot.aggregations.op_registration as reg
2525
from bigframes.core.compile.sqlglot.aggregations.windows import apply_window_if_present
26+
from bigframes.core.compile.sqlglot.expressions import constants
2627
import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr
2728
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
2829
from bigframes.operations import aggregations as agg_ops
@@ -44,9 +45,13 @@ def _(
4445
column: typed_expr.TypedExpr,
4546
window: typing.Optional[window_spec.WindowSpec] = None,
4647
) -> sge.Expression:
47-
# BQ will return null for empty column, result would be false in pandas.
48-
result = apply_window_if_present(sge.func("LOGICAL_AND", column.expr), window)
49-
return sge.func("IFNULL", result, sge.true())
48+
expr = column.expr
49+
if column.dtype != dtypes.BOOL_DTYPE:
50+
expr = sge.NEQ(this=expr, expression=sge.convert(0))
51+
expr = apply_window_if_present(sge.func("LOGICAL_AND", expr), window)
52+
53+
# BQ will return null for empty column, result would be true in pandas.
54+
return sge.func("COALESCE", expr, sge.convert(True))
5055

5156

5257
@UNARY_OP_REGISTRATION.register(agg_ops.AnyOp)
@@ -56,6 +61,8 @@ def _(
5661
window: typing.Optional[window_spec.WindowSpec] = None,
5762
) -> sge.Expression:
5863
expr = column.expr
64+
if column.dtype != dtypes.BOOL_DTYPE:
65+
expr = sge.NEQ(this=expr, expression=sge.convert(0))
5966
expr = apply_window_if_present(sge.func("LOGICAL_OR", expr), window)
6067

6168
# BQ will return null for empty column, result would be false in pandas.
@@ -326,6 +333,15 @@ def _(
326333
unit=sge.Identifier(this="MICROSECOND"),
327334
)
328335

336+
if column.dtype == dtypes.DATE_DTYPE:
337+
date_diff = sge.DateDiff(
338+
this=column.expr, expression=shifted, unit=sge.Identifier(this="DAY")
339+
)
340+
return sge.Cast(
341+
this=sge.Floor(this=date_diff * constants._DAY_TO_MICROSECONDS),
342+
to="INT64",
343+
)
344+
329345
raise TypeError(f"Cannot perform diff on type {column.dtype}")
330346

331347

@@ -410,24 +426,28 @@ def _(
410426
column: typed_expr.TypedExpr,
411427
window: typing.Optional[window_spec.WindowSpec] = None,
412428
) -> sge.Expression:
429+
expr = column.expr
430+
if column.dtype == dtypes.BOOL_DTYPE:
431+
expr = sge.Cast(this=expr, to="INT64")
432+
413433
# Need to short-circuit as log with zeroes is illegal sql
414-
is_zero = sge.EQ(this=column.expr, expression=sge.convert(0))
434+
is_zero = sge.EQ(this=expr, expression=sge.convert(0))
415435

416436
# There is no product sql aggregate function, so must implement as a sum of logs, and then
417437
# apply power after. Note, log and power base must be equal! This impl uses natural log.
418-
logs = (
419-
sge.Case()
420-
.when(is_zero, sge.convert(0))
421-
.else_(sge.func("LN", sge.func("ABS", column.expr)))
438+
logs = sge.If(
439+
this=is_zero,
440+
true=sge.convert(0),
441+
false=sge.func("LOG", sge.convert(2), sge.func("ABS", expr)),
422442
)
423443
logs_sum = apply_window_if_present(sge.func("SUM", logs), window)
424-
magnitude = sge.func("EXP", logs_sum)
444+
magnitude = sge.func("POWER", sge.convert(2), logs_sum)
425445

426446
# Can't determine sign from logs, so have to determine parity of count of negative inputs
427447
is_negative = (
428448
sge.Case()
429449
.when(
430-
sge.LT(this=sge.func("SIGN", column.expr), expression=sge.convert(0)),
450+
sge.EQ(this=sge.func("SIGN", expr), expression=sge.convert(-1)),
431451
sge.convert(1),
432452
)
433453
.else_(sge.convert(0))
@@ -445,11 +465,7 @@ def _(
445465
.else_(
446466
sge.Mul(
447467
this=magnitude,
448-
expression=sge.If(
449-
this=sge.EQ(this=negative_count_parity, expression=sge.convert(1)),
450-
true=sge.convert(-1),
451-
false=sge.convert(1),
452-
),
468+
expression=sge.func("POWER", sge.convert(-1), negative_count_parity),
453469
)
454470
)
455471
)
@@ -499,14 +515,18 @@ def _(
499515
column: typed_expr.TypedExpr,
500516
window: typing.Optional[window_spec.WindowSpec] = None,
501517
) -> sge.Expression:
502-
# TODO: Support interpolation argument
503-
# TODO: Support percentile_disc
504-
result: sge.Expression = sge.func("PERCENTILE_CONT", column.expr, sge.convert(op.q))
518+
expr = column.expr
519+
if column.dtype == dtypes.BOOL_DTYPE:
520+
expr = sge.Cast(this=expr, to="INT64")
521+
522+
result: sge.Expression = sge.func("PERCENTILE_CONT", expr, sge.convert(op.q))
505523
if window is None:
506-
# PERCENTILE_CONT is a navigation function, not an aggregate function, so it always needs an OVER clause.
524+
# PERCENTILE_CONT is a navigation function, not an aggregate function,
525+
# so it always needs an OVER clause.
507526
result = sge.Window(this=result)
508527
else:
509528
result = apply_window_if_present(result, window)
529+
510530
if op.should_floor_result:
511531
result = sge.Cast(this=sge.func("FLOOR", result), to="INT64")
512532
return result

bigframes/core/compile/sqlglot/expressions/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
_NAN = sge.Cast(this=sge.convert("NaN"), to="FLOAT64")
2121
_INF = sge.Cast(this=sge.convert("Infinity"), to="FLOAT64")
2222
_NEG_INF = sge.Cast(this=sge.convert("-Infinity"), to="FLOAT64")
23+
_DAY_TO_MICROSECONDS = sge.convert(86400000000)
2324

2425
# Approx Highest number you can pass in to EXP function and get a valid FLOAT64 result
2526
# FLOAT64 has 11 exponent bits, so max values is about 2**(2**10)
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
WITH `bfcte_0` AS (
22
SELECT
3-
`bool_col`
3+
`bool_col`,
4+
`int64_col`
45
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
56
), `bfcte_1` AS (
67
SELECT
7-
COALESCE(LOGICAL_AND(`bool_col`), TRUE) AS `bfcol_1`
8+
COALESCE(LOGICAL_AND(`bool_col`), TRUE) AS `bfcol_2`,
9+
COALESCE(LOGICAL_AND(`int64_col` <> 0), TRUE) AS `bfcol_3`
810
FROM `bfcte_0`
911
)
1012
SELECT
11-
`bfcol_1` AS `bool_col`
13+
`bfcol_2` AS `bool_col`,
14+
`bfcol_3` AS `int64_col`
1215
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_all/window_partition_out.sql

Lines changed: 0 additions & 14 deletions
This file was deleted.

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_all/window_out.sql renamed to tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_all_w_window/out.sql

File renamed without changes.
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
WITH `bfcte_0` AS (
22
SELECT
3-
`bool_col`
3+
`bool_col`,
4+
`int64_col`
45
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
56
), `bfcte_1` AS (
67
SELECT
7-
COALESCE(LOGICAL_OR(`bool_col`), FALSE) AS `bfcol_1`
8+
COALESCE(LOGICAL_OR(`bool_col`), FALSE) AS `bfcol_2`,
9+
COALESCE(LOGICAL_OR(`int64_col` <> 0), FALSE) AS `bfcol_3`
810
FROM `bfcte_0`
911
)
1012
SELECT
11-
`bfcol_1` AS `bool_col`
13+
`bfcol_2` AS `bool_col`,
14+
`bfcol_3` AS `int64_col`
1215
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_any/window_out.sql renamed to tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_any_w_window/out.sql

File renamed without changes.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`date_col`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CAST(FLOOR(
9+
DATE_DIFF(`date_col`, LAG(`date_col`, 1) OVER (ORDER BY `date_col` ASC NULLS LAST), DAY) * 86400000000
10+
) AS INT64) AS `bfcol_1`
11+
FROM `bfcte_0`
12+
)
13+
SELECT
14+
`bfcol_1` AS `diff_date`
15+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_product/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ WITH `bfcte_0` AS (
77
CASE
88
WHEN LOGICAL_OR(`int64_col` = 0)
99
THEN 0
10-
ELSE EXP(SUM(CASE WHEN `int64_col` = 0 THEN 0 ELSE LN(ABS(`int64_col`)) END)) * IF(MOD(SUM(CASE WHEN SIGN(`int64_col`) < 0 THEN 1 ELSE 0 END), 2) = 1, -1, 1)
10+
ELSE POWER(2, SUM(IF(`int64_col` = 0, 0, LOG(ABS(`int64_col`), 2)))) * POWER(-1, MOD(SUM(CASE WHEN SIGN(`int64_col`) = -1 THEN 1 ELSE 0 END), 2))
1111
END AS `bfcol_1`
1212
FROM `bfcte_0`
1313
)

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_product/window_partition_out.sql

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ WITH `bfcte_0` AS (
99
CASE
1010
WHEN LOGICAL_OR(`int64_col` = 0) OVER (PARTITION BY `string_col`)
1111
THEN 0
12-
ELSE EXP(
13-
SUM(CASE WHEN `int64_col` = 0 THEN 0 ELSE LN(ABS(`int64_col`)) END) OVER (PARTITION BY `string_col`)
14-
) * IF(
12+
ELSE POWER(
13+
2,
14+
SUM(IF(`int64_col` = 0, 0, LOG(ABS(`int64_col`), 2))) OVER (PARTITION BY `string_col`)
15+
) * POWER(
16+
-1,
1517
MOD(
16-
SUM(CASE WHEN SIGN(`int64_col`) < 0 THEN 1 ELSE 0 END) OVER (PARTITION BY `string_col`),
18+
SUM(CASE WHEN SIGN(`int64_col`) = -1 THEN 1 ELSE 0 END) OVER (PARTITION BY `string_col`),
1719
2
18-
) = 1,
19-
-1,
20-
1
20+
)
2121
)
2222
END AS `bfcol_2`
2323
FROM `bfcte_0`

0 commit comments

Comments
 (0)