1313import data_algebra .env
1414from data_algebra .data_ops_types import *
1515
16-
1716_have_black = False
1817try :
1918 # noinspection PyUnresolvedReferences
2322except ImportError :
2423 pass
2524
26-
2725_have_sqlparse = False
2826try :
2927 # noinspection PyUnresolvedReferences
@@ -333,6 +331,24 @@ def extend(
333331 ):
334332 if (ops is None ) or (len (ops ) < 1 ):
335333 return self
334+ parsed_ops = data_algebra .expr_rep .parse_assignments_in_context (
335+ ops , self , parse_env = parse_env
336+ )
337+ new_cols_used_in_calc = set (data_algebra .expr_rep .get_columns_used (parsed_ops ))
338+ if partition_by is None :
339+ partition_by = []
340+ if order_by is None :
341+ order_by = []
342+ if reverse is None :
343+ reverse = []
344+ new_cols_produced_in_calc = set ([k for k in parsed_ops .keys ()])
345+ if (partition_by != 1 ) and (len (partition_by ) > 0 ):
346+ if len (new_cols_produced_in_calc .intersection (partition_by )) > 0 :
347+ raise ValueError ("must not change partition_by columns" )
348+ if len (new_cols_produced_in_calc .intersection (order_by )) > 0 :
349+ raise ValueError ("must not change partition_by columns" )
350+ if len (set (reverse ).difference (order_by )) > 0 :
351+ raise ValueError ("all columns in reverse must be in order_by" )
336352 if self .is_trivial_when_intermediate ():
337353 return self .sources [0 ].extend (
338354 ops ,
@@ -341,13 +357,53 @@ def extend(
341357 reverse = reverse ,
342358 parse_env = parse_env ,
343359 )
360+ if isinstance (self , ExtendNode ):
361+ compatible_partition = (partition_by == self .partition_by ) or (
362+ ((partition_by == 1 ) or (len (partition_by ) <= 0 ))
363+ and ((self .partition_by == 1 ) or (len (self .partition_by ) <= 0 ))
364+ )
365+ same_windowing = (
366+ data_algebra .expr_rep .implies_windowed (parsed_ops )
367+ == self .windowed_situation
368+ )
369+ if (
370+ compatible_partition
371+ and same_windowing
372+ and (order_by == self .order_by )
373+ and (reverse == self .reverse )
374+ and (
375+ len (new_cols_used_in_calc .intersection (self .cols_produced_in_calc ))
376+ == 0
377+ )
378+ and (
379+ len (
380+ new_cols_produced_in_calc .intersection (
381+ self .cols_produced_in_calc
382+ )
383+ )
384+ == 0
385+ )
386+ and (
387+ len (new_cols_produced_in_calc .intersection (self .cols_used_in_calc ))
388+ == 0
389+ )
390+ ):
391+ # merge the extends
392+ new_ops = self .ops .copy ()
393+ new_ops .update (parsed_ops )
394+ return ExtendNode (
395+ source = self .sources [0 ],
396+ parsed_ops = new_ops ,
397+ partition_by = partition_by ,
398+ order_by = order_by ,
399+ reverse = reverse ,
400+ )
344401 return ExtendNode (
345402 source = self ,
346- ops = ops ,
403+ parsed_ops = parsed_ops ,
347404 partition_by = partition_by ,
348405 order_by = order_by ,
349406 reverse = reverse ,
350- parse_env = parse_env ,
351407 )
352408
353409 def project (self , ops = None , * , group_by = None , parse_env = None ):
@@ -357,7 +413,10 @@ def project(self, ops=None, *, group_by=None, parse_env=None):
357413 raise ValueError ("must have ops or group_by" )
358414 if self .is_trivial_when_intermediate ():
359415 return self .sources [0 ].project (ops , group_by = group_by , parse_env = parse_env )
360- return ProjectNode (source = self , ops = ops , group_by = group_by , parse_env = parse_env )
416+ parsed_ops = data_algebra .expr_rep .parse_assignments_in_context (
417+ ops , self , parse_env = parse_env
418+ )
419+ return ProjectNode (source = self , parsed_ops = parsed_ops , group_by = group_by )
361420
362421 def natural_join (self , b , * , by = None , jointype = "INNER" ):
363422 if not isinstance (b , ViewRepresentation ):
@@ -793,31 +852,12 @@ def wrap(d, *, table_name="data_frame"):
793852
794853class ExtendNode (ViewRepresentation ):
795854 def __init__ (
796- self ,
797- source ,
798- ops ,
799- * ,
800- partition_by = None ,
801- order_by = None ,
802- reverse = None ,
803- parse_env = None
855+ self , * , source , parsed_ops , partition_by = None , order_by = None , reverse = None ,
804856 ):
805- windowed_situation = False
806- if ops is None :
807- ops = {}
808- ops = data_algebra .expr_rep .parse_assignments_in_context (
809- ops , source , parse_env = parse_env
810- )
811- if len (ops ) < 1 :
812- raise ValueError ("no ops" )
813- for (k , opk ) in ops .items (): # look for aggregation functions
814- if isinstance (opk , data_algebra .expr_rep .Expression ):
815- if (
816- opk .op
817- in data_algebra .expr_rep .fn_names_that_imply_windowed_situation
818- ):
819- windowed_situation = True
820- self .ops = ops
857+ windowed_situation = data_algebra .expr_rep .implies_windowed (parsed_ops )
858+ self .ops = parsed_ops
859+ self .cols_used_in_calc = data_algebra .expr_rep .get_columns_used (parsed_ops )
860+ self .cols_produced_in_calc = [k for k in parsed_ops .keys ()]
821861 if partition_by is None :
822862 partition_by = []
823863 if isinstance (partition_by , numbers .Number ):
@@ -843,13 +883,13 @@ def __init__(
843883 self .reverse = reverse
844884 column_names = source .column_names .copy ()
845885 consumed_cols = set ()
846- for (k , o ) in ops .items ():
886+ for (k , o ) in parsed_ops .items ():
847887 o .get_column_names (consumed_cols )
848888 unknown_cols = consumed_cols - source .column_set
849889 if len (unknown_cols ) > 0 :
850890 raise KeyError ("referred to unknown columns: " + str (unknown_cols ))
851891 known_cols = set (column_names )
852- for ci in ops .keys ():
892+ for ci in parsed_ops .keys ():
853893 if ci not in known_cols :
854894 column_names .append (ci )
855895 if len (partition_by ) != len (set (partition_by )):
@@ -867,14 +907,14 @@ def __init__(
867907 unknown = set (reverse ) - set (order_by )
868908 if len (unknown ) > 0 :
869909 raise ValueError ("reverse columns not in order_by: " + str (unknown ))
870- bad_overwrite = set (ops .keys ()).intersection (
910+ bad_overwrite = set (parsed_ops .keys ()).intersection (
871911 set (partition_by ).union (order_by , reverse )
872912 )
873913 if len (bad_overwrite ) > 0 :
874914 raise ValueError ("tried to change: " + str (bad_overwrite ))
875915 # check op arguments are very simple: all arguments are column names
876916 if windowed_situation :
877- for (k , opk ) in ops .items ():
917+ for (k , opk ) in parsed_ops .items ():
878918 if not isinstance (opk , data_algebra .expr_rep .Expression ):
879919 raise ValueError (
880920 "non-aggregated expression in windowed/partitoned extend: "
@@ -991,13 +1031,8 @@ def eval_implementation(self, *, data_map, eval_env, data_model):
9911031
9921032
9931033class ProjectNode (ViewRepresentation ):
994- def __init__ (self , source , ops = None , * , group_by = None , parse_env = None ):
995- if ops is None :
996- ops = {}
997- ops = data_algebra .expr_rep .parse_assignments_in_context (
998- ops , source , parse_env = parse_env
999- )
1000- self .ops = ops
1034+ def __init__ (self , * , source , parsed_ops , group_by = None ):
1035+ self .ops = parsed_ops
10011036 if group_by is None :
10021037 group_by = []
10031038 if isinstance (group_by , str ):
@@ -1007,13 +1042,13 @@ def __init__(self, source, ops=None, *, group_by=None, parse_env=None):
10071042 consumed_cols = set ()
10081043 for c in group_by :
10091044 consumed_cols .add (c )
1010- for (k , o ) in ops .items ():
1045+ for (k , o ) in parsed_ops .items ():
10111046 o .get_column_names (consumed_cols )
10121047 unknown_cols = consumed_cols - source .column_set
10131048 if len (unknown_cols ) > 0 :
10141049 raise KeyError ("referred to unknown columns: " + str (unknown_cols ))
10151050 known_cols = set (column_names )
1016- for ci in ops .keys ():
1051+ for ci in parsed_ops .keys ():
10171052 if ci not in known_cols :
10181053 column_names .append (ci )
10191054 if len (group_by ) != len (set (group_by )):
0 commit comments