1+
12from typing import Union
23import collections
34
@@ -27,15 +28,15 @@ def __op_expr__(self, op, other, *, inline=True, method=False):
2728 if not isinstance (op , str ):
2829 raise TypeError ("op is supposed to be a string" )
2930 if not isinstance (other , Term ):
30- other = Value (other )
31+ other = _enc_value (other )
3132 return Expression (op , (self , other ), inline = inline , method = method )
3233
3334 def __rop_expr__ (self , op , other ):
3435 """reversed binary expression"""
3536 if not isinstance (op , str ):
3637 raise TypeError ("op is supposed to be a string" )
3738 if not isinstance (other , Term ):
38- other = Value (other )
39+ other = _enc_value (other )
3940 return Expression (op , (other , self ), inline = True )
4041
4142 def __uop_expr__ (self , op , * , params = None ):
@@ -49,9 +50,9 @@ def __triop_expr__(self, op, x, y, inline=False, method=False):
4950 if not isinstance (op , str ):
5051 raise TypeError ("op is supposed to be a string" )
5152 if not isinstance (x , Term ):
52- x = Value (x )
53+ x = _enc_value (x )
5354 if not isinstance (y , Term ):
54- y = Value (y )
55+ y = _enc_value (y )
5556 return Expression (op , (self , x , y ), inline = inline , method = method )
5657
5758 # tree re-write
@@ -690,6 +691,48 @@ def to_python(self, want_inline_parens=False):
690691 return self .value .__repr__ ()
691692
692693
694+ class FnTerm (Term ):
695+ def __init__ (self , value ):
696+ if not callable (value ):
697+ raise TypeError ("value type must be callable" )
698+ self .value = value
699+ Term .__init__ (self )
700+
701+ def replace_view (self , view ):
702+ return self
703+
704+ def to_python (self , want_inline_parens = False ):
705+ return self .value .__name__
706+
707+
708+ class ListTerm (Term ):
709+ def __init__ (self , value ):
710+ if not isinstance (value , list ):
711+ raise TypeError ("value type must be a list" )
712+ self .value = value
713+ Term .__init__ (self )
714+
715+ def replace_view (self , view ):
716+ return self
717+
718+ def to_python (self , want_inline_parens = False ):
719+ return self .value .__name__
720+
721+ def get_column_names (self , columns_seen ):
722+ for ti in self .value :
723+ ti .get_column_names (ti )
724+
725+
726+ def _enc_value (value ):
727+ if isinstance (value , Term ):
728+ return value
729+ if callable (value ):
730+ return FnTerm (value )
731+ if isinstance (value , list ):
732+ return ListTerm (value )
733+ return Value (value )
734+
735+
693736class ColumnReference (Term ):
694737 """class to represent referring to a column"""
695738
@@ -727,10 +770,10 @@ def get_column_names(self, columns_seen):
727770
728771def connected_components (expr ):
729772 return ("@connected_components("
730- + expr .args [0 ].to_pandas ()
731- + ", "
732- + expr .args [1 ].to_pandas ()
733- + ")" )
773+ + expr .args [0 ].to_pandas ()
774+ + ", "
775+ + expr .args [1 ].to_pandas ()
776+ + ")" )
734777
735778
736779pd_formatters = {
@@ -756,45 +799,22 @@ def connected_components(expr):
756799 "connected_components" : connected_components ,
757800 "partitioned_eval" : lambda expr : (
758801 "@partitioned_eval("
759- # expr.args[0] is a function
760- + '@' + expr .args [0 ].__name__
802+ # expr.args[0] is a FnTerm
803+ + '@' + expr .args [0 ].to_pandas ()
761804 + ", "
762- # expr.args[1] is a list of args
763- + '[' + ', ' .join ([ei .to_pandas () for ei in expr .args [1 ]]) + ']'
805+ # expr.args[1] is a ListTerm
806+ + '[' + ', ' .join ([ei .to_pandas () for ei in expr .args [1 ]. value ]) + ']'
764807 + ", "
765- # expr.args[2] is a list of args
766- + '[' + ', ' .join ([ei .to_pandas () for ei in expr .args [2 ]]) + ']'
808+ # expr.args[2] is a ListTerm
809+ + '[' + ', ' .join ([ei .to_pandas () for ei in expr .args [2 ]. value ]) + ']'
767810 + ")"
768811 ),
769-
770812}
771813
772814
773815r_formatters = {"neg" : lambda expr : "-" + expr .args [0 ].to_R (want_inline_parens = True )}
774816
775817
776- # obj may not be of type Expression
777- def _get_column_names (obj , columns_seen ):
778- if isinstance (obj , Term ):
779- # back to object methods path
780- obj .get_column_names (columns_seen )
781- return
782- if isinstance (obj , list ):
783- for b in obj :
784- _get_column_names (b , columns_seen )
785- return
786-
787-
788- # obj may not be of type Expression
789- def _to_python (obj , * , want_inline_parens = False ):
790- if callable (obj ):
791- return obj .__name__
792- if isinstance (obj , Term ):
793- # back to object methods path
794- return obj .to_python (want_inline_parens = want_inline_parens )
795- return str (obj )
796-
797-
798818class Expression (Term ):
799819 def __init__ (self , op , args , * , params = None , inline = False , method = False ):
800820 if not isinstance (op , str ):
@@ -805,8 +825,7 @@ def __init__(self, op, args, *, params=None, inline=False, method=False):
805825 if len (args ) != 2 :
806826 raise ValueError ("must have two arguments if inline is True" )
807827 self .op = op
808- # TODO: deal with lists, functions and values here (test through test_cc.py)
809- self .args = args
828+ self .args = [_enc_value (ai ) for ai in args ]
810829 self .params = params
811830 self .inline = inline
812831 self .method = method
@@ -820,12 +839,12 @@ def replace_view(self, view):
820839
821840 def get_column_names (self , columns_seen ):
822841 for a in self .args :
823- _get_column_names ( a , columns_seen )
842+ a . get_column_names ( columns_seen )
824843
825844 def to_python (self , * , want_inline_parens = False ):
826845 if self .op in py_formatters .keys ():
827846 return py_formatters [self .op ](self )
828- subs = [_to_python ( ai , want_inline_parens = True ) for ai in self .args ]
847+ subs = [ai . to_python ( want_inline_parens = True ) for ai in self .args ]
829848 if len (subs ) <= 0 :
830849 return "_" + self .op + "()"
831850 if len (subs ) == 1 :
@@ -897,7 +916,7 @@ def _parse_by_eval(source_str, *, data_def, outter_environemnt=None):
897916 source_str , outter_environemnt , data_def
898917 ) # eval is eval(source, globals, locals)- so mp is first
899918 if not isinstance (v , Term ):
900- v = Value (v )
919+ v = _enc_value (v )
901920 v .source_string = source_str
902921 return v
903922
0 commit comments