diff --git a/bigframes/core/compile/sqlglot/expressions/datetime_ops.py b/bigframes/core/compile/sqlglot/expressions/datetime_ops.py index e20d2da567..9717f0fb11 100644 --- a/bigframes/core/compile/sqlglot/expressions/datetime_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/datetime_ops.py @@ -19,6 +19,7 @@ from bigframes import dtypes from bigframes import operations as ops from bigframes.core.compile.constants import UNIT_TO_US_CONVERSION_FACTORS +from bigframes.core.compile.sqlglot import sqlglot_types from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler @@ -26,28 +27,6 @@ register_binary_op = scalar_compiler.scalar_op_compiler.register_binary_op -def _calculate_resample_first(y: TypedExpr, origin: str) -> sge.Expression: - if origin == "epoch": - return sge.convert(0) - elif origin == "start_day": - return sge.func( - "UNIX_MICROS", - sge.Cast( - this=sge.Cast( - this=y.expr, to=sge.DataType(this=sge.DataType.Type.DATE) - ), - to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ), - ), - ) - elif origin == "start": - return sge.func( - "UNIX_MICROS", - sge.Cast(this=y.expr, to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ)), - ) - else: - raise ValueError(f"Origin {origin} not supported") - - @register_binary_op(ops.DatetimeToIntegerLabelOp, pass_op=True) def datetime_to_integer_label_op( x: TypedExpr, y: TypedExpr, op: ops.DatetimeToIntegerLabelOp @@ -317,6 +296,20 @@ def _(expr: TypedExpr, op: ops.FloorDtOp) -> sge.Expression: return sge.TimestampTrunc(this=expr.expr, unit=sge.Identifier(this=bq_freq)) +def _calculate_resample_first(y: TypedExpr, origin: str) -> sge.Expression: + if origin == "epoch": + return sge.convert(0) + elif origin == "start_day": + return sge.func( + "UNIX_MICROS", + sge.Cast(this=sge.Cast(this=y.expr, to="DATE"), to="TIMESTAMP"), + ) + elif origin == "start": + return sge.func("UNIX_MICROS", sge.Cast(this=y.expr, to="TIMESTAMP")) + else: + raise ValueError(f"Origin {origin} not supported") + + @register_unary_op(ops.hour_op) def _(expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="HOUR"), expression=expr.expr) @@ -436,3 +429,221 @@ def _(expr: TypedExpr, op: ops.UnixSeconds) -> sge.Expression: @register_unary_op(ops.year_op) def _(expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="YEAR"), expression=expr.expr) + + +@register_binary_op(ops.IntegerLabelToDatetimeOp, pass_op=True) +def integer_label_to_datetime_op( + x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp +) -> sge.Expression: + # Determine if the frequency is fixed by checking if 'op.freq.nanos' is defined. + try: + return _integer_label_to_datetime_op_fixed_frequency(x, y, op) + except ValueError: + return _integer_label_to_datetime_op_non_fixed_frequency(x, y, op) + + +def _integer_label_to_datetime_op_fixed_frequency( + x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp +) -> sge.Expression: + """ + This function handles fixed frequency conversions where the unit can range + from microseconds (us) to days. + """ + us = op.freq.nanos / 1000 + first = _calculate_resample_first(y, op.origin) # type: ignore + x_label = sge.Cast( + this=sge.func( + "TIMESTAMP_MICROS", + sge.Cast( + this=sge.Add( + this=sge.Mul( + this=sge.Cast(this=x.expr, to="BIGNUMERIC"), + expression=sge.convert(int(us)), + ), + expression=sge.Cast(this=first, to="BIGNUMERIC"), + ), + to="INT64", + ), + ), + to=sqlglot_types.from_bigframes_dtype(y.dtype), + ) + return x_label + + +def _integer_label_to_datetime_op_non_fixed_frequency( + x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp +) -> sge.Expression: + """ + This function handles non-fixed frequency conversions for units ranging + from weeks to years. + """ + rule_code = op.freq.rule_code + n = op.freq.n + if rule_code == "W-SUN": # Weekly + us = n * 7 * 24 * 60 * 60 * 1000000 + first = sge.func( + "UNIX_MICROS", + sge.Add( + this=sge.TimestampTrunc( + this=sge.Cast(this=y.expr, to="TIMESTAMP"), + unit=sge.Var(this="WEEK(MONDAY)"), + ), + expression=sge.Interval( + this=sge.convert(6), unit=sge.Identifier(this="DAY") + ), + ), + ) + x_label = sge.Cast( + this=sge.func( + "TIMESTAMP_MICROS", + sge.Cast( + this=sge.Add( + this=sge.Mul( + this=sge.Cast(this=x.expr, to="BIGNUMERIC"), + expression=sge.convert(us), + ), + expression=sge.Cast(this=first, to="BIGNUMERIC"), + ), + to="INT64", + ), + ), + to=sqlglot_types.from_bigframes_dtype(y.dtype), + ) + elif rule_code in ("ME", "M"): # Monthly + one = sge.convert(1) + twelve = sge.convert(12) + first = sge.Sub( # type: ignore + this=sge.Add( + this=sge.Mul( + this=sge.Extract(this="YEAR", expression=y.expr), + expression=twelve, + ), + expression=sge.Extract(this="MONTH", expression=y.expr), + ), + expression=one, + ) + x_val = sge.Add( + this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first + ) + year = sge.Cast( + this=sge.Floor(this=sge.func("IEEE_DIVIDE", x_val, twelve)), + to="INT64", + ) + month = sge.Add(this=sge.Mod(this=x_val, expression=twelve), expression=one) + next_year = sge.Case( + ifs=[ + sge.If( + this=sge.EQ(this=month, expression=twelve), + true=sge.Add(this=year, expression=one), + ) + ], + default=year, + ) + next_month = sge.Case( + ifs=[ + sge.If( + this=sge.EQ(this=month, expression=twelve), + true=one, + ) + ], + default=sge.Add(this=month, expression=one), + ) + next_month_date = sge.func( + "TIMESTAMP", + sge.Anonymous( + this="DATETIME", + expressions=[ + next_year, + next_month, + one, + sge.convert(0), + sge.convert(0), + sge.convert(0), + ], + ), + ) + x_label = sge.Sub( # type: ignore + this=next_month_date, expression=sge.Interval(this=one, unit="DAY") + ) + elif rule_code in ("QE-DEC", "Q-DEC"): # Quarterly + one = sge.convert(1) + three = sge.convert(3) + four = sge.convert(4) + twelve = sge.convert(12) + first = sge.Sub( # type: ignore + this=sge.Add( + this=sge.Mul( + this=sge.Extract(this="YEAR", expression=y.expr), + expression=four, + ), + expression=sge.Extract(this="QUARTER", expression=y.expr), + ), + expression=one, + ) + x_val = sge.Add( + this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first + ) + year = sge.Cast( + this=sge.Floor(this=sge.func("IEEE_DIVIDE", x_val, four)), + to="INT64", + ) + month = sge.Mul( # type: ignore + this=sge.Paren( + this=sge.Add(this=sge.Mod(this=x_val, expression=four), expression=one) + ), + expression=three, + ) + next_year = sge.Case( + ifs=[ + sge.If( + this=sge.EQ(this=month, expression=twelve), + true=sge.Add(this=year, expression=one), + ) + ], + default=year, + ) + next_month = sge.Case( + ifs=[sge.If(this=sge.EQ(this=month, expression=twelve), true=one)], + default=sge.Add(this=month, expression=one), + ) + next_month_date = sge.Anonymous( + this="DATETIME", + expressions=[ + next_year, + next_month, + one, + sge.convert(0), + sge.convert(0), + sge.convert(0), + ], + ) + x_label = sge.Sub( # type: ignore + this=next_month_date, expression=sge.Interval(this=one, unit="DAY") + ) + elif rule_code in ("YE-DEC", "A-DEC", "Y-DEC"): # Yearly + one = sge.convert(1) + first = sge.Extract(this="YEAR", expression=y.expr) + x_val = sge.Add( + this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first + ) + next_year = sge.Add(this=x_val, expression=one) # type: ignore + next_month_date = sge.func( + "TIMESTAMP", + sge.Anonymous( + this="DATETIME", + expressions=[ + next_year, + one, + one, + sge.convert(0), + sge.convert(0), + sge.convert(0), + ], + ), + ) + x_label = sge.Sub( # type: ignore + this=next_month_date, expression=sge.Interval(this=one, unit="DAY") + ) + else: + raise ValueError(rule_code) + return sge.Cast(this=x_label, to=sqlglot_types.from_bigframes_dtype(y.dtype)) diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_integer_label_to_datetime/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_integer_label_to_datetime/out.sql new file mode 100644 index 0000000000..2a1bd0e2e2 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_integer_label_to_datetime/out.sql @@ -0,0 +1,58 @@ +WITH `bfcte_0` AS ( + SELECT + `rowindex`, + `timestamp_col` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CAST(TIMESTAMP_MICROS( + CAST(CAST(`rowindex` AS BIGNUMERIC) * 86400000000 + CAST(UNIX_MICROS(CAST(`timestamp_col` AS TIMESTAMP)) AS BIGNUMERIC) AS INT64) + ) AS TIMESTAMP) AS `bfcol_2`, + CAST(DATETIME( + CASE + WHEN ( + MOD( + `rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1, + 4 + ) + 1 + ) * 3 = 12 + THEN CAST(FLOOR( + IEEE_DIVIDE( + `rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1, + 4 + ) + ) AS INT64) + 1 + ELSE CAST(FLOOR( + IEEE_DIVIDE( + `rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1, + 4 + ) + ) AS INT64) + END, + CASE + WHEN ( + MOD( + `rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1, + 4 + ) + 1 + ) * 3 = 12 + THEN 1 + ELSE ( + MOD( + `rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1, + 4 + ) + 1 + ) * 3 + 1 + END, + 1, + 0, + 0, + 0 + ) - INTERVAL 1 DAY AS TIMESTAMP) AS `bfcol_3` + FROM `bfcte_0` +) +SELECT + `bfcol_2` AS `fixed_freq`, + `bfcol_3` AS `non_fixed_freq` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_integer_label_to_datetime_fixed/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_integer_label_to_datetime_fixed/out.sql new file mode 100644 index 0000000000..8a759e85f9 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_integer_label_to_datetime_fixed/out.sql @@ -0,0 +1,16 @@ +WITH `bfcte_0` AS ( + SELECT + `rowindex`, + `timestamp_col` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CAST(TIMESTAMP_MICROS( + CAST(CAST(`rowindex` AS BIGNUMERIC) * 86400000000 + CAST(UNIX_MICROS(CAST(`timestamp_col` AS TIMESTAMP)) AS BIGNUMERIC) AS INT64) + ) AS TIMESTAMP) AS `bfcol_2` + FROM `bfcte_0` +) +SELECT + `bfcol_2` AS `fixed_freq` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_integer_label_to_datetime_month/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_integer_label_to_datetime_month/out.sql new file mode 100644 index 0000000000..a9e64fead6 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_integer_label_to_datetime_month/out.sql @@ -0,0 +1,50 @@ +WITH `bfcte_0` AS ( + SELECT + `rowindex`, + `timestamp_col` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CAST(TIMESTAMP( + DATETIME( + CASE + WHEN MOD( + `rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 12 + EXTRACT(MONTH FROM `timestamp_col`) - 1, + 12 + ) + 1 = 12 + THEN CAST(FLOOR( + IEEE_DIVIDE( + `rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 12 + EXTRACT(MONTH FROM `timestamp_col`) - 1, + 12 + ) + ) AS INT64) + 1 + ELSE CAST(FLOOR( + IEEE_DIVIDE( + `rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 12 + EXTRACT(MONTH FROM `timestamp_col`) - 1, + 12 + ) + ) AS INT64) + END, + CASE + WHEN MOD( + `rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 12 + EXTRACT(MONTH FROM `timestamp_col`) - 1, + 12 + ) + 1 = 12 + THEN 1 + ELSE MOD( + `rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 12 + EXTRACT(MONTH FROM `timestamp_col`) - 1, + 12 + ) + 1 + 1 + END, + 1, + 0, + 0, + 0 + ) + ) - INTERVAL 1 DAY AS TIMESTAMP) AS `bfcol_2` + FROM `bfcte_0` +) +SELECT + `bfcol_2` AS `non_fixed_freq_monthly` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_integer_label_to_datetime_quarter/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_integer_label_to_datetime_quarter/out.sql new file mode 100644 index 0000000000..58064855a9 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_integer_label_to_datetime_quarter/out.sql @@ -0,0 +1,54 @@ +WITH `bfcte_0` AS ( + SELECT + `rowindex`, + `timestamp_col` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CAST(DATETIME( + CASE + WHEN ( + MOD( + `rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1, + 4 + ) + 1 + ) * 3 = 12 + THEN CAST(FLOOR( + IEEE_DIVIDE( + `rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1, + 4 + ) + ) AS INT64) + 1 + ELSE CAST(FLOOR( + IEEE_DIVIDE( + `rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1, + 4 + ) + ) AS INT64) + END, + CASE + WHEN ( + MOD( + `rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1, + 4 + ) + 1 + ) * 3 = 12 + THEN 1 + ELSE ( + MOD( + `rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1, + 4 + ) + 1 + ) * 3 + 1 + END, + 1, + 0, + 0, + 0 + ) - INTERVAL 1 DAY AS TIMESTAMP) AS `bfcol_2` + FROM `bfcte_0` +) +SELECT + `bfcol_2` AS `non_fixed_freq` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_integer_label_to_datetime_week/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_integer_label_to_datetime_week/out.sql new file mode 100644 index 0000000000..1b1e2a119a --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_integer_label_to_datetime_week/out.sql @@ -0,0 +1,18 @@ +WITH `bfcte_0` AS ( + SELECT + `rowindex`, + `timestamp_col` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CAST(CAST(TIMESTAMP_MICROS( + CAST(CAST(`rowindex` AS BIGNUMERIC) * 604800000000 + CAST(UNIX_MICROS( + TIMESTAMP_TRUNC(CAST(`timestamp_col` AS TIMESTAMP), WEEK(MONDAY)) + INTERVAL 6 DAY + ) AS BIGNUMERIC) AS INT64) + ) AS TIMESTAMP) AS TIMESTAMP) AS `bfcol_2` + FROM `bfcte_0` +) +SELECT + `bfcol_2` AS `non_fixed_freq_weekly` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_integer_label_to_datetime_year/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_integer_label_to_datetime_year/out.sql new file mode 100644 index 0000000000..ab77a9d190 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_integer_label_to_datetime_year/out.sql @@ -0,0 +1,14 @@ +WITH `bfcte_0` AS ( + SELECT + `rowindex`, + `timestamp_col` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CAST(TIMESTAMP(DATETIME(`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) + 1, 1, 1, 0, 0, 0)) - INTERVAL 1 DAY AS TIMESTAMP) AS `bfcol_2` + FROM `bfcte_0` +) +SELECT + `bfcol_2` AS `non_fixed_freq_yearly` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py index c4acb37e51..95156748e9 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py @@ -293,3 +293,74 @@ def test_sub_timedelta(scalar_types_df: bpd.DataFrame, snapshot): bf_df["timedelta_sub_timedelta"] = bf_df["duration_col"] - bf_df["duration_col"] snapshot.assert_match(bf_df.sql, "out.sql") + + +def test_integer_label_to_datetime_fixed(scalar_types_df: bpd.DataFrame, snapshot): + col_names = ["rowindex", "timestamp_col"] + bf_df = scalar_types_df[col_names] + ops_map = { + "fixed_freq": ops.IntegerLabelToDatetimeOp( + freq=pd.tseries.offsets.Day(), origin="start", label="left" # type: ignore + ).as_expr("rowindex", "timestamp_col"), + } + + sql = utils._apply_ops_to_sql(bf_df, list(ops_map.values()), list(ops_map.keys())) + snapshot.assert_match(sql, "out.sql") + + +def test_integer_label_to_datetime_week(scalar_types_df: bpd.DataFrame, snapshot): + col_names = ["rowindex", "timestamp_col"] + bf_df = scalar_types_df[col_names] + ops_map = { + "non_fixed_freq_weekly": ops.IntegerLabelToDatetimeOp( + freq=pd.tseries.offsets.Week(weekday=6), origin="start", label="left" # type: ignore + ).as_expr("rowindex", "timestamp_col"), + } + + sql = utils._apply_ops_to_sql(bf_df, list(ops_map.values()), list(ops_map.keys())) + snapshot.assert_match(sql, "out.sql") + + +def test_integer_label_to_datetime_month(scalar_types_df: bpd.DataFrame, snapshot): + col_names = ["rowindex", "timestamp_col"] + bf_df = scalar_types_df[col_names] + ops_map = { + "non_fixed_freq_monthly": ops.IntegerLabelToDatetimeOp( + freq=pd.tseries.offsets.MonthEnd(), # type: ignore + origin="start", + label="left", + ).as_expr("rowindex", "timestamp_col"), + } + + sql = utils._apply_ops_to_sql(bf_df, list(ops_map.values()), list(ops_map.keys())) + snapshot.assert_match(sql, "out.sql") + + +def test_integer_label_to_datetime_quarter(scalar_types_df: bpd.DataFrame, snapshot): + col_names = ["rowindex", "timestamp_col"] + bf_df = scalar_types_df[col_names] + ops_map = { + "non_fixed_freq": ops.IntegerLabelToDatetimeOp( + freq=pd.tseries.offsets.QuarterEnd(startingMonth=12), # type: ignore + origin="start", + label="left", + ).as_expr("rowindex", "timestamp_col"), + } + + sql = utils._apply_ops_to_sql(bf_df, list(ops_map.values()), list(ops_map.keys())) + snapshot.assert_match(sql, "out.sql") + + +def test_integer_label_to_datetime_year(scalar_types_df: bpd.DataFrame, snapshot): + col_names = ["rowindex", "timestamp_col"] + bf_df = scalar_types_df[col_names] + ops_map = { + "non_fixed_freq_yearly": ops.IntegerLabelToDatetimeOp( + freq=pd.tseries.offsets.YearEnd(month=12), # type: ignore + origin="start", + label="left", + ).as_expr("rowindex", "timestamp_col"), + } + + sql = utils._apply_ops_to_sql(bf_df, list(ops_map.values()), list(ops_map.keys())) + snapshot.assert_match(sql, "out.sql")