Skip to content

Commit e2d9740

Browse files
committed
fix arrow issue
1 parent c298f3b commit e2d9740

File tree

6 files changed

+108
-150
lines changed

6 files changed

+108
-150
lines changed

build/lib/data_algebra/arrow.py

Lines changed: 51 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,22 @@ def cod(self):
1717
"""return co-domain, object at head of arrow"""
1818
raise NotImplementedError("base class called")
1919

20+
def apply_to(self, b, *, strict=True):
21+
""" apply_to b, compose arrows (right to left) """
22+
raise NotImplementedError("base class called")
23+
2024
# noinspection PyPep8Naming
21-
def transform(self, X, *, strict=True):
22-
""" transform X, compose arrows (right to left) """
25+
def transform(self, X):
26+
""" transform X, act on X """
2327
raise NotImplementedError("base class called")
2428

2529
def __rshift__(self, other): # override self >> other
26-
return other.transform(self, strict=True)
30+
return other.apply_to(self, strict=True)
2731

2832
def __rrshift__(self, other): # override other >> self
29-
return self.transform(other, strict=True)
33+
if isinstance(other, Arrow):
34+
return self.apply_to(other, strict=True)
35+
return self.transform(other)
3036

3137

3238
class DataOpArrow(Arrow):
@@ -60,83 +66,56 @@ def __init__(self, pipeline, *, free_table_key=None):
6066
self.outgoing_types = self.incoming_types.copy()
6167
Arrow.__init__(self)
6268

63-
# noinspection PyPep8Naming
64-
def transform(self, X, *, strict=True):
65-
"""replace self input table with X"""
66-
if isinstance(X, data_algebra.data_ops.ViewRepresentation):
67-
X = DataOpArrow(X)
68-
if isinstance(X, DataOpArrow):
69-
missing = set(self.incoming_columns) - set(X.outgoing_columns)
70-
if len(missing) > 0:
71-
raise ValueError("missing required columns: " + str(missing))
72-
excess = set(X.outgoing_columns) - set(self.incoming_columns)
69+
def apply_to(self, b, *, strict=True):
70+
"""replace self input table with b"""
71+
if isinstance(b, data_algebra.data_ops.ViewRepresentation):
72+
b = DataOpArrow(b)
73+
if not isinstance(b, DataOpArrow):
74+
raise TypeError("unexpected type: " + str(type(b)))
75+
missing = set(self.incoming_columns) - set(b.outgoing_columns)
76+
if len(missing) > 0:
77+
raise ValueError("missing required columns: " + str(missing))
78+
if strict:
79+
excess = set(b.outgoing_columns) - set(self.incoming_columns)
7380
if len(excess) > 0:
74-
if strict:
75-
raise ValueError("extra incoming columns: " + str(excess))
76-
# check categorical arrow composition conditions
77-
if set(self.incoming_columns) != set(X.outgoing_columns):
78-
raise ValueError(
79-
"arrow composition conditions not met (incoming column set doesn't match outgoing)"
80-
)
81-
if (self.incoming_types is not None) and (X.outgoing_types is not None):
82-
for c in self.incoming_columns:
83-
st = self.incoming_types[c]
84-
xt = X.outgoing_types[c]
85-
if st != xt:
86-
raise ValueError(
87-
"column "
88-
+ c
89-
+ " self incoming type is "
90-
+ str(st)
91-
+ ", while X outgoing type is "
92-
+ str(xt)
93-
)
94-
new_pipeline = self.pipeline.apply_to(
95-
X.pipeline, target_table_key=self.free_table_key
81+
raise ValueError("extra incoming columns: " + str(excess))
82+
# check categorical arrow composition conditions
83+
if set(self.incoming_columns) != set(b.outgoing_columns):
84+
raise ValueError(
85+
"arrow composition conditions not met (incoming column set doesn't match outgoing)"
9686
)
97-
res = DataOpArrow(pipeline=new_pipeline, free_table_key=X.free_table_key)
98-
# res = DataOpArrow(
99-
# X.pipeline.stand_in_for_table(
100-
# ops=self.pipeline, table_key=self.free_table_key
101-
# )
102-
# )
103-
res.incoming_types = X.incoming_types
104-
res.outgoing_types = self.outgoing_types
105-
return res
106-
if isinstance(X, list) or isinstance(X, tuple) or isinstance(X, set):
107-
# schema type object
108-
if set(self.incoming_columns) != set(X):
109-
raise ValueError("input did not match arrow dom()")
110-
return self.cod()
111-
if isinstance(X, dict):
112-
# schema type object
113-
if set(self.incoming_columns) != set(X.keys()):
114-
raise ValueError("input did not match arrow dom()")
115-
if self.incoming_types is not None:
116-
for c in self.incoming_columns:
117-
st = self.incoming_types[c]
118-
xt = X[c]
119-
if st != xt:
120-
raise ValueError(
121-
"column "
122-
+ c
123-
+ " self incoming type is "
124-
+ str(st)
125-
+ ", while X outgoing type is "
126-
+ str(xt)
127-
)
128-
return self.cod()
87+
if (self.incoming_types is not None) and (b.outgoing_types is not None):
88+
for c in self.incoming_columns:
89+
st = self.incoming_types[c]
90+
xt = b.outgoing_types[c]
91+
if st != xt:
92+
raise ValueError(
93+
"column "
94+
+ c
95+
+ " self incoming type is "
96+
+ str(st)
97+
+ ", while b outgoing type is "
98+
+ str(xt)
99+
)
100+
new_pipeline = self.pipeline.apply_to(
101+
b.pipeline, target_table_key=self.free_table_key
102+
)
103+
new_pipeline.get_tables() # check tables are compatible
104+
res = DataOpArrow(pipeline=new_pipeline, free_table_key=b.free_table_key)
105+
res.incoming_types = b.incoming_types
106+
res.outgoing_types = self.outgoing_types
107+
return res
108+
109+
# noinspection PyPep8Naming
110+
def transform(self, X):
129111
# assume a pandas.DataFrame compatible object
130112
# noinspection PyUnresolvedReferences
131113
cols = set(X.columns)
132114
missing = set(self.incoming_columns) - cols
133115
if len(missing) > 0:
134116
raise ValueError("missing required columns: " + str(missing))
135117
excess = cols - set(self.incoming_columns)
136-
if len(excess):
137-
if strict:
138-
raise ValueError("extra incoming columns: " + str(excess))
139-
# noinspection PyUnresolvedReferences
118+
if len(excess) > 0:
140119
X = X[self.incoming_columns]
141120
return self.pipeline.transform(X)
142121

coverage.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ data_algebra/PostgreSQL.py 21 4 81%
5454
data_algebra/SQLite.py 90 13 86%
5555
data_algebra/SparkSQL.py 21 4 81%
5656
data_algebra/__init__.py 5 0 100%
57-
data_algebra/arrow.py 144 50 65%
57+
data_algebra/arrow.py 135 39 71%
5858
data_algebra/cdata.py 232 75 68%
5959
data_algebra/cdata_impl.py 10 1 90%
6060
data_algebra/connected_components.py 49 1 98%
@@ -74,7 +74,7 @@ data_algebra/test_util.py 119 17 86%
7474
data_algebra/util.py 44 10 77%
7575
data_algebra/yaml.py 102 13 87%
7676
----------------------------------------------------------
77-
TOTAL 3702 929 75%
77+
TOTAL 3693 918 75%
7878

7979

80-
============================== 77 passed in 9.63s ==============================
80+
============================== 77 passed in 6.58s ==============================

data_algebra/arrow.py

Lines changed: 51 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,22 @@ def cod(self):
1717
"""return co-domain, object at head of arrow"""
1818
raise NotImplementedError("base class called")
1919

20+
def apply_to(self, b, *, strict=True):
21+
""" apply_to b, compose arrows (right to left) """
22+
raise NotImplementedError("base class called")
23+
2024
# noinspection PyPep8Naming
21-
def transform(self, X, *, strict=True):
22-
""" transform X, compose arrows (right to left) """
25+
def transform(self, X):
26+
""" transform X, act on X """
2327
raise NotImplementedError("base class called")
2428

2529
def __rshift__(self, other): # override self >> other
26-
return other.transform(self, strict=True)
30+
return other.apply_to(self, strict=True)
2731

2832
def __rrshift__(self, other): # override other >> self
29-
return self.transform(other, strict=True)
33+
if isinstance(other, Arrow):
34+
return self.apply_to(other, strict=True)
35+
return self.transform(other)
3036

3137

3238
class DataOpArrow(Arrow):
@@ -60,83 +66,56 @@ def __init__(self, pipeline, *, free_table_key=None):
6066
self.outgoing_types = self.incoming_types.copy()
6167
Arrow.__init__(self)
6268

63-
# noinspection PyPep8Naming
64-
def transform(self, X, *, strict=True):
65-
"""replace self input table with X"""
66-
if isinstance(X, data_algebra.data_ops.ViewRepresentation):
67-
X = DataOpArrow(X)
68-
if isinstance(X, DataOpArrow):
69-
missing = set(self.incoming_columns) - set(X.outgoing_columns)
70-
if len(missing) > 0:
71-
raise ValueError("missing required columns: " + str(missing))
72-
excess = set(X.outgoing_columns) - set(self.incoming_columns)
69+
def apply_to(self, b, *, strict=True):
70+
"""replace self input table with b"""
71+
if isinstance(b, data_algebra.data_ops.ViewRepresentation):
72+
b = DataOpArrow(b)
73+
if not isinstance(b, DataOpArrow):
74+
raise TypeError("unexpected type: " + str(type(b)))
75+
missing = set(self.incoming_columns) - set(b.outgoing_columns)
76+
if len(missing) > 0:
77+
raise ValueError("missing required columns: " + str(missing))
78+
if strict:
79+
excess = set(b.outgoing_columns) - set(self.incoming_columns)
7380
if len(excess) > 0:
74-
if strict:
75-
raise ValueError("extra incoming columns: " + str(excess))
76-
# check categorical arrow composition conditions
77-
if set(self.incoming_columns) != set(X.outgoing_columns):
78-
raise ValueError(
79-
"arrow composition conditions not met (incoming column set doesn't match outgoing)"
80-
)
81-
if (self.incoming_types is not None) and (X.outgoing_types is not None):
82-
for c in self.incoming_columns:
83-
st = self.incoming_types[c]
84-
xt = X.outgoing_types[c]
85-
if st != xt:
86-
raise ValueError(
87-
"column "
88-
+ c
89-
+ " self incoming type is "
90-
+ str(st)
91-
+ ", while X outgoing type is "
92-
+ str(xt)
93-
)
94-
new_pipeline = self.pipeline.apply_to(
95-
X.pipeline, target_table_key=self.free_table_key
81+
raise ValueError("extra incoming columns: " + str(excess))
82+
# check categorical arrow composition conditions
83+
if set(self.incoming_columns) != set(b.outgoing_columns):
84+
raise ValueError(
85+
"arrow composition conditions not met (incoming column set doesn't match outgoing)"
9686
)
97-
res = DataOpArrow(pipeline=new_pipeline, free_table_key=X.free_table_key)
98-
# res = DataOpArrow(
99-
# X.pipeline.stand_in_for_table(
100-
# ops=self.pipeline, table_key=self.free_table_key
101-
# )
102-
# )
103-
res.incoming_types = X.incoming_types
104-
res.outgoing_types = self.outgoing_types
105-
return res
106-
if isinstance(X, list) or isinstance(X, tuple) or isinstance(X, set):
107-
# schema type object
108-
if set(self.incoming_columns) != set(X):
109-
raise ValueError("input did not match arrow dom()")
110-
return self.cod()
111-
if isinstance(X, dict):
112-
# schema type object
113-
if set(self.incoming_columns) != set(X.keys()):
114-
raise ValueError("input did not match arrow dom()")
115-
if self.incoming_types is not None:
116-
for c in self.incoming_columns:
117-
st = self.incoming_types[c]
118-
xt = X[c]
119-
if st != xt:
120-
raise ValueError(
121-
"column "
122-
+ c
123-
+ " self incoming type is "
124-
+ str(st)
125-
+ ", while X outgoing type is "
126-
+ str(xt)
127-
)
128-
return self.cod()
87+
if (self.incoming_types is not None) and (b.outgoing_types is not None):
88+
for c in self.incoming_columns:
89+
st = self.incoming_types[c]
90+
xt = b.outgoing_types[c]
91+
if st != xt:
92+
raise ValueError(
93+
"column "
94+
+ c
95+
+ " self incoming type is "
96+
+ str(st)
97+
+ ", while b outgoing type is "
98+
+ str(xt)
99+
)
100+
new_pipeline = self.pipeline.apply_to(
101+
b.pipeline, target_table_key=self.free_table_key
102+
)
103+
new_pipeline.get_tables() # check tables are compatible
104+
res = DataOpArrow(pipeline=new_pipeline, free_table_key=b.free_table_key)
105+
res.incoming_types = b.incoming_types
106+
res.outgoing_types = self.outgoing_types
107+
return res
108+
109+
# noinspection PyPep8Naming
110+
def transform(self, X):
129111
# assume a pandas.DataFrame compatible object
130112
# noinspection PyUnresolvedReferences
131113
cols = set(X.columns)
132114
missing = set(self.incoming_columns) - cols
133115
if len(missing) > 0:
134116
raise ValueError("missing required columns: " + str(missing))
135117
excess = cols - set(self.incoming_columns)
136-
if len(excess):
137-
if strict:
138-
raise ValueError("extra incoming columns: " + str(excess))
139-
# noinspection PyUnresolvedReferences
118+
if len(excess) > 0:
140119
X = X[self.incoming_columns]
141120
return self.pipeline.transform(X)
142121

-94 Bytes
Binary file not shown.

dist/data_algebra-0.3.8.tar.gz

-78 Bytes
Binary file not shown.

tests/test_arrow1.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_arrow1():
5252
assert data_algebra.test_util.equivalent_frames(d >> a1, d >> li >> a1)
5353
assert data_algebra.test_util.equivalent_frames(d >> a1, d >> a1 >> ri)
5454
a1.dom() >> a1
55-
a1.transform(a1.dom())
55+
a1.apply_to(a1.dom())
5656

5757
# print(a1)
5858

@@ -151,7 +151,7 @@ def test_arrow1():
151151

152152
# %%
153153

154-
f0 = (a3.transform(a2.transform(a1))).pipeline.__repr__()
154+
f0 = (a3.apply_to(a2.apply_to(a1))).pipeline.__repr__()
155155
f1 = (a1 >> a2 >> a3).pipeline.__repr__()
156156

157157
assert f1 == f0
@@ -210,7 +210,7 @@ def test_arrow1():
210210
assert data_algebra.test_util.equivalent_frames(d >> a1, d >> li >> a1)
211211
assert data_algebra.test_util.equivalent_frames(d >> a1, d >> a1 >> ri)
212212
a1.dom() >> a1
213-
a1.transform(a1.dom())
213+
a1.apply_to(a1.dom())
214214

215215

216216
def test_arrow_cod_dom():

0 commit comments

Comments
 (0)