@@ -60,73 +60,84 @@ def _compatibility_check(self, x, y):
6060 return x , y
6161
6262 def _prep_index (self , x , y ):
63- """ Preps index and column axes for arithmetic """
6463 if x .kdims .shape [0 ] == 1 and y .kdims .shape [0 ] > 1 :
65- # Broadcast x to y
6664 x .kdims = y .kdims
6765 x .key_labels = y .key_labels
6866 return x , y
6967 if x .kdims .shape [0 ] > 1 and y .kdims .shape [0 ] == 1 :
70- # Broadcast y to x
7168 y .kdims = x .kdims
7269 y .key_labels = x .key_labels
7370 return x , y
7471 if x .kdims .shape [0 ] == y .kdims .shape [0 ] == 1 and x .key_labels != y .key_labels :
75- # Broadcast to the triangle with a larger multi-index
7672 kdims = x .kdims if len (x .key_labels ) > len (y .key_labels ) else y .kdims
77- y .kdims = x .kdims = kdims
7873 key_labels = x .key_labels if len (x .key_labels ) > len (y .key_labels ) else y .key_labels
79- y .key_labels = x .key_labels = key_labels
74+ x .kdims = y .kdims = kdims
75+ x .key_labels = y .key_labels = key_labels
8076 return x , y
81- a , b = set (x .key_labels ), set (y .key_labels )
82- common = a .intersection (b )
83- if common in [a , b ] and (a != b or (a == b and x .kdims .shape [0 ] != y .kdims .shape [0 ])):
84- # If index labels are subset of other triangle index labels
85- x = x .groupby (list (common ))
86- y = y .groupby (list (common ))
87- return x , y
88- if common not in [a , b ]:
89- raise ValueError ('Index broadcasting is ambiguous between' , str (a ), 'and' , str (b ))
90- if (
91- x .key_labels == y .key_labels
92- and x .kdims .shape [0 ] == y .kdims .shape [0 ]
93- and y .kdims .shape [0 ] > 1
94- and not x .kdims is y .kdims
95- and not x .index .equals (y .index )
96- ):
97- # Make sure exact but unsorted index labels works
98- x = x .sort_index ()
99- try :
100- y = y .loc [x .index ]
101- except :
77+
78+ # Use sets for faster operations
79+ x_labels = set (x .key_labels )
80+ y_labels = set (y .key_labels )
81+ common = x_labels .intersection (y_labels )
82+
83+ if common == x_labels or common == y_labels :
84+ if x_labels != y_labels or x .kdims .shape [0 ] != y .kdims .shape [0 ]:
10285 x = x .groupby (list (common ))
10386 y = y .groupby (list (common ))
87+ elif x .kdims .shape [0 ] > 1 and not np .array_equal (x .kdims , y .kdims ) and not x .index .equals (y .index ):
88+ x = x .sort_index ()
89+ try :
90+ y = y .loc [x .index ]
91+ except :
92+ x = x .groupby (list (common ))
93+ y = y .groupby (list (common ))
94+ return x , y
95+
96+ if common != x_labels and common != y_labels :
97+ raise ValueError ('Index broadcasting is ambiguous between ' + str (x_labels ) + ' and ' + str (y_labels ))
98+
10499 return x , y
105100
106101 def _prep_columns (self , x , y ):
107102 x_backend , y_backend = x .array_backend , y .array_backend
103+
108104 if len (x .columns ) == 1 and len (y .columns ) > 1 :
109105 x .vdims = y .vdims
110106 elif len (y .columns ) == 1 and len (x .columns ) > 1 :
111107 y .vdims = x .vdims
112- elif len (y .columns ) == 1 and len (x .columns ) == 1 and x .columns != y .columns :
108+ elif len (y .columns ) == len (x .columns ) == 1 and x .columns != y .columns :
113109 y .vdims = x .vdims
114- elif x .shape [1 ] == y .shape [1 ] and np .all (x .columns == y .columns ):
115- pass
110+ elif x .shape [1 ] == y .shape [1 ] and np .array_equal (x .columns , y .columns ):
111+ return x , y
116112 else :
117- col_union = list (x .columns ) + [
118- item for item in y .columns if item not in x .columns
119- ]
120- for item in [item for item in col_union if item not in x .columns ]:
121- x [item ] = 0
122- x = x [col_union ]
123- for item in [item for item in col_union if item not in y .columns ]:
124- y [item ] = 0
125- y = y [col_union ]
126- x , y = (
127- x .set_backend (x_backend , inplace = True ),
128- y .set_backend (y_backend , inplace = True ),
129- )
113+ # Use sets for faster operations
114+ x_cols = set (x .columns )
115+ y_cols = set (y .columns )
116+
117+ # Find columns to add to each triangle
118+ cols_to_add_to_x = y_cols - x_cols
119+ cols_to_add_to_y = x_cols - y_cols
120+
121+ # Create new columns only if necessary
122+ if cols_to_add_to_x :
123+ new_x_cols = list (x .columns ) + list (cols_to_add_to_x )
124+ x = x .reindex (columns = new_x_cols , fill_value = 0 )
125+
126+ if cols_to_add_to_y :
127+ new_y_cols = list (y .columns ) + list (cols_to_add_to_y )
128+ y = y .reindex (columns = new_y_cols , fill_value = 0 )
129+
130+ # Ensure both triangles have the same column order
131+ final_cols = list (x_cols | y_cols )
132+ x = x [final_cols ]
133+ y = y [final_cols ]
134+
135+ # Reset backends only if they've changed
136+ if x .array_backend != x_backend :
137+ x = x .set_backend (x_backend , inplace = True )
138+ if y .array_backend != y_backend :
139+ y = y .set_backend (y_backend , inplace = True )
140+
130141 return x , y
131142
132143 def _prep_origin_development (self , obj , other ):
0 commit comments