@@ -17,16 +17,22 @@ def cod(self):
1717 """return co-domain, object at head of arrow"""
1818 raise NotImplementedError ("base class called" )
1919
20+ def apply_to (self , b , * , strict = True ):
21+ """ apply_to b, compose arrows (right to left) """
22+ raise NotImplementedError ("base class called" )
23+
2024 # noinspection PyPep8Naming
21- def transform (self , X , * , strict = True ):
22- """ transform X, compose arrows (right to left) """
25+ def transform (self , X ):
26+ """ transform X, act on X """
2327 raise NotImplementedError ("base class called" )
2428
2529 def __rshift__ (self , other ): # override self >> other
26- return other .transform (self , strict = True )
30+ return other .apply_to (self , strict = True )
2731
2832 def __rrshift__ (self , other ): # override other >> self
29- return self .transform (other , strict = True )
33+ if isinstance (other , Arrow ):
34+ return self .apply_to (other , strict = True )
35+ return self .transform (other )
3036
3137
3238class DataOpArrow (Arrow ):
@@ -60,83 +66,56 @@ def __init__(self, pipeline, *, free_table_key=None):
6066 self .outgoing_types = self .incoming_types .copy ()
6167 Arrow .__init__ (self )
6268
63- # noinspection PyPep8Naming
64- def transform (self , X , * , strict = True ):
65- """replace self input table with X"""
66- if isinstance (X , data_algebra .data_ops .ViewRepresentation ):
67- X = DataOpArrow (X )
68- if isinstance (X , DataOpArrow ):
69- missing = set (self .incoming_columns ) - set (X .outgoing_columns )
70- if len (missing ) > 0 :
71- raise ValueError ("missing required columns: " + str (missing ))
72- excess = set (X .outgoing_columns ) - set (self .incoming_columns )
69+ def apply_to (self , b , * , strict = True ):
70+ """replace self input table with b"""
71+ if isinstance (b , data_algebra .data_ops .ViewRepresentation ):
72+ b = DataOpArrow (b )
73+ if not isinstance (b , DataOpArrow ):
74+ raise TypeError ("unexpected type: " + str (type (b )))
75+ missing = set (self .incoming_columns ) - set (b .outgoing_columns )
76+ if len (missing ) > 0 :
77+ raise ValueError ("missing required columns: " + str (missing ))
78+ if strict :
79+ excess = set (b .outgoing_columns ) - set (self .incoming_columns )
7380 if len (excess ) > 0 :
74- if strict :
75- raise ValueError ("extra incoming columns: " + str (excess ))
76- # check categorical arrow composition conditions
77- if set (self .incoming_columns ) != set (X .outgoing_columns ):
78- raise ValueError (
79- "arrow composition conditions not met (incoming column set doesn't match outgoing)"
80- )
81- if (self .incoming_types is not None ) and (X .outgoing_types is not None ):
82- for c in self .incoming_columns :
83- st = self .incoming_types [c ]
84- xt = X .outgoing_types [c ]
85- if st != xt :
86- raise ValueError (
87- "column "
88- + c
89- + " self incoming type is "
90- + str (st )
91- + ", while X outgoing type is "
92- + str (xt )
93- )
94- new_pipeline = self .pipeline .apply_to (
95- X .pipeline , target_table_key = self .free_table_key
81+ raise ValueError ("extra incoming columns: " + str (excess ))
82+ # check categorical arrow composition conditions
83+ if set (self .incoming_columns ) != set (b .outgoing_columns ):
84+ raise ValueError (
85+ "arrow composition conditions not met (incoming column set doesn't match outgoing)"
9686 )
97- res = DataOpArrow (pipeline = new_pipeline , free_table_key = X .free_table_key )
98- # res = DataOpArrow(
99- # X.pipeline.stand_in_for_table(
100- # ops=self.pipeline, table_key=self.free_table_key
101- # )
102- # )
103- res .incoming_types = X .incoming_types
104- res .outgoing_types = self .outgoing_types
105- return res
106- if isinstance (X , list ) or isinstance (X , tuple ) or isinstance (X , set ):
107- # schema type object
108- if set (self .incoming_columns ) != set (X ):
109- raise ValueError ("input did not match arrow dom()" )
110- return self .cod ()
111- if isinstance (X , dict ):
112- # schema type object
113- if set (self .incoming_columns ) != set (X .keys ()):
114- raise ValueError ("input did not match arrow dom()" )
115- if self .incoming_types is not None :
116- for c in self .incoming_columns :
117- st = self .incoming_types [c ]
118- xt = X [c ]
119- if st != xt :
120- raise ValueError (
121- "column "
122- + c
123- + " self incoming type is "
124- + str (st )
125- + ", while X outgoing type is "
126- + str (xt )
127- )
128- return self .cod ()
87+ if (self .incoming_types is not None ) and (b .outgoing_types is not None ):
88+ for c in self .incoming_columns :
89+ st = self .incoming_types [c ]
90+ xt = b .outgoing_types [c ]
91+ if st != xt :
92+ raise ValueError (
93+ "column "
94+ + c
95+ + " self incoming type is "
96+ + str (st )
97+ + ", while b outgoing type is "
98+ + str (xt )
99+ )
100+ new_pipeline = self .pipeline .apply_to (
101+ b .pipeline , target_table_key = self .free_table_key
102+ )
103+ new_pipeline .get_tables () # check tables are compatible
104+ res = DataOpArrow (pipeline = new_pipeline , free_table_key = b .free_table_key )
105+ res .incoming_types = b .incoming_types
106+ res .outgoing_types = self .outgoing_types
107+ return res
108+
109+ # noinspection PyPep8Naming
110+ def transform (self , X ):
129111 # assume a pandas.DataFrame compatible object
130112 # noinspection PyUnresolvedReferences
131113 cols = set (X .columns )
132114 missing = set (self .incoming_columns ) - cols
133115 if len (missing ) > 0 :
134116 raise ValueError ("missing required columns: " + str (missing ))
135117 excess = cols - set (self .incoming_columns )
136- if len (excess ):
137- if strict :
138- raise ValueError ("extra incoming columns: " + str (excess ))
139- # noinspection PyUnresolvedReferences
118+ if len (excess ) > 0 :
140119 X = X [self .incoming_columns ]
141120 return self .pipeline .transform (X )
142121
0 commit comments