1616// under the License.
1717
1818use datafusion_common:: DataFusionError ;
19+ use datafusion_expr:: expr:: { AggregateFunction , AggregateUDF , Alias } ;
1920use datafusion_expr:: logical_plan:: Aggregate ;
21+ use datafusion_expr:: Expr ;
2022use pyo3:: prelude:: * ;
2123use std:: fmt:: { self , Display , Formatter } ;
2224
2325use super :: logical_node:: LogicalNode ;
2426use crate :: common:: df_schema:: PyDFSchema ;
27+ use crate :: errors:: py_type_err;
2528use crate :: expr:: PyExpr ;
2629use crate :: sql:: logical:: PyLogicalPlan ;
2730
@@ -84,6 +87,24 @@ impl PyAggregate {
8487 . collect ( ) )
8588 }
8689
90+ /// Returns the inner Aggregate Expr(s)
91+ pub fn agg_expressions ( & self ) -> PyResult < Vec < PyExpr > > {
92+ Ok ( self
93+ . aggregate
94+ . aggr_expr
95+ . iter ( )
96+ . map ( |e| PyExpr :: from ( e. clone ( ) ) )
97+ . collect ( ) )
98+ }
99+
100+ pub fn agg_func_name ( & self , expr : PyExpr ) -> PyResult < String > {
101+ Self :: _agg_func_name ( & expr. expr )
102+ }
103+
104+ pub fn aggregation_arguments ( & self , expr : PyExpr ) -> PyResult < Vec < PyExpr > > {
105+ self . _aggregation_arguments ( & expr. expr )
106+ }
107+
87108 // Retrieves the input `LogicalPlan` to this `Aggregate` node
88109 fn input ( & self ) -> PyResult < Vec < PyLogicalPlan > > {
89110 Ok ( Self :: inputs ( self ) )
@@ -99,6 +120,34 @@ impl PyAggregate {
99120 }
100121}
101122
123+ impl PyAggregate {
124+ #[ allow( clippy:: only_used_in_recursion) ]
125+ fn _aggregation_arguments ( & self , expr : & Expr ) -> PyResult < Vec < PyExpr > > {
126+ match expr {
127+ // TODO: This Alias logic seems to be returning some strange results that we should investigate
128+ Expr :: Alias ( Alias { expr, .. } ) => self . _aggregation_arguments ( expr. as_ref ( ) ) ,
129+ Expr :: AggregateFunction ( AggregateFunction { fun : _, args, .. } )
130+ | Expr :: AggregateUDF ( AggregateUDF { fun : _, args, .. } ) => {
131+ Ok ( args. iter ( ) . map ( |e| PyExpr :: from ( e. clone ( ) ) ) . collect ( ) )
132+ }
133+ _ => Err ( py_type_err (
134+ "Encountered a non Aggregate type in aggregation_arguments" ,
135+ ) ) ,
136+ }
137+ }
138+
139+ fn _agg_func_name ( expr : & Expr ) -> PyResult < String > {
140+ match expr {
141+ Expr :: Alias ( Alias { expr, .. } ) => Self :: _agg_func_name ( expr. as_ref ( ) ) ,
142+ Expr :: AggregateFunction ( AggregateFunction { fun, .. } ) => Ok ( fun. to_string ( ) ) ,
143+ Expr :: AggregateUDF ( AggregateUDF { fun, .. } ) => Ok ( fun. name . clone ( ) ) ,
144+ _ => Err ( py_type_err (
145+ "Encountered a non Aggregate type in agg_func_name" ,
146+ ) ) ,
147+ }
148+ }
149+ }
150+
102151impl LogicalNode for PyAggregate {
103152 fn inputs ( & self ) -> Vec < PyLogicalPlan > {
104153 vec ! [ PyLogicalPlan :: from( ( * self . aggregate. input) . clone( ) ) ]
0 commit comments