Skip to content

Commit 700bfae

Browse files
committed
refactoring
1 parent 522b388 commit 700bfae

File tree

2 files changed

+185
-154
lines changed
  • bigframes/core/compile/sqlglot/expressions
  • tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_integer_label_to_datetime_week

2 files changed

+185
-154
lines changed

bigframes/core/compile/sqlglot/expressions/datetime_ops.py

Lines changed: 183 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -478,135 +478,99 @@ def _integer_label_to_datetime_op_non_fixed_frequency(
478478
from weeks to years.
479479
"""
480480
rule_code = op.freq.rule_code
481+
482+
if rule_code == "W-SUN":
483+
return _integer_label_to_datetime_op_weekly_freq(x, y, op)
484+
485+
if rule_code in ("ME", "M"):
486+
return _integer_label_to_datetime_op_monthly_freq(x, y, op)
487+
488+
if rule_code in ("QE-DEC", "Q-DEC"):
489+
return _integer_label_to_datetime_op_quarterly_freq(x, y, op)
490+
491+
if rule_code in ("YE-DEC", "A-DEC", "Y-DEC"):
492+
return _integer_label_to_datetime_op_yearly_freq(x, y, op)
493+
494+
raise ValueError(rule_code)
495+
496+
497+
def _integer_label_to_datetime_op_weekly_freq(
498+
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
499+
) -> sge.Expression:
481500
n = op.freq.n
482-
if rule_code == "W-SUN": # Weekly
483-
us = n * 7 * 24 * 60 * 60 * 1000000
484-
first = sge.func(
485-
"UNIX_MICROS",
486-
sge.Add(
487-
this=sge.TimestampTrunc(
488-
this=sge.Cast(this=y.expr, to="TIMESTAMP"),
489-
unit=sge.Var(this="WEEK(MONDAY)"),
490-
),
491-
expression=sge.Interval(
492-
this=sge.convert(6), unit=sge.Identifier(this="DAY")
493-
),
494-
),
495-
)
496-
x_label = sge.Cast(
497-
this=sge.func(
498-
"TIMESTAMP_MICROS",
499-
sge.Cast(
500-
this=sge.Add(
501-
this=sge.Mul(
502-
this=sge.Cast(this=x.expr, to="BIGNUMERIC"),
503-
expression=sge.convert(us),
504-
),
505-
expression=sge.Cast(this=first, to="BIGNUMERIC"),
506-
),
507-
to="INT64",
508-
),
509-
),
510-
to=sqlglot_types.from_bigframes_dtype(y.dtype),
511-
)
512-
elif rule_code in ("ME", "M"): # Monthly
513-
one = sge.convert(1)
514-
twelve = sge.convert(12)
515-
first = sge.Sub( # type: ignore
516-
this=sge.Add(
517-
this=sge.Mul(
518-
this=sge.Extract(this="YEAR", expression=y.expr),
519-
expression=twelve,
520-
),
521-
expression=sge.Extract(this="MONTH", expression=y.expr),
501+
# Calculate microseconds for the weekly interval.
502+
us = n * 7 * 24 * 60 * 60 * 1000000
503+
first = sge.func(
504+
"UNIX_MICROS",
505+
sge.Add(
506+
this=sge.TimestampTrunc(
507+
this=sge.Cast(this=y.expr, to="TIMESTAMP"),
508+
unit=sge.Var(this="WEEK(MONDAY)"),
522509
),
523-
expression=one,
524-
)
525-
x_val = sge.Add(
526-
this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first
527-
)
528-
year = sge.Cast(
529-
this=sge.Floor(this=sge.func("IEEE_DIVIDE", x_val, twelve)),
530-
to="INT64",
531-
)
532-
month = sge.Add(this=sge.Mod(this=x_val, expression=twelve), expression=one)
533-
next_year = sge.Case(
534-
ifs=[
535-
sge.If(
536-
this=sge.EQ(this=month, expression=twelve),
537-
true=sge.Add(this=year, expression=one),
538-
)
539-
],
540-
default=year,
541-
)
542-
next_month = sge.Case(
543-
ifs=[
544-
sge.If(
545-
this=sge.EQ(this=month, expression=twelve),
546-
true=one,
547-
)
548-
],
549-
default=sge.Add(this=month, expression=one),
550-
)
551-
next_month_date = sge.func(
552-
"TIMESTAMP",
553-
sge.Anonymous(
554-
this="DATETIME",
555-
expressions=[
556-
next_year,
557-
next_month,
558-
one,
559-
sge.convert(0),
560-
sge.convert(0),
561-
sge.convert(0),
562-
],
510+
expression=sge.Interval(
511+
this=sge.convert(6), unit=sge.Identifier(this="DAY")
563512
),
564-
)
565-
x_label = sge.Sub( # type: ignore
566-
this=next_month_date, expression=sge.Interval(this=one, unit="DAY")
567-
)
568-
elif rule_code in ("QE-DEC", "Q-DEC"): # Quarterly
569-
one = sge.convert(1)
570-
three = sge.convert(3)
571-
four = sge.convert(4)
572-
twelve = sge.convert(12)
573-
first = sge.Sub( # type: ignore
574-
this=sge.Add(
575-
this=sge.Mul(
576-
this=sge.Extract(this="YEAR", expression=y.expr),
577-
expression=four,
513+
),
514+
)
515+
return sge.Cast(
516+
this=sge.func(
517+
"TIMESTAMP_MICROS",
518+
sge.Cast(
519+
this=sge.Add(
520+
this=sge.Mul(
521+
this=sge.Cast(this=x.expr, to="BIGNUMERIC"),
522+
expression=sge.convert(us),
523+
),
524+
expression=sge.Cast(this=first, to="BIGNUMERIC"),
578525
),
579-
expression=sge.Extract(this="QUARTER", expression=y.expr),
526+
to="INT64",
580527
),
581-
expression=one,
582-
)
583-
x_val = sge.Add(
584-
this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first
585-
)
586-
year = sge.Cast(
587-
this=sge.Floor(this=sge.func("IEEE_DIVIDE", x_val, four)),
588-
to="INT64",
589-
)
590-
month = sge.Mul( # type: ignore
591-
this=sge.Paren(
592-
this=sge.Add(this=sge.Mod(this=x_val, expression=four), expression=one)
528+
),
529+
to=sqlglot_types.from_bigframes_dtype(y.dtype),
530+
)
531+
532+
533+
def _integer_label_to_datetime_op_monthly_freq(
534+
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
535+
) -> sge.Expression:
536+
n = op.freq.n
537+
one = sge.convert(1)
538+
twelve = sge.convert(12)
539+
first = sge.Sub( # type: ignore
540+
this=sge.Add(
541+
this=sge.Mul(
542+
this=sge.Extract(this="YEAR", expression=y.expr),
543+
expression=twelve,
593544
),
594-
expression=three,
595-
)
596-
next_year = sge.Case(
597-
ifs=[
598-
sge.If(
599-
this=sge.EQ(this=month, expression=twelve),
600-
true=sge.Add(this=year, expression=one),
601-
)
602-
],
603-
default=year,
604-
)
605-
next_month = sge.Case(
606-
ifs=[sge.If(this=sge.EQ(this=month, expression=twelve), true=one)],
607-
default=sge.Add(this=month, expression=one),
608-
)
609-
next_month_date = sge.Anonymous(
545+
expression=sge.Extract(this="MONTH", expression=y.expr),
546+
),
547+
expression=one,
548+
)
549+
x_val = sge.Add(
550+
this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first
551+
)
552+
year = sge.Cast(
553+
this=sge.Floor(this=sge.func("IEEE_DIVIDE", x_val, twelve)),
554+
to="INT64",
555+
)
556+
month = sge.Add(this=sge.Mod(this=x_val, expression=twelve), expression=one)
557+
558+
next_year = sge.Case(
559+
ifs=[
560+
sge.If(
561+
this=sge.EQ(this=month, expression=twelve),
562+
true=sge.Add(this=year, expression=one),
563+
)
564+
],
565+
default=year,
566+
)
567+
next_month = sge.Case(
568+
ifs=[sge.If(this=sge.EQ(this=month, expression=twelve), true=one)],
569+
default=sge.Add(this=month, expression=one),
570+
)
571+
next_month_date = sge.func(
572+
"TIMESTAMP",
573+
sge.Anonymous(
610574
this="DATETIME",
611575
expressions=[
612576
next_year,
@@ -616,34 +580,101 @@ def _integer_label_to_datetime_op_non_fixed_frequency(
616580
sge.convert(0),
617581
sge.convert(0),
618582
],
619-
)
620-
x_label = sge.Sub( # type: ignore
621-
this=next_month_date, expression=sge.Interval(this=one, unit="DAY")
622-
)
623-
elif rule_code in ("YE-DEC", "A-DEC", "Y-DEC"): # Yearly
624-
one = sge.convert(1)
625-
first = sge.Extract(this="YEAR", expression=y.expr)
626-
x_val = sge.Add(
627-
this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first
628-
)
629-
next_year = sge.Add(this=x_val, expression=one) # type: ignore
630-
next_month_date = sge.func(
631-
"TIMESTAMP",
632-
sge.Anonymous(
633-
this="DATETIME",
634-
expressions=[
635-
next_year,
636-
one,
637-
one,
638-
sge.convert(0),
639-
sge.convert(0),
640-
sge.convert(0),
641-
],
583+
),
584+
)
585+
x_label = sge.Sub( # type: ignore
586+
this=next_month_date, expression=sge.Interval(this=one, unit="DAY")
587+
)
588+
return sge.Cast(this=x_label, to=sqlglot_types.from_bigframes_dtype(y.dtype))
589+
590+
591+
def _integer_label_to_datetime_op_quarterly_freq(
592+
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
593+
) -> sge.Expression:
594+
n = op.freq.n
595+
one = sge.convert(1)
596+
three = sge.convert(3)
597+
four = sge.convert(4)
598+
twelve = sge.convert(12)
599+
first = sge.Sub( # type: ignore
600+
this=sge.Add(
601+
this=sge.Mul(
602+
this=sge.Extract(this="YEAR", expression=y.expr),
603+
expression=four,
642604
),
643-
)
644-
x_label = sge.Sub( # type: ignore
645-
this=next_month_date, expression=sge.Interval(this=one, unit="DAY")
646-
)
647-
else:
648-
raise ValueError(rule_code)
605+
expression=sge.Extract(this="QUARTER", expression=y.expr),
606+
),
607+
expression=one,
608+
)
609+
x_val = sge.Add(
610+
this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first
611+
)
612+
year = sge.Cast(
613+
this=sge.Floor(this=sge.func("IEEE_DIVIDE", x_val, four)),
614+
to="INT64",
615+
)
616+
month = sge.Mul( # type: ignore
617+
this=sge.Paren(
618+
this=sge.Add(this=sge.Mod(this=x_val, expression=four), expression=one)
619+
),
620+
expression=three,
621+
)
622+
623+
next_year = sge.Case(
624+
ifs=[
625+
sge.If(
626+
this=sge.EQ(this=month, expression=twelve),
627+
true=sge.Add(this=year, expression=one),
628+
)
629+
],
630+
default=year,
631+
)
632+
next_month = sge.Case(
633+
ifs=[sge.If(this=sge.EQ(this=month, expression=twelve), true=one)],
634+
default=sge.Add(this=month, expression=one),
635+
)
636+
next_month_date = sge.Anonymous(
637+
this="DATETIME",
638+
expressions=[
639+
next_year,
640+
next_month,
641+
one,
642+
sge.convert(0),
643+
sge.convert(0),
644+
sge.convert(0),
645+
],
646+
)
647+
x_label = sge.Sub( # type: ignore
648+
this=next_month_date, expression=sge.Interval(this=one, unit="DAY")
649+
)
650+
return sge.Cast(this=x_label, to=sqlglot_types.from_bigframes_dtype(y.dtype))
651+
652+
653+
def _integer_label_to_datetime_op_yearly_freq(
654+
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
655+
) -> sge.Expression:
656+
n = op.freq.n
657+
one = sge.convert(1)
658+
first = sge.Extract(this="YEAR", expression=y.expr)
659+
x_val = sge.Add(
660+
this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first
661+
)
662+
next_year = sge.Add(this=x_val, expression=one) # type: ignore
663+
next_month_date = sge.func(
664+
"TIMESTAMP",
665+
sge.Anonymous(
666+
this="DATETIME",
667+
expressions=[
668+
next_year,
669+
one,
670+
one,
671+
sge.convert(0),
672+
sge.convert(0),
673+
sge.convert(0),
674+
],
675+
),
676+
)
677+
x_label = sge.Sub( # type: ignore
678+
this=next_month_date, expression=sge.Interval(this=one, unit="DAY")
679+
)
649680
return sge.Cast(this=x_label, to=sqlglot_types.from_bigframes_dtype(y.dtype))

tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_integer_label_to_datetime_week/out.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ WITH `bfcte_0` AS (
66
), `bfcte_1` AS (
77
SELECT
88
*,
9-
CAST(CAST(TIMESTAMP_MICROS(
9+
CAST(TIMESTAMP_MICROS(
1010
CAST(CAST(`rowindex` AS BIGNUMERIC) * 604800000000 + CAST(UNIX_MICROS(
1111
TIMESTAMP_TRUNC(CAST(`timestamp_col` AS TIMESTAMP), WEEK(MONDAY)) + INTERVAL 6 DAY
1212
) AS BIGNUMERIC) AS INT64)
13-
) AS TIMESTAMP) AS TIMESTAMP) AS `bfcol_2`
13+
) AS TIMESTAMP) AS `bfcol_2`
1414
FROM `bfcte_0`
1515
)
1616
SELECT

0 commit comments

Comments
 (0)