Skip to content

Commit 787e47c

Browse files
committed
chore: Migrate DatetimeToIntegerLabelOp operator to SQLGlot
1 parent 6e73d77 commit 787e47c

File tree

3 files changed

+377
-0
lines changed

3 files changed

+377
-0
lines changed

bigframes/core/compile/sqlglot/expressions/datetime_ops.py

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,272 @@
2121
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2222

2323
register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op
24+
register_binary_op = scalar_compiler.scalar_op_compiler.register_binary_op
25+
26+
27+
def _calculate_resample_first(y: TypedExpr, origin: str) -> sge.Expression:
28+
if origin == "epoch":
29+
return sge.convert(0)
30+
elif origin == "start_day":
31+
return sge.func(
32+
"UNIX_MICROS",
33+
sge.Cast(
34+
this=sge.Cast(
35+
this=y.expr, to=sge.DataType(this=sge.DataType.Type.DATE)
36+
),
37+
to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ),
38+
),
39+
)
40+
elif origin == "start":
41+
return sge.func(
42+
"UNIX_MICROS",
43+
sge.Cast(this=y.expr, to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ)),
44+
)
45+
else:
46+
raise ValueError(f"Origin {origin} not supported")
47+
48+
49+
@register_binary_op(ops.DatetimeToIntegerLabelOp, pass_op=True)
50+
def datetime_to_integer_label_op(
51+
x: TypedExpr, y: TypedExpr, op: ops.DatetimeToIntegerLabelOp
52+
) -> sge.Expression:
53+
# Determine if the frequency is fixed by checking if 'op.freq.nanos' is defined.
54+
try:
55+
return _datetime_to_integer_label_fixed_frequency(x, y, op)
56+
except ValueError:
57+
return _datetime_to_integer_label_non_fixed_frequency(x, y, op)
58+
59+
60+
def _datetime_to_integer_label_fixed_frequency(
61+
x: TypedExpr, y: TypedExpr, op: ops.DatetimeToIntegerLabelOp
62+
) -> sge.Expression:
63+
"""
64+
This function handles fixed frequency conversions where the unit can range
65+
from microseconds (us) to days.
66+
"""
67+
us = op.freq.nanos / 1000
68+
x_int = sge.func(
69+
"UNIX_MICROS",
70+
sge.Cast(this=x.expr, to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ)),
71+
)
72+
first = _calculate_resample_first(y, op.origin)
73+
x_int_label = sge.Cast(
74+
this=sge.Floor(
75+
this=sge.func(
76+
"IEEE_DIVIDE",
77+
sge.Sub(this=x_int, expression=first),
78+
sge.convert(int(us)),
79+
)
80+
),
81+
to=sge.DataType.build("INT64"),
82+
)
83+
return x_int_label
84+
85+
86+
def _datetime_to_integer_label_non_fixed_frequency(
87+
x: TypedExpr, y: TypedExpr, op: ops.DatetimeToIntegerLabelOp
88+
) -> sge.Expression:
89+
"""
90+
This function handles non-fixed frequency conversions for units ranging
91+
from weeks to years.
92+
"""
93+
rule_code = op.freq.rule_code
94+
n = op.freq.n
95+
if rule_code == "W-SUN": # Weekly
96+
us = n * 7 * 24 * 60 * 60 * 1000000
97+
x_trunc = sge.TimestampTrunc(this=x.expr, unit=sge.Var(this="WEEK(MONDAY)"))
98+
y_trunc = sge.TimestampTrunc(this=y.expr, unit=sge.Var(this="WEEK(MONDAY)"))
99+
x_plus_6 = sge.Add(
100+
this=x_trunc,
101+
expression=sge.Interval(
102+
this=sge.convert(6), unit=sge.Identifier(this="DAY")
103+
),
104+
)
105+
y_plus_6 = sge.Add(
106+
this=y_trunc,
107+
expression=sge.Interval(
108+
this=sge.convert(6), unit=sge.Identifier(this="DAY")
109+
),
110+
)
111+
x_int = sge.func(
112+
"UNIX_MICROS",
113+
sge.Cast(
114+
this=x_plus_6, to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ)
115+
),
116+
)
117+
first = sge.func(
118+
"UNIX_MICROS",
119+
sge.Cast(
120+
this=y_plus_6, to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ)
121+
),
122+
)
123+
return sge.Case(
124+
ifs=[
125+
sge.If(
126+
this=sge.EQ(this=x_int, expression=first),
127+
true=sge.convert(0),
128+
)
129+
],
130+
default=sge.Add(
131+
this=sge.Cast(
132+
this=sge.Floor(
133+
this=sge.func(
134+
"IEEE_DIVIDE",
135+
sge.Sub(
136+
this=sge.Sub(this=x_int, expression=first),
137+
expression=sge.convert(1),
138+
),
139+
sge.convert(us),
140+
)
141+
),
142+
to=sge.DataType.build("INT64"),
143+
),
144+
expression=sge.convert(1),
145+
),
146+
)
147+
elif rule_code == "ME": # Monthly
148+
x_int = sge.Paren(
149+
this=sge.Add(
150+
this=sge.Mul(
151+
this=sge.Extract(
152+
this=sge.Identifier(this="YEAR"), expression=x.expr
153+
),
154+
expression=sge.convert(12),
155+
),
156+
expression=sge.Sub(
157+
this=sge.Extract(
158+
this=sge.Identifier(this="MONTH"), expression=x.expr
159+
),
160+
expression=sge.convert(1),
161+
),
162+
)
163+
)
164+
first = sge.Paren(
165+
this=sge.Add(
166+
this=sge.Mul(
167+
this=sge.Extract(
168+
this=sge.Identifier(this="YEAR"), expression=y.expr
169+
),
170+
expression=sge.convert(12),
171+
),
172+
expression=sge.Sub(
173+
this=sge.Extract(
174+
this=sge.Identifier(this="MONTH"), expression=y.expr
175+
),
176+
expression=sge.convert(1),
177+
),
178+
)
179+
)
180+
return sge.Case(
181+
ifs=[
182+
sge.If(
183+
this=sge.EQ(this=x_int, expression=first),
184+
true=sge.convert(0),
185+
)
186+
],
187+
default=sge.Add(
188+
this=sge.Cast(
189+
this=sge.Floor(
190+
this=sge.func(
191+
"IEEE_DIVIDE",
192+
sge.Sub(
193+
this=sge.Sub(this=x_int, expression=first),
194+
expression=sge.convert(1),
195+
),
196+
sge.convert(n),
197+
)
198+
),
199+
to=sge.DataType.build("INT64"),
200+
),
201+
expression=sge.convert(1),
202+
),
203+
)
204+
elif rule_code == "QE-DEC": # Quarterly
205+
x_int = sge.Paren(
206+
this=sge.Add(
207+
this=sge.Mul(
208+
this=sge.Extract(
209+
this=sge.Identifier(this="YEAR"), expression=x.expr
210+
),
211+
expression=sge.convert(4),
212+
),
213+
expression=sge.Sub(
214+
this=sge.Extract(
215+
this=sge.Identifier(this="QUARTER"), expression=x.expr
216+
),
217+
expression=sge.convert(1),
218+
),
219+
)
220+
)
221+
first = sge.Paren(
222+
this=sge.Add(
223+
this=sge.Mul(
224+
this=sge.Extract(
225+
this=sge.Identifier(this="YEAR"), expression=y.expr
226+
),
227+
expression=sge.convert(4),
228+
),
229+
expression=sge.Sub(
230+
this=sge.Extract(
231+
this=sge.Identifier(this="QUARTER"), expression=y.expr
232+
),
233+
expression=sge.convert(1),
234+
),
235+
)
236+
)
237+
return sge.Case(
238+
ifs=[
239+
sge.If(
240+
this=sge.EQ(this=x_int, expression=first),
241+
true=sge.convert(0),
242+
)
243+
],
244+
default=sge.Add(
245+
this=sge.Cast(
246+
this=sge.Floor(
247+
this=sge.func(
248+
"IEEE_DIVIDE",
249+
sge.Sub(
250+
this=sge.Sub(this=x_int, expression=first),
251+
expression=sge.convert(1),
252+
),
253+
sge.convert(n),
254+
)
255+
),
256+
to=sge.DataType.build("INT64"),
257+
),
258+
expression=sge.convert(1),
259+
),
260+
)
261+
elif rule_code == "YE-DEC": # Yearly
262+
x_int = sge.Extract(this=sge.Identifier(this="YEAR"), expression=x.expr)
263+
first = sge.Extract(this=sge.Identifier(this="YEAR"), expression=y.expr)
264+
return sge.Case(
265+
ifs=[
266+
sge.If(
267+
this=sge.EQ(this=x_int, expression=first),
268+
true=sge.convert(0),
269+
)
270+
],
271+
default=sge.Add(
272+
this=sge.Cast(
273+
this=sge.Floor(
274+
this=sge.func(
275+
"IEEE_DIVIDE",
276+
sge.Sub(
277+
this=sge.Sub(this=x_int, expression=first),
278+
expression=sge.convert(1),
279+
),
280+
sge.convert(n),
281+
)
282+
),
283+
to=sge.DataType.build("INT64"),
284+
),
285+
expression=sge.convert(1),
286+
),
287+
)
288+
else:
289+
raise ValueError(rule_code)
24290

25291

26292
@register_unary_op(ops.FloorDtOp, pass_op=True)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`datetime_col`,
4+
`timestamp_col`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
*,
9+
CAST(FLOOR(
10+
IEEE_DIVIDE(
11+
UNIX_MICROS(CAST(`datetime_col` AS TIMESTAMP)) - UNIX_MICROS(CAST(`timestamp_col` AS TIMESTAMP)),
12+
86400000000
13+
)
14+
) AS INT64) AS `bfcol_2`,
15+
CASE
16+
WHEN (
17+
EXTRACT(YEAR FROM `datetime_col`) * 12 + EXTRACT(MONTH FROM `datetime_col`) - 1
18+
) = (
19+
EXTRACT(YEAR FROM `timestamp_col`) * 12 + EXTRACT(MONTH FROM `timestamp_col`) - 1
20+
)
21+
THEN 0
22+
ELSE CAST(FLOOR(
23+
IEEE_DIVIDE(
24+
(
25+
EXTRACT(YEAR FROM `datetime_col`) * 12 + EXTRACT(MONTH FROM `datetime_col`) - 1
26+
) - (
27+
EXTRACT(YEAR FROM `timestamp_col`) * 12 + EXTRACT(MONTH FROM `timestamp_col`) - 1
28+
) - 1,
29+
1
30+
)
31+
) AS INT64) + 1
32+
END AS `bfcol_3`,
33+
CASE
34+
WHEN UNIX_MICROS(
35+
CAST(TIMESTAMP_TRUNC(`datetime_col`, WEEK(MONDAY)) + INTERVAL 6 DAY AS TIMESTAMP)
36+
) = UNIX_MICROS(
37+
CAST(TIMESTAMP_TRUNC(`timestamp_col`, WEEK(MONDAY)) + INTERVAL 6 DAY AS TIMESTAMP)
38+
)
39+
THEN 0
40+
ELSE CAST(FLOOR(
41+
IEEE_DIVIDE(
42+
UNIX_MICROS(
43+
CAST(TIMESTAMP_TRUNC(`datetime_col`, WEEK(MONDAY)) + INTERVAL 6 DAY AS TIMESTAMP)
44+
) - UNIX_MICROS(
45+
CAST(TIMESTAMP_TRUNC(`timestamp_col`, WEEK(MONDAY)) + INTERVAL 6 DAY AS TIMESTAMP)
46+
) - 1,
47+
604800000000
48+
)
49+
) AS INT64) + 1
50+
END AS `bfcol_4`,
51+
CASE
52+
WHEN (
53+
EXTRACT(YEAR FROM `datetime_col`) * 4 + EXTRACT(QUARTER FROM `datetime_col`) - 1
54+
) = (
55+
EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1
56+
)
57+
THEN 0
58+
ELSE CAST(FLOOR(
59+
IEEE_DIVIDE(
60+
(
61+
EXTRACT(YEAR FROM `datetime_col`) * 4 + EXTRACT(QUARTER FROM `datetime_col`) - 1
62+
) - (
63+
EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1
64+
) - 1,
65+
1
66+
)
67+
) AS INT64) + 1
68+
END AS `bfcol_5`,
69+
CASE
70+
WHEN EXTRACT(YEAR FROM `datetime_col`) = EXTRACT(YEAR FROM `timestamp_col`)
71+
THEN 0
72+
ELSE CAST(FLOOR(
73+
IEEE_DIVIDE(EXTRACT(YEAR FROM `datetime_col`) - EXTRACT(YEAR FROM `timestamp_col`) - 1, 1)
74+
) AS INT64) + 1
75+
END AS `bfcol_6`
76+
FROM `bfcte_0`
77+
)
78+
SELECT
79+
`bfcol_2` AS `fixed_freq`,
80+
`bfcol_3` AS `non_fixed_freq_monthly`,
81+
`bfcol_4` AS `non_fixed_freq_weekly`,
82+
`bfcol_5` AS `non_fixed_freq_quarterly`,
83+
`bfcol_6` AS `non_fixed_freq_yearly`
84+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,33 @@ def test_dayofyear(scalar_types_df: bpd.DataFrame, snapshot):
5757
snapshot.assert_match(sql, "out.sql")
5858

5959

60+
def test_datetime_to_integer_label(scalar_types_df: bpd.DataFrame, snapshot):
61+
col_names = ["datetime_col", "timestamp_col"]
62+
bf_df = scalar_types_df[col_names]
63+
ops_map = {
64+
"fixed_freq": ops.DatetimeToIntegerLabelOp(
65+
freq=pd.tseries.offsets.Day(), origin="start", closed="left"
66+
).as_expr("datetime_col", "timestamp_col"),
67+
"non_fixed_freq_monthly": ops.DatetimeToIntegerLabelOp(
68+
freq=pd.tseries.offsets.MonthEnd(), origin="start", closed="left"
69+
).as_expr("datetime_col", "timestamp_col"),
70+
"non_fixed_freq_weekly": ops.DatetimeToIntegerLabelOp(
71+
freq=pd.tseries.offsets.Week(weekday=6), origin="start", closed="left"
72+
).as_expr("datetime_col", "timestamp_col"),
73+
"non_fixed_freq_quarterly": ops.DatetimeToIntegerLabelOp(
74+
freq=pd.tseries.offsets.QuarterEnd(startingMonth=12),
75+
origin="start",
76+
closed="left",
77+
).as_expr("datetime_col", "timestamp_col"),
78+
"non_fixed_freq_yearly": ops.DatetimeToIntegerLabelOp(
79+
freq=pd.tseries.offsets.YearEnd(month=12), origin="start", closed="left"
80+
).as_expr("datetime_col", "timestamp_col"),
81+
}
82+
83+
sql = utils._apply_ops_to_sql(bf_df, list(ops_map.values()), list(ops_map.keys()))
84+
snapshot.assert_match(sql, "out.sql")
85+
86+
6087
def test_floor_dt(scalar_types_df: bpd.DataFrame, snapshot):
6188
col_names = ["datetime_col", "timestamp_col", "date_col"]
6289
bf_df = scalar_types_df[col_names]

0 commit comments

Comments
 (0)