22
33import data_algebra .data_ops
44import data_algebra .flow_text
5+ from data_algebra .shift_pipe_action import ShiftPipeAction
56
67
7- class Arrow (abc . ABC ):
8+ class Arrow (ShiftPipeAction ):
89 """
910 Arrow from category theory: see Steve Awody,
1011 "Category Theory, 2nd Edition", Oxford Univ. Press, 2010 pg. 4.
@@ -13,7 +14,7 @@ class Arrow(abc.ABC):
1314 """
1415
1516 def __init__ (self ):
16- pass
17+ ShiftPipeAction . __init__ ( self )
1718
1819 @abc .abstractmethod
1920 def dom (self ):
@@ -23,28 +24,16 @@ def dom(self):
2324 def cod (self ):
2425 """return co-domain, object at head of arrow"""
2526
26- @abc .abstractmethod
27- def apply_to (self , b ):
28- """apply_to b, compose arrows (right to left)"""
29-
3027 # noinspection PyPep8Naming
3128 @abc .abstractmethod
32- def act_on (self , X ):
33- """act on X , must associate with composition"""
29+ def act_on (self , b ):
30+ """act on b , must associate with composition"""
3431
3532 # noinspection PyPep8Naming
3633 def transform (self , X ):
3734 """transform X, may or may not associate with composition"""
3835 return self .act_on (X )
3936
40- def __rshift__ (self , other ): # override self >> other
41- return other .apply_to (self )
42-
43- def __rrshift__ (self , other ): # override other >> self
44- if isinstance (other , Arrow ):
45- return self .apply_to (other )
46- return self .act_on (other )
47-
4837
4938class DataOpArrow (Arrow ):
5039 """
@@ -79,39 +68,44 @@ def get_feature_names(self):
7968 cp = self .outgoing_columns .copy ()
8069 return cp
8170
82- def apply_to (self , b ):
83- """replace self input table with b"""
71+ def act_on (self , b , * , correct_ordered_first_call : bool = False ):
72+ """
73+ Apply self onto b.
74+
75+ :param b: item to act on, or item that has been sent to self.
76+ :param correct_ordered_first_call: if True indicates this call is from __rshift__ or __rrshift__ and not the fallback paths.
77+ """
78+ assert isinstance (correct_ordered_first_call , bool )
8479 if isinstance (b , data_algebra .data_ops .ViewRepresentation ):
8580 b = DataOpArrow (b )
86- assert isinstance (b , DataOpArrow )
87- # check categorical arrow composition conditions
88- missing = set (self .incoming_columns ) - set (b .outgoing_columns )
89- if len (missing ) > 0 :
90- raise ValueError ("missing required columns: " + str (missing ))
91- excess = set (b .outgoing_columns ) - set (self .incoming_columns )
92- if len (excess ) > 0 :
93- raise ValueError ("extra incoming columns: " + str (excess ))
94- new_pipeline = self .pipeline .replace_leaves ({self .free_table_key : b .pipeline })
95- new_pipeline .get_tables () # check tables are compatible
96- res = DataOpArrow (
97- pipeline = new_pipeline ,
98- free_table_key = b .free_table_key ,
99- )
100- return res
101-
102- # noinspection PyPep8Naming
103- def act_on (self , X ):
81+ if isinstance (b , DataOpArrow ):
82+ # check categorical arrow composition conditions
83+ missing = set (self .incoming_columns ) - set (b .outgoing_columns )
84+ if len (missing ) > 0 :
85+ raise ValueError ("missing required columns: " + str (missing ))
86+ excess = set (b .outgoing_columns ) - set (self .incoming_columns )
87+ if len (excess ) > 0 :
88+ raise ValueError ("extra incoming columns: " + str (excess ))
89+ new_pipeline = self .pipeline .replace_leaves ({self .free_table_key : b .pipeline })
90+ new_pipeline .get_tables () # check tables are compatible
91+ res = DataOpArrow (
92+ pipeline = new_pipeline ,
93+ free_table_key = b .free_table_key ,
94+ )
95+ return res
96+ if correct_ordered_first_call and isinstance (b , ShiftPipeAction ):
97+ return b .act_on (self , correct_ordered_first_call = False ) # fall back
10498 # assume a pandas.DataFrame compatible object
10599 # noinspection PyUnresolvedReferences
106- cols = set (X .columns )
100+ cols = set (b .columns )
107101 missing = set (self .incoming_columns ) - cols
108102 if len (missing ) > 0 :
109103 raise ValueError ("missing required columns: " + str (missing ))
110104 excess = cols - set (self .incoming_columns )
111105 assert len (excess ) == 0
112106 if len (excess ) > 0 :
113- X = X [self .incoming_columns ]
114- return self .pipeline .act_on (X )
107+ b = b [self .incoming_columns ]
108+ return self .pipeline .act_on (b )
115109
116110 def dom (self ):
117111 return DataOpArrow (
0 commit comments