Skip to content

Commit b2861f2

Browse files
committed
add some extend folding
1 parent aaf7dae commit b2861f2

File tree

8 files changed

+266
-108
lines changed

8 files changed

+266
-108
lines changed

build/lib/data_algebra/data_ops.py

Lines changed: 77 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import data_algebra.env
1414
from data_algebra.data_ops_types import *
1515

16-
1716
_have_black = False
1817
try:
1918
# noinspection PyUnresolvedReferences
@@ -23,7 +22,6 @@
2322
except ImportError:
2423
pass
2524

26-
2725
_have_sqlparse = False
2826
try:
2927
# noinspection PyUnresolvedReferences
@@ -333,6 +331,24 @@ def extend(
333331
):
334332
if (ops is None) or (len(ops) < 1):
335333
return self
334+
parsed_ops = data_algebra.expr_rep.parse_assignments_in_context(
335+
ops, self, parse_env=parse_env
336+
)
337+
new_cols_used_in_calc = set(data_algebra.expr_rep.get_columns_used(parsed_ops))
338+
if partition_by is None:
339+
partition_by = []
340+
if order_by is None:
341+
order_by = []
342+
if reverse is None:
343+
reverse = []
344+
new_cols_produced_in_calc = set([k for k in parsed_ops.keys()])
345+
if (partition_by != 1) and (len(partition_by) > 0):
346+
if len(new_cols_produced_in_calc.intersection(partition_by)) > 0:
347+
raise ValueError("must not change partition_by columns")
348+
if len(new_cols_produced_in_calc.intersection(order_by)) > 0:
349+
raise ValueError("must not change partition_by columns")
350+
if len(set(reverse).difference(order_by)) > 0:
351+
raise ValueError("all columns in reverse must be in order_by")
336352
if self.is_trivial_when_intermediate():
337353
return self.sources[0].extend(
338354
ops,
@@ -341,13 +357,53 @@ def extend(
341357
reverse=reverse,
342358
parse_env=parse_env,
343359
)
360+
if isinstance(self, ExtendNode):
361+
compatible_partition = (partition_by == self.partition_by) or (
362+
((partition_by == 1) or (len(partition_by) <= 0))
363+
and ((self.partition_by == 1) or (len(self.partition_by) <= 0))
364+
)
365+
same_windowing = (
366+
data_algebra.expr_rep.implies_windowed(parsed_ops)
367+
== self.windowed_situation
368+
)
369+
if (
370+
compatible_partition
371+
and same_windowing
372+
and (order_by == self.order_by)
373+
and (reverse == self.reverse)
374+
and (
375+
len(new_cols_used_in_calc.intersection(self.cols_produced_in_calc))
376+
== 0
377+
)
378+
and (
379+
len(
380+
new_cols_produced_in_calc.intersection(
381+
self.cols_produced_in_calc
382+
)
383+
)
384+
== 0
385+
)
386+
and (
387+
len(new_cols_produced_in_calc.intersection(self.cols_used_in_calc))
388+
== 0
389+
)
390+
):
391+
# merge the extends
392+
new_ops = self.ops.copy()
393+
new_ops.update(parsed_ops)
394+
return ExtendNode(
395+
source=self.sources[0],
396+
parsed_ops=new_ops,
397+
partition_by=partition_by,
398+
order_by=order_by,
399+
reverse=reverse,
400+
)
344401
return ExtendNode(
345402
source=self,
346-
ops=ops,
403+
parsed_ops=parsed_ops,
347404
partition_by=partition_by,
348405
order_by=order_by,
349406
reverse=reverse,
350-
parse_env=parse_env,
351407
)
352408

353409
def project(self, ops=None, *, group_by=None, parse_env=None):
@@ -357,7 +413,10 @@ def project(self, ops=None, *, group_by=None, parse_env=None):
357413
raise ValueError("must have ops or group_by")
358414
if self.is_trivial_when_intermediate():
359415
return self.sources[0].project(ops, group_by=group_by, parse_env=parse_env)
360-
return ProjectNode(source=self, ops=ops, group_by=group_by, parse_env=parse_env)
416+
parsed_ops = data_algebra.expr_rep.parse_assignments_in_context(
417+
ops, self, parse_env=parse_env
418+
)
419+
return ProjectNode(source=self, parsed_ops=parsed_ops, group_by=group_by)
361420

362421
def natural_join(self, b, *, by=None, jointype="INNER"):
363422
if not isinstance(b, ViewRepresentation):
@@ -793,31 +852,12 @@ def wrap(d, *, table_name="data_frame"):
793852

794853
class ExtendNode(ViewRepresentation):
795854
def __init__(
796-
self,
797-
source,
798-
ops,
799-
*,
800-
partition_by=None,
801-
order_by=None,
802-
reverse=None,
803-
parse_env=None
855+
self, *, source, parsed_ops, partition_by=None, order_by=None, reverse=None,
804856
):
805-
windowed_situation = False
806-
if ops is None:
807-
ops = {}
808-
ops = data_algebra.expr_rep.parse_assignments_in_context(
809-
ops, source, parse_env=parse_env
810-
)
811-
if len(ops) < 1:
812-
raise ValueError("no ops")
813-
for (k, opk) in ops.items(): # look for aggregation functions
814-
if isinstance(opk, data_algebra.expr_rep.Expression):
815-
if (
816-
opk.op
817-
in data_algebra.expr_rep.fn_names_that_imply_windowed_situation
818-
):
819-
windowed_situation = True
820-
self.ops = ops
857+
windowed_situation = data_algebra.expr_rep.implies_windowed(parsed_ops)
858+
self.ops = parsed_ops
859+
self.cols_used_in_calc = data_algebra.expr_rep.get_columns_used(parsed_ops)
860+
self.cols_produced_in_calc = [k for k in parsed_ops.keys()]
821861
if partition_by is None:
822862
partition_by = []
823863
if isinstance(partition_by, numbers.Number):
@@ -843,13 +883,13 @@ def __init__(
843883
self.reverse = reverse
844884
column_names = source.column_names.copy()
845885
consumed_cols = set()
846-
for (k, o) in ops.items():
886+
for (k, o) in parsed_ops.items():
847887
o.get_column_names(consumed_cols)
848888
unknown_cols = consumed_cols - source.column_set
849889
if len(unknown_cols) > 0:
850890
raise KeyError("referred to unknown columns: " + str(unknown_cols))
851891
known_cols = set(column_names)
852-
for ci in ops.keys():
892+
for ci in parsed_ops.keys():
853893
if ci not in known_cols:
854894
column_names.append(ci)
855895
if len(partition_by) != len(set(partition_by)):
@@ -867,14 +907,14 @@ def __init__(
867907
unknown = set(reverse) - set(order_by)
868908
if len(unknown) > 0:
869909
raise ValueError("reverse columns not in order_by: " + str(unknown))
870-
bad_overwrite = set(ops.keys()).intersection(
910+
bad_overwrite = set(parsed_ops.keys()).intersection(
871911
set(partition_by).union(order_by, reverse)
872912
)
873913
if len(bad_overwrite) > 0:
874914
raise ValueError("tried to change: " + str(bad_overwrite))
875915
# check op arguments are very simple: all arguments are column names
876916
if windowed_situation:
877-
for (k, opk) in ops.items():
917+
for (k, opk) in parsed_ops.items():
878918
if not isinstance(opk, data_algebra.expr_rep.Expression):
879919
raise ValueError(
880920
"non-aggregated expression in windowed/partitoned extend: "
@@ -991,13 +1031,8 @@ def eval_implementation(self, *, data_map, eval_env, data_model):
9911031

9921032

9931033
class ProjectNode(ViewRepresentation):
994-
def __init__(self, source, ops=None, *, group_by=None, parse_env=None):
995-
if ops is None:
996-
ops = {}
997-
ops = data_algebra.expr_rep.parse_assignments_in_context(
998-
ops, source, parse_env=parse_env
999-
)
1000-
self.ops = ops
1034+
def __init__(self, *, source, parsed_ops, group_by=None):
1035+
self.ops = parsed_ops
10011036
if group_by is None:
10021037
group_by = []
10031038
if isinstance(group_by, str):
@@ -1007,13 +1042,13 @@ def __init__(self, source, ops=None, *, group_by=None, parse_env=None):
10071042
consumed_cols = set()
10081043
for c in group_by:
10091044
consumed_cols.add(c)
1010-
for (k, o) in ops.items():
1045+
for (k, o) in parsed_ops.items():
10111046
o.get_column_names(consumed_cols)
10121047
unknown_cols = consumed_cols - source.column_set
10131048
if len(unknown_cols) > 0:
10141049
raise KeyError("referred to unknown columns: " + str(unknown_cols))
10151050
known_cols = set(column_names)
1016-
for ci in ops.keys():
1051+
for ci in parsed_ops.keys():
10171052
if ci not in known_cols:
10181053
column_names.append(ci)
10191054
if len(group_by) != len(set(group_by)):

build/lib/data_algebra/expr_rep.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import data_algebra.util
55
import data_algebra.env
66

7-
87
# for some ideas in capturing expressions in Python see:
98
# scipy
109
# pipe-like idea
@@ -1104,3 +1103,26 @@ def standardize_join_type(join_str):
11041103
except KeyError:
11051104
pass
11061105
return join_str
1106+
1107+
1108+
def get_columns_used(parsed_exprs):
1109+
if not isinstance(parsed_exprs, dict):
1110+
raise TypeError(
1111+
"expected parsed_exprs to be a dictionary of data_algebra.expr_rep.Term(s)"
1112+
)
1113+
columns_seen = set()
1114+
for node in parsed_exprs.values():
1115+
node.get_column_names(columns_seen)
1116+
return columns_seen
1117+
1118+
1119+
def implies_windowed(parsed_exprs):
1120+
if not isinstance(parsed_exprs, dict):
1121+
raise TypeError(
1122+
"expected parsed_exprs to be a dictionary of data_algebra.expr_rep.Term(s)"
1123+
)
1124+
for opk in parsed_exprs.values(): # look for aggregation functions
1125+
if isinstance(opk, data_algebra.expr_rep.Expression):
1126+
if opk.op in data_algebra.expr_rep.fn_names_that_imply_windowed_situation:
1127+
return True
1128+
return False

coverage.txt

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,45 +2,45 @@
22
platform darwin -- Python 3.6.9, pytest-5.2.2, py-1.8.0, pluggy-0.13.0
33
rootdir: /Users/johnmount/Documents/work/data_algebra
44
plugins: cov-2.8.1
5-
collected 73 items
5+
collected 74 items
66

77
tests/test_R_yaml.py . [ 1%]
88
tests/test_apply.py . [ 2%]
99
tests/test_arrow1.py . [ 4%]
1010
tests/test_calc_warnings_errors.py . [ 5%]
1111
tests/test_cc.py ...... [ 13%]
12-
tests/test_cdata1.py . [ 15%]
12+
tests/test_cdata1.py . [ 14%]
1313
tests/test_cdata_example.py .... [ 20%]
1414
tests/test_cols_used.py . [ 21%]
15-
tests/test_concat_rows.py . [ 23%]
15+
tests/test_concat_rows.py . [ 22%]
1616
tests/test_degenerate_project.py . [ 24%]
17-
tests/test_drop_columns.py . [ 26%]
17+
tests/test_drop_columns.py . [ 25%]
1818
tests/test_exampe1.py .... [ 31%]
1919
tests/test_example_data_ops.py . [ 32%]
20-
tests/test_exp.py . [ 34%]
20+
tests/test_exp.py . [ 33%]
2121
tests/test_export_neg.py . [ 35%]
2222
tests/test_expr_parse.py . [ 36%]
23-
tests/test_extend.py ... [ 41%]
24-
tests/test_flow_text.py . [ 42%]
25-
tests/test_free_expr.py . [ 43%]
23+
tests/test_extend.py .... [ 41%]
24+
tests/test_flow_text.py . [ 43%]
25+
tests/test_free_expr.py . [ 44%]
2626
tests/test_ghost_col_issue.py . [ 45%]
27-
tests/test_if_else.py . [ 46%]
28-
tests/test_join_check.py . [ 47%]
29-
tests/test_join_effects.py .. [ 50%]
27+
tests/test_if_else.py . [ 47%]
28+
tests/test_join_check.py . [ 48%]
29+
tests/test_join_effects.py .. [ 51%]
3030
tests/test_math.py . [ 52%]
31-
tests/test_natural_join.py . [ 53%]
32-
tests/test_neg.py . [ 54%]
31+
tests/test_natural_join.py . [ 54%]
32+
tests/test_neg.py . [ 55%]
3333
tests/test_null_bad.py . [ 56%]
34-
tests/test_parse.py . [ 57%]
34+
tests/test_parse.py . [ 58%]
3535
tests/test_project.py ..... [ 64%]
36-
tests/test_scatter_example.py . [ 65%]
36+
tests/test_scatter_example.py . [ 66%]
3737
tests/test_scoring_example.py . [ 67%]
3838
tests/test_select_stacking.py . [ 68%]
39-
tests/test_shorten.py . [ 69%]
40-
tests/test_simple.py ..... [ 76%]
39+
tests/test_shorten.py . [ 70%]
40+
tests/test_simple.py ..... [ 77%]
4141
tests/test_spark_sql.py . [ 78%]
4242
tests/test_sqlite.py . [ 79%]
43-
tests/test_strat_example.py . [ 80%]
43+
tests/test_strat_example.py . [ 81%]
4444
tests/test_table_is_key_by_columns.py . [ 82%]
4545
tests/test_transform_examples.py ........... [ 97%]
4646
tests/test_window2.py . [ 98%]
@@ -58,20 +58,20 @@ data_algebra/cdata.py 232 75 68%
5858
data_algebra/cdata_impl.py 10 1 90%
5959
data_algebra/connected_components.py 49 1 98%
6060
data_algebra/data_model.py 29 13 55%
61-
data_algebra/data_ops.py 1230 236 81%
61+
data_algebra/data_ops.py 1244 240 81%
6262
data_algebra/data_ops_types.py 42 16 62%
6363
data_algebra/db_model.py 389 72 81%
6464
data_algebra/diagram.py 56 43 23%
6565
data_algebra/env.py 31 3 90%
6666
data_algebra/expr.py 20 4 80%
67-
data_algebra/expr_rep.py 638 209 67%
67+
data_algebra/expr_rep.py 653 210 68%
6868
data_algebra/flow_text.py 17 0 100%
6969
data_algebra/pandas_model.py 182 20 89%
7070
data_algebra/test_util.py 119 17 86%
7171
data_algebra/util.py 44 10 77%
7272
data_algebra/yaml.py 101 13 87%
7373
----------------------------------------------------------
74-
TOTAL 3464 795 77%
74+
TOTAL 3493 800 77%
7575

7676

77-
============================== 73 passed in 7.55s ==============================
77+
============================== 74 passed in 6.94s ==============================

0 commit comments

Comments
 (0)