Skip to content

Commit e156660

Browse files
authored
chore: Migrate IntegerLabelToDatetimeOp operator to SQLGlot (#2310)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes b/447388852 🦕
1 parent e1d54d2 commit e156660

File tree

8 files changed

+538
-22
lines changed
  • bigframes/core/compile/sqlglot/expressions
  • tests/unit/core/compile/sqlglot/expressions

8 files changed

+538
-22
lines changed

bigframes/core/compile/sqlglot/expressions/datetime_ops.py

Lines changed: 257 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,35 +19,14 @@
1919
from bigframes import dtypes
2020
from bigframes import operations as ops
2121
from bigframes.core.compile.constants import UNIT_TO_US_CONVERSION_FACTORS
22+
from bigframes.core.compile.sqlglot import sqlglot_types
2223
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2324
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2425

2526
register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op
2627
register_binary_op = scalar_compiler.scalar_op_compiler.register_binary_op
2728

2829

29-
def _calculate_resample_first(y: TypedExpr, origin: str) -> sge.Expression:
30-
if origin == "epoch":
31-
return sge.convert(0)
32-
elif origin == "start_day":
33-
return sge.func(
34-
"UNIX_MICROS",
35-
sge.Cast(
36-
this=sge.Cast(
37-
this=y.expr, to=sge.DataType(this=sge.DataType.Type.DATE)
38-
),
39-
to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ),
40-
),
41-
)
42-
elif origin == "start":
43-
return sge.func(
44-
"UNIX_MICROS",
45-
sge.Cast(this=y.expr, to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ)),
46-
)
47-
else:
48-
raise ValueError(f"Origin {origin} not supported")
49-
50-
5130
@register_binary_op(ops.DatetimeToIntegerLabelOp, pass_op=True)
5231
def datetime_to_integer_label_op(
5332
x: TypedExpr, y: TypedExpr, op: ops.DatetimeToIntegerLabelOp
@@ -317,6 +296,20 @@ def _(expr: TypedExpr, op: ops.FloorDtOp) -> sge.Expression:
317296
return sge.TimestampTrunc(this=expr.expr, unit=sge.Identifier(this=bq_freq))
318297

319298

299+
def _calculate_resample_first(y: TypedExpr, origin: str) -> sge.Expression:
300+
if origin == "epoch":
301+
return sge.convert(0)
302+
elif origin == "start_day":
303+
return sge.func(
304+
"UNIX_MICROS",
305+
sge.Cast(this=sge.Cast(this=y.expr, to="DATE"), to="TIMESTAMP"),
306+
)
307+
elif origin == "start":
308+
return sge.func("UNIX_MICROS", sge.Cast(this=y.expr, to="TIMESTAMP"))
309+
else:
310+
raise ValueError(f"Origin {origin} not supported")
311+
312+
320313
@register_unary_op(ops.hour_op)
321314
def _(expr: TypedExpr) -> sge.Expression:
322315
return sge.Extract(this=sge.Identifier(this="HOUR"), expression=expr.expr)
@@ -436,3 +429,245 @@ def _(expr: TypedExpr, op: ops.UnixSeconds) -> sge.Expression:
436429
@register_unary_op(ops.year_op)
437430
def _(expr: TypedExpr) -> sge.Expression:
438431
return sge.Extract(this=sge.Identifier(this="YEAR"), expression=expr.expr)
432+
433+
434+
@register_binary_op(ops.IntegerLabelToDatetimeOp, pass_op=True)
435+
def integer_label_to_datetime_op(
436+
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
437+
) -> sge.Expression:
438+
# Determine if the frequency is fixed by checking if 'op.freq.nanos' is defined.
439+
try:
440+
return _integer_label_to_datetime_op_fixed_frequency(x, y, op)
441+
442+
except ValueError:
443+
# Non-fixed frequency conversions for units ranging from weeks to years.
444+
rule_code = op.freq.rule_code
445+
446+
if rule_code == "W-SUN":
447+
return _integer_label_to_datetime_op_weekly_freq(x, y, op)
448+
449+
if rule_code in ("ME", "M"):
450+
return _integer_label_to_datetime_op_monthly_freq(x, y, op)
451+
452+
if rule_code in ("QE-DEC", "Q-DEC"):
453+
return _integer_label_to_datetime_op_quarterly_freq(x, y, op)
454+
455+
if rule_code in ("YE-DEC", "A-DEC", "Y-DEC"):
456+
return _integer_label_to_datetime_op_yearly_freq(x, y, op)
457+
458+
# If the rule_code is not recognized, raise an error here.
459+
raise ValueError(f"Unsupported frequency rule code: {rule_code}")
460+
461+
462+
def _integer_label_to_datetime_op_fixed_frequency(
463+
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
464+
) -> sge.Expression:
465+
"""
466+
This function handles fixed frequency conversions where the unit can range
467+
from microseconds (us) to days.
468+
"""
469+
us = op.freq.nanos / 1000
470+
first = _calculate_resample_first(y, op.origin) # type: ignore
471+
x_label = sge.Cast(
472+
this=sge.func(
473+
"TIMESTAMP_MICROS",
474+
sge.Cast(
475+
this=sge.Add(
476+
this=sge.Mul(
477+
this=sge.Cast(this=x.expr, to="BIGNUMERIC"),
478+
expression=sge.convert(int(us)),
479+
),
480+
expression=sge.Cast(this=first, to="BIGNUMERIC"),
481+
),
482+
to="INT64",
483+
),
484+
),
485+
to=sqlglot_types.from_bigframes_dtype(y.dtype),
486+
)
487+
return x_label
488+
489+
490+
def _integer_label_to_datetime_op_weekly_freq(
491+
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
492+
) -> sge.Expression:
493+
n = op.freq.n
494+
# Calculate microseconds for the weekly interval.
495+
us = n * 7 * 24 * 60 * 60 * 1000000
496+
first = sge.func(
497+
"UNIX_MICROS",
498+
sge.Add(
499+
this=sge.TimestampTrunc(
500+
this=sge.Cast(this=y.expr, to="TIMESTAMP"),
501+
unit=sge.Var(this="WEEK(MONDAY)"),
502+
),
503+
expression=sge.Interval(
504+
this=sge.convert(6), unit=sge.Identifier(this="DAY")
505+
),
506+
),
507+
)
508+
return sge.Cast(
509+
this=sge.func(
510+
"TIMESTAMP_MICROS",
511+
sge.Cast(
512+
this=sge.Add(
513+
this=sge.Mul(
514+
this=sge.Cast(this=x.expr, to="BIGNUMERIC"),
515+
expression=sge.convert(us),
516+
),
517+
expression=sge.Cast(this=first, to="BIGNUMERIC"),
518+
),
519+
to="INT64",
520+
),
521+
),
522+
to=sqlglot_types.from_bigframes_dtype(y.dtype),
523+
)
524+
525+
526+
def _integer_label_to_datetime_op_monthly_freq(
527+
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
528+
) -> sge.Expression:
529+
n = op.freq.n
530+
one = sge.convert(1)
531+
twelve = sge.convert(12)
532+
first = sge.Sub( # type: ignore
533+
this=sge.Add(
534+
this=sge.Mul(
535+
this=sge.Extract(this="YEAR", expression=y.expr),
536+
expression=twelve,
537+
),
538+
expression=sge.Extract(this="MONTH", expression=y.expr),
539+
),
540+
expression=one,
541+
)
542+
x_val = sge.Add(
543+
this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first
544+
)
545+
year = sge.Cast(
546+
this=sge.Floor(this=sge.func("IEEE_DIVIDE", x_val, twelve)),
547+
to="INT64",
548+
)
549+
month = sge.Add(this=sge.Mod(this=x_val, expression=twelve), expression=one)
550+
551+
next_year = sge.Case(
552+
ifs=[
553+
sge.If(
554+
this=sge.EQ(this=month, expression=twelve),
555+
true=sge.Add(this=year, expression=one),
556+
)
557+
],
558+
default=year,
559+
)
560+
next_month = sge.Case(
561+
ifs=[sge.If(this=sge.EQ(this=month, expression=twelve), true=one)],
562+
default=sge.Add(this=month, expression=one),
563+
)
564+
next_month_date = sge.func(
565+
"TIMESTAMP",
566+
sge.Anonymous(
567+
this="DATETIME",
568+
expressions=[
569+
next_year,
570+
next_month,
571+
one,
572+
sge.convert(0),
573+
sge.convert(0),
574+
sge.convert(0),
575+
],
576+
),
577+
)
578+
x_label = sge.Sub( # type: ignore
579+
this=next_month_date, expression=sge.Interval(this=one, unit="DAY")
580+
)
581+
return sge.Cast(this=x_label, to=sqlglot_types.from_bigframes_dtype(y.dtype))
582+
583+
584+
def _integer_label_to_datetime_op_quarterly_freq(
585+
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
586+
) -> sge.Expression:
587+
n = op.freq.n
588+
one = sge.convert(1)
589+
three = sge.convert(3)
590+
four = sge.convert(4)
591+
twelve = sge.convert(12)
592+
first = sge.Sub( # type: ignore
593+
this=sge.Add(
594+
this=sge.Mul(
595+
this=sge.Extract(this="YEAR", expression=y.expr),
596+
expression=four,
597+
),
598+
expression=sge.Extract(this="QUARTER", expression=y.expr),
599+
),
600+
expression=one,
601+
)
602+
x_val = sge.Add(
603+
this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first
604+
)
605+
year = sge.Cast(
606+
this=sge.Floor(this=sge.func("IEEE_DIVIDE", x_val, four)),
607+
to="INT64",
608+
)
609+
month = sge.Mul( # type: ignore
610+
this=sge.Paren(
611+
this=sge.Add(this=sge.Mod(this=x_val, expression=four), expression=one)
612+
),
613+
expression=three,
614+
)
615+
616+
next_year = sge.Case(
617+
ifs=[
618+
sge.If(
619+
this=sge.EQ(this=month, expression=twelve),
620+
true=sge.Add(this=year, expression=one),
621+
)
622+
],
623+
default=year,
624+
)
625+
next_month = sge.Case(
626+
ifs=[sge.If(this=sge.EQ(this=month, expression=twelve), true=one)],
627+
default=sge.Add(this=month, expression=one),
628+
)
629+
next_month_date = sge.Anonymous(
630+
this="DATETIME",
631+
expressions=[
632+
next_year,
633+
next_month,
634+
one,
635+
sge.convert(0),
636+
sge.convert(0),
637+
sge.convert(0),
638+
],
639+
)
640+
x_label = sge.Sub( # type: ignore
641+
this=next_month_date, expression=sge.Interval(this=one, unit="DAY")
642+
)
643+
return sge.Cast(this=x_label, to=sqlglot_types.from_bigframes_dtype(y.dtype))
644+
645+
646+
def _integer_label_to_datetime_op_yearly_freq(
647+
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
648+
) -> sge.Expression:
649+
n = op.freq.n
650+
one = sge.convert(1)
651+
first = sge.Extract(this="YEAR", expression=y.expr)
652+
x_val = sge.Add(
653+
this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first
654+
)
655+
next_year = sge.Add(this=x_val, expression=one) # type: ignore
656+
next_month_date = sge.func(
657+
"TIMESTAMP",
658+
sge.Anonymous(
659+
this="DATETIME",
660+
expressions=[
661+
next_year,
662+
one,
663+
one,
664+
sge.convert(0),
665+
sge.convert(0),
666+
sge.convert(0),
667+
],
668+
),
669+
)
670+
x_label = sge.Sub( # type: ignore
671+
this=next_month_date, expression=sge.Interval(this=one, unit="DAY")
672+
)
673+
return sge.Cast(this=x_label, to=sqlglot_types.from_bigframes_dtype(y.dtype))
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`rowindex`,
4+
`timestamp_col`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
*,
9+
CAST(TIMESTAMP_MICROS(
10+
CAST(CAST(`rowindex` AS BIGNUMERIC) * 86400000000 + CAST(UNIX_MICROS(CAST(`timestamp_col` AS TIMESTAMP)) AS BIGNUMERIC) AS INT64)
11+
) AS TIMESTAMP) AS `bfcol_2`,
12+
CAST(DATETIME(
13+
CASE
14+
WHEN (
15+
MOD(
16+
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
17+
4
18+
) + 1
19+
) * 3 = 12
20+
THEN CAST(FLOOR(
21+
IEEE_DIVIDE(
22+
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
23+
4
24+
)
25+
) AS INT64) + 1
26+
ELSE CAST(FLOOR(
27+
IEEE_DIVIDE(
28+
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
29+
4
30+
)
31+
) AS INT64)
32+
END,
33+
CASE
34+
WHEN (
35+
MOD(
36+
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
37+
4
38+
) + 1
39+
) * 3 = 12
40+
THEN 1
41+
ELSE (
42+
MOD(
43+
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
44+
4
45+
) + 1
46+
) * 3 + 1
47+
END,
48+
1,
49+
0,
50+
0,
51+
0
52+
) - INTERVAL 1 DAY AS TIMESTAMP) AS `bfcol_3`
53+
FROM `bfcte_0`
54+
)
55+
SELECT
56+
`bfcol_2` AS `fixed_freq`,
57+
`bfcol_3` AS `non_fixed_freq`
58+
FROM `bfcte_1`
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`rowindex`,
4+
`timestamp_col`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
*,
9+
CAST(TIMESTAMP_MICROS(
10+
CAST(CAST(`rowindex` AS BIGNUMERIC) * 86400000000 + CAST(UNIX_MICROS(CAST(`timestamp_col` AS TIMESTAMP)) AS BIGNUMERIC) AS INT64)
11+
) AS TIMESTAMP) AS `bfcol_2`
12+
FROM `bfcte_0`
13+
)
14+
SELECT
15+
`bfcol_2` AS `fixed_freq`
16+
FROM `bfcte_1`

0 commit comments

Comments
 (0)