diff --git a/bigframes/core/compile/sqlglot/expressions/datetime_ops.py b/bigframes/core/compile/sqlglot/expressions/datetime_ops.py index 0f1e9dadf3..78e17ae33b 100644 --- a/bigframes/core/compile/sqlglot/expressions/datetime_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/datetime_ops.py @@ -23,6 +23,272 @@ import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op +register_binary_op = scalar_compiler.scalar_op_compiler.register_binary_op + + +def _calculate_resample_first(y: TypedExpr, origin: str) -> sge.Expression: + if origin == "epoch": + return sge.convert(0) + elif origin == "start_day": + return sge.func( + "UNIX_MICROS", + sge.Cast( + this=sge.Cast( + this=y.expr, to=sge.DataType(this=sge.DataType.Type.DATE) + ), + to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ), + ), + ) + elif origin == "start": + return sge.func( + "UNIX_MICROS", + sge.Cast(this=y.expr, to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ)), + ) + else: + raise ValueError(f"Origin {origin} not supported") + + +@register_binary_op(ops.DatetimeToIntegerLabelOp, pass_op=True) +def datetime_to_integer_label_op( + x: TypedExpr, y: TypedExpr, op: ops.DatetimeToIntegerLabelOp +) -> sge.Expression: + # Determine if the frequency is fixed by checking if 'op.freq.nanos' is defined. + try: + return _datetime_to_integer_label_fixed_frequency(x, y, op) + except ValueError: + return _datetime_to_integer_label_non_fixed_frequency(x, y, op) + + +def _datetime_to_integer_label_fixed_frequency( + x: TypedExpr, y: TypedExpr, op: ops.DatetimeToIntegerLabelOp +) -> sge.Expression: + """ + This function handles fixed frequency conversions where the unit can range + from microseconds (us) to days. + """ + us = op.freq.nanos / 1000 + x_int = sge.func( + "UNIX_MICROS", + sge.Cast(this=x.expr, to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ)), + ) + first = _calculate_resample_first(y, op.origin) # type: ignore + x_int_label = sge.Cast( + this=sge.Floor( + this=sge.func( + "IEEE_DIVIDE", + sge.Sub(this=x_int, expression=first), + sge.convert(int(us)), + ) + ), + to=sge.DataType.build("INT64"), + ) + return x_int_label + + +def _datetime_to_integer_label_non_fixed_frequency( + x: TypedExpr, y: TypedExpr, op: ops.DatetimeToIntegerLabelOp +) -> sge.Expression: + """ + This function handles non-fixed frequency conversions for units ranging + from weeks to years. + """ + rule_code = op.freq.rule_code + n = op.freq.n + if rule_code == "W-SUN": # Weekly + us = n * 7 * 24 * 60 * 60 * 1000000 + x_trunc = sge.TimestampTrunc(this=x.expr, unit=sge.Var(this="WEEK(MONDAY)")) + y_trunc = sge.TimestampTrunc(this=y.expr, unit=sge.Var(this="WEEK(MONDAY)")) + x_plus_6 = sge.Add( + this=x_trunc, + expression=sge.Interval( + this=sge.convert(6), unit=sge.Identifier(this="DAY") + ), + ) + y_plus_6 = sge.Add( + this=y_trunc, + expression=sge.Interval( + this=sge.convert(6), unit=sge.Identifier(this="DAY") + ), + ) + x_int = sge.func( + "UNIX_MICROS", + sge.Cast( + this=x_plus_6, to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ) + ), + ) + first = sge.func( + "UNIX_MICROS", + sge.Cast( + this=y_plus_6, to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ) + ), + ) + return sge.Case( + ifs=[ + sge.If( + this=sge.EQ(this=x_int, expression=first), + true=sge.convert(0), + ) + ], + default=sge.Add( + this=sge.Cast( + this=sge.Floor( + this=sge.func( + "IEEE_DIVIDE", + sge.Sub( + this=sge.Sub(this=x_int, expression=first), + expression=sge.convert(1), + ), + sge.convert(us), + ) + ), + to=sge.DataType.build("INT64"), + ), + expression=sge.convert(1), + ), + ) + elif rule_code == "ME": # Monthly + x_int = sge.Paren( # type: ignore + this=sge.Add( + this=sge.Mul( + this=sge.Extract( + this=sge.Identifier(this="YEAR"), expression=x.expr + ), + expression=sge.convert(12), + ), + expression=sge.Sub( + this=sge.Extract( + this=sge.Identifier(this="MONTH"), expression=x.expr + ), + expression=sge.convert(1), + ), + ) + ) + first = sge.Paren( # type: ignore + this=sge.Add( + this=sge.Mul( + this=sge.Extract( + this=sge.Identifier(this="YEAR"), expression=y.expr + ), + expression=sge.convert(12), + ), + expression=sge.Sub( + this=sge.Extract( + this=sge.Identifier(this="MONTH"), expression=y.expr + ), + expression=sge.convert(1), + ), + ) + ) + return sge.Case( + ifs=[ + sge.If( + this=sge.EQ(this=x_int, expression=first), + true=sge.convert(0), + ) + ], + default=sge.Add( + this=sge.Cast( + this=sge.Floor( + this=sge.func( + "IEEE_DIVIDE", + sge.Sub( + this=sge.Sub(this=x_int, expression=first), + expression=sge.convert(1), + ), + sge.convert(n), + ) + ), + to=sge.DataType.build("INT64"), + ), + expression=sge.convert(1), + ), + ) + elif rule_code == "QE-DEC": # Quarterly + x_int = sge.Paren( # type: ignore + this=sge.Add( + this=sge.Mul( + this=sge.Extract( + this=sge.Identifier(this="YEAR"), expression=x.expr + ), + expression=sge.convert(4), + ), + expression=sge.Sub( + this=sge.Extract( + this=sge.Identifier(this="QUARTER"), expression=x.expr + ), + expression=sge.convert(1), + ), + ) + ) + first = sge.Paren( # type: ignore + this=sge.Add( + this=sge.Mul( + this=sge.Extract( + this=sge.Identifier(this="YEAR"), expression=y.expr + ), + expression=sge.convert(4), + ), + expression=sge.Sub( + this=sge.Extract( + this=sge.Identifier(this="QUARTER"), expression=y.expr + ), + expression=sge.convert(1), + ), + ) + ) + return sge.Case( + ifs=[ + sge.If( + this=sge.EQ(this=x_int, expression=first), + true=sge.convert(0), + ) + ], + default=sge.Add( + this=sge.Cast( + this=sge.Floor( + this=sge.func( + "IEEE_DIVIDE", + sge.Sub( + this=sge.Sub(this=x_int, expression=first), + expression=sge.convert(1), + ), + sge.convert(n), + ) + ), + to=sge.DataType.build("INT64"), + ), + expression=sge.convert(1), + ), + ) + elif rule_code == "YE-DEC": # Yearly + x_int = sge.Extract(this=sge.Identifier(this="YEAR"), expression=x.expr) + first = sge.Extract(this=sge.Identifier(this="YEAR"), expression=y.expr) + return sge.Case( + ifs=[ + sge.If( + this=sge.EQ(this=x_int, expression=first), + true=sge.convert(0), + ) + ], + default=sge.Add( + this=sge.Cast( + this=sge.Floor( + this=sge.func( + "IEEE_DIVIDE", + sge.Sub( + this=sge.Sub(this=x_int, expression=first), + expression=sge.convert(1), + ), + sge.convert(n), + ) + ), + to=sge.DataType.build("INT64"), + ), + expression=sge.convert(1), + ), + ) + else: + raise ValueError(rule_code) @register_unary_op(ops.FloorDtOp, pass_op=True) diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_datetime_to_integer_label/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_datetime_to_integer_label/out.sql new file mode 100644 index 0000000000..5260dd680a --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_datetime_to_integer_label/out.sql @@ -0,0 +1,38 @@ +WITH `bfcte_0` AS ( + SELECT + `datetime_col`, + `timestamp_col` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CAST(FLOOR( + IEEE_DIVIDE( + UNIX_MICROS(CAST(`datetime_col` AS TIMESTAMP)) - UNIX_MICROS(CAST(`timestamp_col` AS TIMESTAMP)), + 86400000000 + ) + ) AS INT64) AS `bfcol_2`, + CASE + WHEN UNIX_MICROS( + CAST(TIMESTAMP_TRUNC(`datetime_col`, WEEK(MONDAY)) + INTERVAL 6 DAY AS TIMESTAMP) + ) = UNIX_MICROS( + CAST(TIMESTAMP_TRUNC(`timestamp_col`, WEEK(MONDAY)) + INTERVAL 6 DAY AS TIMESTAMP) + ) + THEN 0 + ELSE CAST(FLOOR( + IEEE_DIVIDE( + UNIX_MICROS( + CAST(TIMESTAMP_TRUNC(`datetime_col`, WEEK(MONDAY)) + INTERVAL 6 DAY AS TIMESTAMP) + ) - UNIX_MICROS( + CAST(TIMESTAMP_TRUNC(`timestamp_col`, WEEK(MONDAY)) + INTERVAL 6 DAY AS TIMESTAMP) + ) - 1, + 604800000000 + ) + ) AS INT64) + 1 + END AS `bfcol_3` + FROM `bfcte_0` +) +SELECT + `bfcol_2` AS `fixed_freq`, + `bfcol_3` AS `non_fixed_freq_weekly` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py index 9d93b9019f..c4acb37e51 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py @@ -57,6 +57,22 @@ def test_dayofyear(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_datetime_to_integer_label(scalar_types_df: bpd.DataFrame, snapshot): + col_names = ["datetime_col", "timestamp_col"] + bf_df = scalar_types_df[col_names] + ops_map = { + "fixed_freq": ops.DatetimeToIntegerLabelOp( + freq=pd.tseries.offsets.Day(), origin="start", closed="left" # type: ignore + ).as_expr("datetime_col", "timestamp_col"), + "non_fixed_freq_weekly": ops.DatetimeToIntegerLabelOp( + freq=pd.tseries.offsets.Week(weekday=6), origin="start", closed="left" # type: ignore + ).as_expr("datetime_col", "timestamp_col"), + } + + sql = utils._apply_ops_to_sql(bf_df, list(ops_map.values()), list(ops_map.keys())) + snapshot.assert_match(sql, "out.sql") + + def test_floor_dt(scalar_types_df: bpd.DataFrame, snapshot): col_names = ["datetime_col", "timestamp_col", "date_col"] bf_df = scalar_types_df[col_names]