Skip to content

Commit 1eb3b66

Browse files
committed
check project a bit earlier
1 parent b2861f2 commit 1eb3b66

File tree

5 files changed

+33
-21
lines changed

5 files changed

+33
-21
lines changed

build/lib/data_algebra/data_ops.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -331,16 +331,16 @@ def extend(
331331
):
332332
if (ops is None) or (len(ops) < 1):
333333
return self
334-
parsed_ops = data_algebra.expr_rep.parse_assignments_in_context(
335-
ops, self, parse_env=parse_env
336-
)
337-
new_cols_used_in_calc = set(data_algebra.expr_rep.get_columns_used(parsed_ops))
338334
if partition_by is None:
339335
partition_by = []
340336
if order_by is None:
341337
order_by = []
342338
if reverse is None:
343339
reverse = []
340+
parsed_ops = data_algebra.expr_rep.parse_assignments_in_context(
341+
ops, self, parse_env=parse_env
342+
)
343+
new_cols_used_in_calc = set(data_algebra.expr_rep.get_columns_used(parsed_ops))
344344
new_cols_produced_in_calc = set([k for k in parsed_ops.keys()])
345345
if (partition_by != 1) and (len(partition_by) > 0):
346346
if len(new_cols_produced_in_calc.intersection(partition_by)) > 0:
@@ -357,6 +357,7 @@ def extend(
357357
reverse=reverse,
358358
parse_env=parse_env,
359359
)
360+
# see if we can combine nodes
360361
if isinstance(self, ExtendNode):
361362
compatible_partition = (partition_by == self.partition_by) or (
362363
((partition_by == 1) or (len(partition_by) <= 0))
@@ -398,6 +399,7 @@ def extend(
398399
order_by=order_by,
399400
reverse=reverse,
400401
)
402+
# new node
401403
return ExtendNode(
402404
source=self,
403405
parsed_ops=parsed_ops,
@@ -407,15 +409,19 @@ def extend(
407409
)
408410

409411
def project(self, ops=None, *, group_by=None, parse_env=None):
410-
if ((ops is None) or (len(ops) < 1)) and (
411-
(group_by is None) or (len(group_by) < 1)
412-
):
412+
if group_by is None:
413+
group_by = []
414+
if ((ops is None) or (len(ops) < 1)) and (len(group_by) < 1):
413415
raise ValueError("must have ops or group_by")
414-
if self.is_trivial_when_intermediate():
415-
return self.sources[0].project(ops, group_by=group_by, parse_env=parse_env)
416416
parsed_ops = data_algebra.expr_rep.parse_assignments_in_context(
417417
ops, self, parse_env=parse_env
418418
)
419+
new_cols_used_in_calc = set(data_algebra.expr_rep.get_columns_used(parsed_ops))
420+
new_cols_produced_in_calc = set([k for k in parsed_ops.keys()])
421+
if len(new_cols_used_in_calc.intersection(group_by)):
422+
raise ValueError("can not alter grouping columns")
423+
if self.is_trivial_when_intermediate():
424+
return self.sources[0].project(ops, group_by=group_by, parse_env=parse_env)
419425
return ProjectNode(source=self, parsed_ops=parsed_ops, group_by=group_by)
420426

421427
def natural_join(self, b, *, by=None, jointype="INNER"):

coverage.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ data_algebra/cdata.py 232 75 68%
5858
data_algebra/cdata_impl.py 10 1 90%
5959
data_algebra/connected_components.py 49 1 98%
6060
data_algebra/data_model.py 29 13 55%
61-
data_algebra/data_ops.py 1244 240 81%
61+
data_algebra/data_ops.py 1250 242 81%
6262
data_algebra/data_ops_types.py 42 16 62%
6363
data_algebra/db_model.py 389 72 81%
6464
data_algebra/diagram.py 56 43 23%
@@ -71,7 +71,7 @@ data_algebra/test_util.py 119 17 86%
7171
data_algebra/util.py 44 10 77%
7272
data_algebra/yaml.py 101 13 87%
7373
----------------------------------------------------------
74-
TOTAL 3493 800 77%
74+
TOTAL 3499 802 77%
7575

7676

77-
============================== 74 passed in 6.94s ==============================
77+
============================== 74 passed in 7.19s ==============================

data_algebra/data_ops.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -331,16 +331,16 @@ def extend(
331331
):
332332
if (ops is None) or (len(ops) < 1):
333333
return self
334-
parsed_ops = data_algebra.expr_rep.parse_assignments_in_context(
335-
ops, self, parse_env=parse_env
336-
)
337-
new_cols_used_in_calc = set(data_algebra.expr_rep.get_columns_used(parsed_ops))
338334
if partition_by is None:
339335
partition_by = []
340336
if order_by is None:
341337
order_by = []
342338
if reverse is None:
343339
reverse = []
340+
parsed_ops = data_algebra.expr_rep.parse_assignments_in_context(
341+
ops, self, parse_env=parse_env
342+
)
343+
new_cols_used_in_calc = set(data_algebra.expr_rep.get_columns_used(parsed_ops))
344344
new_cols_produced_in_calc = set([k for k in parsed_ops.keys()])
345345
if (partition_by != 1) and (len(partition_by) > 0):
346346
if len(new_cols_produced_in_calc.intersection(partition_by)) > 0:
@@ -357,6 +357,7 @@ def extend(
357357
reverse=reverse,
358358
parse_env=parse_env,
359359
)
360+
# see if we can combine nodes
360361
if isinstance(self, ExtendNode):
361362
compatible_partition = (partition_by == self.partition_by) or (
362363
((partition_by == 1) or (len(partition_by) <= 0))
@@ -398,6 +399,7 @@ def extend(
398399
order_by=order_by,
399400
reverse=reverse,
400401
)
402+
# new node
401403
return ExtendNode(
402404
source=self,
403405
parsed_ops=parsed_ops,
@@ -407,15 +409,19 @@ def extend(
407409
)
408410

409411
def project(self, ops=None, *, group_by=None, parse_env=None):
410-
if ((ops is None) or (len(ops) < 1)) and (
411-
(group_by is None) or (len(group_by) < 1)
412-
):
412+
if group_by is None:
413+
group_by = []
414+
if ((ops is None) or (len(ops) < 1)) and (len(group_by) < 1):
413415
raise ValueError("must have ops or group_by")
414-
if self.is_trivial_when_intermediate():
415-
return self.sources[0].project(ops, group_by=group_by, parse_env=parse_env)
416416
parsed_ops = data_algebra.expr_rep.parse_assignments_in_context(
417417
ops, self, parse_env=parse_env
418418
)
419+
new_cols_used_in_calc = set(data_algebra.expr_rep.get_columns_used(parsed_ops))
420+
new_cols_produced_in_calc = set([k for k in parsed_ops.keys()])
421+
if len(new_cols_used_in_calc.intersection(group_by)):
422+
raise ValueError("can not alter grouping columns")
423+
if self.is_trivial_when_intermediate():
424+
return self.sources[0].project(ops, group_by=group_by, parse_env=parse_env)
419425
return ProjectNode(source=self, parsed_ops=parsed_ops, group_by=group_by)
420426

421427
def natural_join(self, b, *, by=None, jointype="INNER"):
36 Bytes
Binary file not shown.

dist/data_algebra-0.3.6.tar.gz

39 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)