|
1 | 1 | use crate::queryplanner::rolling::RollingWindowAggregate; |
2 | | -use datafusion::arrow::array::{Array, AsArray}; |
3 | | -use datafusion::arrow::compute::{date_part, DatePart}; |
4 | | -use datafusion::common::tree_node::{ |
5 | | - Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, |
6 | | -}; |
| 2 | +use datafusion::arrow::array::Array; |
| 3 | +use datafusion::arrow::datatypes::DataType; |
| 4 | +use datafusion::common::tree_node::Transformed; |
7 | 5 | use datafusion::common::{Column, DataFusionError, JoinType, ScalarValue, TableReference}; |
8 | 6 | use datafusion::functions::datetime::date_part::DatePartFunc; |
9 | 7 | use datafusion::functions::datetime::date_trunc::DateTruncFunc; |
10 | 8 | use datafusion::logical_expr::expr::{AggregateFunction, AggregateFunctionParams, Alias, ScalarFunction}; |
11 | 9 | use datafusion::logical_expr::{ |
12 | | - Aggregate, BinaryExpr, Cast, ColumnarValue, Expr, Extension, Join, LogicalPlan, Operator, |
13 | | - Projection, ScalarUDFImpl, SubqueryAlias, Union, Unnest, |
| 10 | + Aggregate, BinaryExpr, Cast, ColumnarValue, Expr, Extension, Join, LogicalPlan, Operator, Projection, ScalarFunctionArgs, ScalarUDFImpl, SubqueryAlias, Union, Unnest |
14 | 11 | }; |
15 | 12 | use datafusion::optimizer::optimizer::ApplyOrder; |
16 | 13 | use datafusion::optimizer::{OptimizerConfig, OptimizerRule}; |
17 | 14 | use itertools::Itertools; |
18 | | -use mockall::predicate::le; |
19 | | -use std::collections::HashMap; |
20 | 15 | use std::sync::Arc; |
21 | 16 |
|
22 | 17 | /// Rewrites following logical plan: |
@@ -194,6 +189,7 @@ impl RollingOptimizerRule { |
194 | 189 | _ => None, |
195 | 190 | }) |
196 | 191 | .collect::<Option<Vec<_>>>()?; |
| 192 | + |
197 | 193 | let RollingWindowJoinExtractorResult { |
198 | 194 | input, |
199 | 195 | dimension, |
@@ -261,6 +257,7 @@ impl RollingOptimizerRule { |
261 | 257 | }) => { |
262 | 258 | let left_series = Self::extract_series_projection(left) |
263 | 259 | .or_else(|| Self::extract_series_union(left))?; |
| 260 | + |
264 | 261 | let RollingWindowBoundsExtractorResult { |
265 | 262 | lower_bound, |
266 | 263 | upper_bound, |
@@ -596,10 +593,17 @@ impl RollingOptimizerRule { |
596 | 593 | LogicalPlan::Unnest(Unnest { |
597 | 594 | input, |
598 | 595 | exec_columns, |
| 596 | + schema, |
599 | 597 | .. |
600 | 598 | }) => { |
601 | 599 | let series_column = exec_columns.iter().next().cloned()?; |
602 | | - Self::extract_series_from_unnest(input, series_column) |
| 600 | + let series = Self::extract_series_from_unnest(input, series_column); |
| 601 | + let col = schema.field(0).name(); |
| 602 | + series.map(|mut series| { |
| 603 | + series.from_col = Column::from_name(col); |
| 604 | + series.to_col = series.from_col.clone(); |
| 605 | + series |
| 606 | + }) |
603 | 607 | } |
604 | 608 | _ => None, |
605 | 609 | } |
@@ -633,15 +637,17 @@ impl RollingOptimizerRule { |
633 | 637 | }); |
634 | 638 | } |
635 | 639 | Expr::Literal(ScalarValue::List(list)) => { |
| 640 | + |
636 | 641 | // TODO why does first element holds the array? Is it always the case? |
637 | 642 | let array = list.iter().next().as_ref().cloned()??; |
638 | 643 | let from = ScalarValue::try_from_array(&array, 0).ok()?; |
639 | 644 | let to = |
640 | 645 | ScalarValue::try_from_array(&array, array.len() - 1).ok()?; |
641 | 646 |
|
| 647 | + let index_1 = ScalarValue::try_from_array(&array, 1).ok()?; |
642 | 648 | let every = month_aware_sub( |
643 | 649 | &from, |
644 | | - &ScalarValue::try_from_array(&array, 1).ok()?, |
| 650 | + &index_1, |
645 | 651 | )?; |
646 | 652 |
|
647 | 653 | return Some(RollingWindowSeriesExtractorResult { |
@@ -700,58 +706,99 @@ pub fn month_aware_sub(from: &ScalarValue, to: &ScalarValue) -> Option<ScalarVal |
700 | 706 | | ScalarValue::TimestampMicrosecond(_, None) |
701 | 707 | | ScalarValue::TimestampNanosecond(_, None), |
702 | 708 | ) => { |
| 709 | + let from_type = from.data_type(); |
| 710 | + let to_type = to.data_type(); |
703 | 711 | // TODO lookup from registry? |
704 | 712 | let date_trunc = DateTruncFunc::new(); |
705 | | - let date_part = DatePartFunc::new(); |
706 | 713 | let from_trunc = date_trunc |
707 | | - .invoke(&[ |
708 | | - ColumnarValue::Scalar(ScalarValue::Utf8(Some("month".to_string()))), |
709 | | - ColumnarValue::Scalar(from.clone()), |
710 | | - ]) |
| 714 | + .invoke_with_args( |
| 715 | + ScalarFunctionArgs { |
| 716 | + args: vec![ |
| 717 | + ColumnarValue::Scalar(ScalarValue::Utf8(Some("month".to_string()))), |
| 718 | + ColumnarValue::Scalar(from.clone()), |
| 719 | + ], |
| 720 | + number_rows: 1, |
| 721 | + return_type: &from_type, |
| 722 | + }, |
| 723 | + ) |
711 | 724 | .ok()?; |
712 | 725 | let to_trunc = date_trunc |
713 | | - .invoke(&[ |
714 | | - ColumnarValue::Scalar(ScalarValue::Utf8(Some("month".to_string()))), |
715 | | - ColumnarValue::Scalar(to.clone()), |
716 | | - ]) |
| 726 | + .invoke_with_args( |
| 727 | + ScalarFunctionArgs { |
| 728 | + args: vec![ |
| 729 | + ColumnarValue::Scalar(ScalarValue::Utf8(Some("month".to_string()))), |
| 730 | + ColumnarValue::Scalar(to.clone()), |
| 731 | + ], |
| 732 | + number_rows: 1, |
| 733 | + return_type: &to_type, |
| 734 | + }, |
| 735 | + ) |
717 | 736 | .ok()?; |
718 | 737 | match (from_trunc, to_trunc) { |
719 | 738 | (ColumnarValue::Scalar(from_trunc), ColumnarValue::Scalar(to_trunc)) => { |
| 739 | + // TODO as with date_trunc above, lookup from registry? |
| 740 | + let date_part = DatePartFunc::new(); |
| 741 | + |
720 | 742 | if from.sub(from_trunc.clone()).ok() == to.sub(to_trunc.clone()).ok() { |
721 | 743 | let from_month = date_part |
722 | | - .invoke(&[ |
723 | | - ColumnarValue::Scalar(ScalarValue::Utf8(Some("month".to_string()))), |
724 | | - ColumnarValue::Scalar(from_trunc.clone()), |
725 | | - ]) |
| 744 | + .invoke_with_args( |
| 745 | + ScalarFunctionArgs { |
| 746 | + args: vec![ |
| 747 | + ColumnarValue::Scalar(ScalarValue::Utf8(Some("month".to_string()))), |
| 748 | + ColumnarValue::Scalar(from_trunc.clone()), |
| 749 | + ], |
| 750 | + number_rows: 1, |
| 751 | + return_type: &DataType::Int32, |
| 752 | + }, |
| 753 | + ) |
726 | 754 | .ok()?; |
727 | 755 | let from_year = date_part |
728 | | - .invoke(&[ |
729 | | - ColumnarValue::Scalar(ScalarValue::Utf8(Some("year".to_string()))), |
730 | | - ColumnarValue::Scalar(from_trunc.clone()), |
731 | | - ]) |
| 756 | + .invoke_with_args( |
| 757 | + ScalarFunctionArgs { |
| 758 | + args: vec![ |
| 759 | + ColumnarValue::Scalar(ScalarValue::Utf8(Some("year".to_string()))), |
| 760 | + ColumnarValue::Scalar(from_trunc.clone()), |
| 761 | + ], |
| 762 | + number_rows: 1, |
| 763 | + return_type: &DataType::Int32, |
| 764 | + }, |
| 765 | + ) |
732 | 766 | .ok()?; |
733 | 767 | let to_month = date_part |
734 | | - .invoke(&[ |
735 | | - ColumnarValue::Scalar(ScalarValue::Utf8(Some("month".to_string()))), |
736 | | - ColumnarValue::Scalar(to_trunc.clone()), |
737 | | - ]) |
| 768 | + .invoke_with_args( |
| 769 | + ScalarFunctionArgs { |
| 770 | + args: vec![ |
| 771 | + ColumnarValue::Scalar(ScalarValue::Utf8(Some("month".to_string()))), |
| 772 | + ColumnarValue::Scalar(to_trunc.clone()), |
| 773 | + ], |
| 774 | + number_rows: 1, |
| 775 | + return_type: &DataType::Int32, |
| 776 | + }, |
| 777 | + ) |
738 | 778 | .ok()?; |
739 | 779 | let to_year = date_part |
740 | | - .invoke(&[ |
741 | | - ColumnarValue::Scalar(ScalarValue::Utf8(Some("year".to_string()))), |
742 | | - ColumnarValue::Scalar(to_trunc.clone()), |
743 | | - ]) |
| 780 | + .invoke_with_args( |
| 781 | + ScalarFunctionArgs { |
| 782 | + args: vec![ |
| 783 | + ColumnarValue::Scalar(ScalarValue::Utf8(Some("year".to_string()))), |
| 784 | + ColumnarValue::Scalar(to_trunc.clone()), |
| 785 | + ], |
| 786 | + number_rows: 1, |
| 787 | + return_type: &DataType::Int32, |
| 788 | + }, |
| 789 | + ) |
744 | 790 | .ok()?; |
| 791 | + |
745 | 792 | match (from_month, from_year, to_month, to_year) { |
746 | 793 | ( |
747 | | - ColumnarValue::Scalar(ScalarValue::Float64(Some(from_month))), |
748 | | - ColumnarValue::Scalar(ScalarValue::Float64(Some(from_year))), |
749 | | - ColumnarValue::Scalar(ScalarValue::Float64(Some(to_month))), |
750 | | - ColumnarValue::Scalar(ScalarValue::Float64(Some(to_year))), |
| 794 | + ColumnarValue::Scalar(ScalarValue::Int32(Some(from_month))), |
| 795 | + ColumnarValue::Scalar(ScalarValue::Int32(Some(from_year))), |
| 796 | + ColumnarValue::Scalar(ScalarValue::Int32(Some(to_month))), |
| 797 | + ColumnarValue::Scalar(ScalarValue::Int32(Some(to_year))), |
751 | 798 | ) => { |
752 | 799 | return Some(ScalarValue::IntervalYearMonth(Some( |
753 | | - (to_year - from_year) as i32 * 12 |
754 | | - + (to_month - from_month) as i32, |
| 800 | + (to_year - from_year) * 12 |
| 801 | + + (to_month - from_month), |
755 | 802 | ))) |
756 | 803 | } |
757 | 804 | _ => {} |
|
0 commit comments