|
19 | 19 | from bigframes import dtypes |
20 | 20 | from bigframes import operations as ops |
21 | 21 | from bigframes.core.compile.constants import UNIT_TO_US_CONVERSION_FACTORS |
| 22 | +from bigframes.core.compile.sqlglot import sqlglot_types |
22 | 23 | from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr |
23 | 24 | import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler |
24 | 25 |
|
25 | 26 | register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op |
26 | 27 | register_binary_op = scalar_compiler.scalar_op_compiler.register_binary_op |
27 | 28 |
|
28 | 29 |
|
29 | | -def _calculate_resample_first(y: TypedExpr, origin: str) -> sge.Expression: |
30 | | - if origin == "epoch": |
31 | | - return sge.convert(0) |
32 | | - elif origin == "start_day": |
33 | | - return sge.func( |
34 | | - "UNIX_MICROS", |
35 | | - sge.Cast( |
36 | | - this=sge.Cast( |
37 | | - this=y.expr, to=sge.DataType(this=sge.DataType.Type.DATE) |
38 | | - ), |
39 | | - to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ), |
40 | | - ), |
41 | | - ) |
42 | | - elif origin == "start": |
43 | | - return sge.func( |
44 | | - "UNIX_MICROS", |
45 | | - sge.Cast(this=y.expr, to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ)), |
46 | | - ) |
47 | | - else: |
48 | | - raise ValueError(f"Origin {origin} not supported") |
49 | | - |
50 | | - |
51 | 30 | @register_binary_op(ops.DatetimeToIntegerLabelOp, pass_op=True) |
52 | 31 | def datetime_to_integer_label_op( |
53 | 32 | x: TypedExpr, y: TypedExpr, op: ops.DatetimeToIntegerLabelOp |
@@ -317,6 +296,20 @@ def _(expr: TypedExpr, op: ops.FloorDtOp) -> sge.Expression: |
317 | 296 | return sge.TimestampTrunc(this=expr.expr, unit=sge.Identifier(this=bq_freq)) |
318 | 297 |
|
319 | 298 |
|
| 299 | +def _calculate_resample_first(y: TypedExpr, origin: str) -> sge.Expression: |
| 300 | + if origin == "epoch": |
| 301 | + return sge.convert(0) |
| 302 | + elif origin == "start_day": |
| 303 | + return sge.func( |
| 304 | + "UNIX_MICROS", |
| 305 | + sge.Cast(this=sge.Cast(this=y.expr, to="DATE"), to="TIMESTAMP"), |
| 306 | + ) |
| 307 | + elif origin == "start": |
| 308 | + return sge.func("UNIX_MICROS", sge.Cast(this=y.expr, to="TIMESTAMP")) |
| 309 | + else: |
| 310 | + raise ValueError(f"Origin {origin} not supported") |
| 311 | + |
| 312 | + |
320 | 313 | @register_unary_op(ops.hour_op) |
321 | 314 | def _(expr: TypedExpr) -> sge.Expression: |
322 | 315 | return sge.Extract(this=sge.Identifier(this="HOUR"), expression=expr.expr) |
@@ -436,3 +429,245 @@ def _(expr: TypedExpr, op: ops.UnixSeconds) -> sge.Expression: |
436 | 429 | @register_unary_op(ops.year_op) |
437 | 430 | def _(expr: TypedExpr) -> sge.Expression: |
438 | 431 | return sge.Extract(this=sge.Identifier(this="YEAR"), expression=expr.expr) |
| 432 | + |
| 433 | + |
| 434 | +@register_binary_op(ops.IntegerLabelToDatetimeOp, pass_op=True) |
| 435 | +def integer_label_to_datetime_op( |
| 436 | + x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp |
| 437 | +) -> sge.Expression: |
| 438 | + # Determine if the frequency is fixed by checking if 'op.freq.nanos' is defined. |
| 439 | + try: |
| 440 | + return _integer_label_to_datetime_op_fixed_frequency(x, y, op) |
| 441 | + |
| 442 | + except ValueError: |
| 443 | + # Non-fixed frequency conversions for units ranging from weeks to years. |
| 444 | + rule_code = op.freq.rule_code |
| 445 | + |
| 446 | + if rule_code == "W-SUN": |
| 447 | + return _integer_label_to_datetime_op_weekly_freq(x, y, op) |
| 448 | + |
| 449 | + if rule_code in ("ME", "M"): |
| 450 | + return _integer_label_to_datetime_op_monthly_freq(x, y, op) |
| 451 | + |
| 452 | + if rule_code in ("QE-DEC", "Q-DEC"): |
| 453 | + return _integer_label_to_datetime_op_quarterly_freq(x, y, op) |
| 454 | + |
| 455 | + if rule_code in ("YE-DEC", "A-DEC", "Y-DEC"): |
| 456 | + return _integer_label_to_datetime_op_yearly_freq(x, y, op) |
| 457 | + |
| 458 | + # If the rule_code is not recognized, raise an error here. |
| 459 | + raise ValueError(f"Unsupported frequency rule code: {rule_code}") |
| 460 | + |
| 461 | + |
| 462 | +def _integer_label_to_datetime_op_fixed_frequency( |
| 463 | + x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp |
| 464 | +) -> sge.Expression: |
| 465 | + """ |
| 466 | + This function handles fixed frequency conversions where the unit can range |
| 467 | + from microseconds (us) to days. |
| 468 | + """ |
| 469 | + us = op.freq.nanos / 1000 |
| 470 | + first = _calculate_resample_first(y, op.origin) # type: ignore |
| 471 | + x_label = sge.Cast( |
| 472 | + this=sge.func( |
| 473 | + "TIMESTAMP_MICROS", |
| 474 | + sge.Cast( |
| 475 | + this=sge.Add( |
| 476 | + this=sge.Mul( |
| 477 | + this=sge.Cast(this=x.expr, to="BIGNUMERIC"), |
| 478 | + expression=sge.convert(int(us)), |
| 479 | + ), |
| 480 | + expression=sge.Cast(this=first, to="BIGNUMERIC"), |
| 481 | + ), |
| 482 | + to="INT64", |
| 483 | + ), |
| 484 | + ), |
| 485 | + to=sqlglot_types.from_bigframes_dtype(y.dtype), |
| 486 | + ) |
| 487 | + return x_label |
| 488 | + |
| 489 | + |
| 490 | +def _integer_label_to_datetime_op_weekly_freq( |
| 491 | + x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp |
| 492 | +) -> sge.Expression: |
| 493 | + n = op.freq.n |
| 494 | + # Calculate microseconds for the weekly interval. |
| 495 | + us = n * 7 * 24 * 60 * 60 * 1000000 |
| 496 | + first = sge.func( |
| 497 | + "UNIX_MICROS", |
| 498 | + sge.Add( |
| 499 | + this=sge.TimestampTrunc( |
| 500 | + this=sge.Cast(this=y.expr, to="TIMESTAMP"), |
| 501 | + unit=sge.Var(this="WEEK(MONDAY)"), |
| 502 | + ), |
| 503 | + expression=sge.Interval( |
| 504 | + this=sge.convert(6), unit=sge.Identifier(this="DAY") |
| 505 | + ), |
| 506 | + ), |
| 507 | + ) |
| 508 | + return sge.Cast( |
| 509 | + this=sge.func( |
| 510 | + "TIMESTAMP_MICROS", |
| 511 | + sge.Cast( |
| 512 | + this=sge.Add( |
| 513 | + this=sge.Mul( |
| 514 | + this=sge.Cast(this=x.expr, to="BIGNUMERIC"), |
| 515 | + expression=sge.convert(us), |
| 516 | + ), |
| 517 | + expression=sge.Cast(this=first, to="BIGNUMERIC"), |
| 518 | + ), |
| 519 | + to="INT64", |
| 520 | + ), |
| 521 | + ), |
| 522 | + to=sqlglot_types.from_bigframes_dtype(y.dtype), |
| 523 | + ) |
| 524 | + |
| 525 | + |
| 526 | +def _integer_label_to_datetime_op_monthly_freq( |
| 527 | + x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp |
| 528 | +) -> sge.Expression: |
| 529 | + n = op.freq.n |
| 530 | + one = sge.convert(1) |
| 531 | + twelve = sge.convert(12) |
| 532 | + first = sge.Sub( # type: ignore |
| 533 | + this=sge.Add( |
| 534 | + this=sge.Mul( |
| 535 | + this=sge.Extract(this="YEAR", expression=y.expr), |
| 536 | + expression=twelve, |
| 537 | + ), |
| 538 | + expression=sge.Extract(this="MONTH", expression=y.expr), |
| 539 | + ), |
| 540 | + expression=one, |
| 541 | + ) |
| 542 | + x_val = sge.Add( |
| 543 | + this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first |
| 544 | + ) |
| 545 | + year = sge.Cast( |
| 546 | + this=sge.Floor(this=sge.func("IEEE_DIVIDE", x_val, twelve)), |
| 547 | + to="INT64", |
| 548 | + ) |
| 549 | + month = sge.Add(this=sge.Mod(this=x_val, expression=twelve), expression=one) |
| 550 | + |
| 551 | + next_year = sge.Case( |
| 552 | + ifs=[ |
| 553 | + sge.If( |
| 554 | + this=sge.EQ(this=month, expression=twelve), |
| 555 | + true=sge.Add(this=year, expression=one), |
| 556 | + ) |
| 557 | + ], |
| 558 | + default=year, |
| 559 | + ) |
| 560 | + next_month = sge.Case( |
| 561 | + ifs=[sge.If(this=sge.EQ(this=month, expression=twelve), true=one)], |
| 562 | + default=sge.Add(this=month, expression=one), |
| 563 | + ) |
| 564 | + next_month_date = sge.func( |
| 565 | + "TIMESTAMP", |
| 566 | + sge.Anonymous( |
| 567 | + this="DATETIME", |
| 568 | + expressions=[ |
| 569 | + next_year, |
| 570 | + next_month, |
| 571 | + one, |
| 572 | + sge.convert(0), |
| 573 | + sge.convert(0), |
| 574 | + sge.convert(0), |
| 575 | + ], |
| 576 | + ), |
| 577 | + ) |
| 578 | + x_label = sge.Sub( # type: ignore |
| 579 | + this=next_month_date, expression=sge.Interval(this=one, unit="DAY") |
| 580 | + ) |
| 581 | + return sge.Cast(this=x_label, to=sqlglot_types.from_bigframes_dtype(y.dtype)) |
| 582 | + |
| 583 | + |
| 584 | +def _integer_label_to_datetime_op_quarterly_freq( |
| 585 | + x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp |
| 586 | +) -> sge.Expression: |
| 587 | + n = op.freq.n |
| 588 | + one = sge.convert(1) |
| 589 | + three = sge.convert(3) |
| 590 | + four = sge.convert(4) |
| 591 | + twelve = sge.convert(12) |
| 592 | + first = sge.Sub( # type: ignore |
| 593 | + this=sge.Add( |
| 594 | + this=sge.Mul( |
| 595 | + this=sge.Extract(this="YEAR", expression=y.expr), |
| 596 | + expression=four, |
| 597 | + ), |
| 598 | + expression=sge.Extract(this="QUARTER", expression=y.expr), |
| 599 | + ), |
| 600 | + expression=one, |
| 601 | + ) |
| 602 | + x_val = sge.Add( |
| 603 | + this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first |
| 604 | + ) |
| 605 | + year = sge.Cast( |
| 606 | + this=sge.Floor(this=sge.func("IEEE_DIVIDE", x_val, four)), |
| 607 | + to="INT64", |
| 608 | + ) |
| 609 | + month = sge.Mul( # type: ignore |
| 610 | + this=sge.Paren( |
| 611 | + this=sge.Add(this=sge.Mod(this=x_val, expression=four), expression=one) |
| 612 | + ), |
| 613 | + expression=three, |
| 614 | + ) |
| 615 | + |
| 616 | + next_year = sge.Case( |
| 617 | + ifs=[ |
| 618 | + sge.If( |
| 619 | + this=sge.EQ(this=month, expression=twelve), |
| 620 | + true=sge.Add(this=year, expression=one), |
| 621 | + ) |
| 622 | + ], |
| 623 | + default=year, |
| 624 | + ) |
| 625 | + next_month = sge.Case( |
| 626 | + ifs=[sge.If(this=sge.EQ(this=month, expression=twelve), true=one)], |
| 627 | + default=sge.Add(this=month, expression=one), |
| 628 | + ) |
| 629 | + next_month_date = sge.Anonymous( |
| 630 | + this="DATETIME", |
| 631 | + expressions=[ |
| 632 | + next_year, |
| 633 | + next_month, |
| 634 | + one, |
| 635 | + sge.convert(0), |
| 636 | + sge.convert(0), |
| 637 | + sge.convert(0), |
| 638 | + ], |
| 639 | + ) |
| 640 | + x_label = sge.Sub( # type: ignore |
| 641 | + this=next_month_date, expression=sge.Interval(this=one, unit="DAY") |
| 642 | + ) |
| 643 | + return sge.Cast(this=x_label, to=sqlglot_types.from_bigframes_dtype(y.dtype)) |
| 644 | + |
| 645 | + |
| 646 | +def _integer_label_to_datetime_op_yearly_freq( |
| 647 | + x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp |
| 648 | +) -> sge.Expression: |
| 649 | + n = op.freq.n |
| 650 | + one = sge.convert(1) |
| 651 | + first = sge.Extract(this="YEAR", expression=y.expr) |
| 652 | + x_val = sge.Add( |
| 653 | + this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first |
| 654 | + ) |
| 655 | + next_year = sge.Add(this=x_val, expression=one) # type: ignore |
| 656 | + next_month_date = sge.func( |
| 657 | + "TIMESTAMP", |
| 658 | + sge.Anonymous( |
| 659 | + this="DATETIME", |
| 660 | + expressions=[ |
| 661 | + next_year, |
| 662 | + one, |
| 663 | + one, |
| 664 | + sge.convert(0), |
| 665 | + sge.convert(0), |
| 666 | + sge.convert(0), |
| 667 | + ], |
| 668 | + ), |
| 669 | + ) |
| 670 | + x_label = sge.Sub( # type: ignore |
| 671 | + this=next_month_date, expression=sge.Interval(this=one, unit="DAY") |
| 672 | + ) |
| 673 | + return sge.Cast(this=x_label, to=sqlglot_types.from_bigframes_dtype(y.dtype)) |
0 commit comments