@@ -160,50 +160,102 @@ def _normalize_branch_ids(branch_ids):
160160def _count_branch_memberships (cb , bid_cols , ids ):
161161 if len (bid_cols ) == 0 :
162162 return np .zeros (shape = (cb .shape [0 ],), dtype = np .int64 )
163- id_list = _normalize_branch_ids (ids ).tolist ()
164- if len (id_list ) == 0 :
165- return np .zeros (shape = (cb .shape [0 ],), dtype = np .int64 )
166163 bid_matrix = cb .loc [:, bid_cols ].to_numpy (copy = False )
167- return np .isin (bid_matrix , id_list ).sum (axis = 1 ).astype (np .int64 )
168-
169-
170- def _mark_dependent_foreground_rows (cb , bid_cols , trait_name , dependent_id_combinations ):
164+ return _count_branch_memberships_from_bid_matrix (bid_matrix = bid_matrix , ids = ids )
165+
166+
167+ def _count_branch_memberships_from_bid_matrix (bid_matrix , ids ):
168+ bid_matrix = np .asarray (bid_matrix , dtype = np .int64 )
169+ if bid_matrix .ndim != 2 :
170+ raise ValueError ('bid_matrix should be a 2D array.' )
171+ if bid_matrix .shape [1 ] == 0 :
172+ return np .zeros (shape = (bid_matrix .shape [0 ],), dtype = np .int64 )
173+ id_array = _normalize_branch_ids (ids )
174+ if id_array .shape [0 ] == 0 :
175+ return np .zeros (shape = (bid_matrix .shape [0 ],), dtype = np .int64 )
176+ id_array = np .unique (id_array )
177+ return np .isin (bid_matrix , id_array ).sum (axis = 1 ).astype (np .int64 )
178+
179+
180+ def _build_order_invariant_row_keys (matrix , assume_sorted = False ):
181+ matrix = np .asarray (matrix , dtype = np .int64 )
182+ if matrix .ndim != 2 :
183+ raise ValueError ('matrix should be a 2D array.' )
184+ if not assume_sorted :
185+ matrix = np .sort (matrix , axis = 1 )
186+ key_dtype = np .dtype ((np .void , matrix .dtype .itemsize * matrix .shape [1 ]))
187+ if matrix .shape [0 ] == 0 :
188+ return np .zeros (shape = (0 ,), dtype = key_dtype )
189+ return np .ascontiguousarray (matrix ).view (key_dtype ).reshape (- 1 )
190+
191+
192+ def _compute_dependent_foreground_mask (cb , bid_cols , dependent_id_combinations , precomputed_bid_key = None ):
171193 if len (bid_cols ) == 0 :
172- return cb
194+ return np . zeros ( shape = ( cb . shape [ 0 ],), dtype = bool )
173195 dep = np .asarray (dependent_id_combinations , dtype = np .int64 )
174196 if dep .size == 0 :
175- return cb
197+ return np . zeros ( shape = ( cb . shape [ 0 ],), dtype = bool )
176198 if dep .size % len (bid_cols ) != 0 :
177199 raise ValueError ('dependent_id_combinations had an unexpected shape.' )
178- col_is_fg = 'is_fg_' + trait_name
179- # Branch-combination semantics are order-invariant; compare sorted row tuples.
180200 dep = dep .reshape (- 1 , len (bid_cols ))
181201 dep_sorted = np .sort (dep , axis = 1 )
182202 dep_sorted = np .unique (dep_sorted , axis = 0 )
183- bid_matrix = cb .loc [:, bid_cols ].to_numpy (copy = False )
184- if bid_matrix .shape [0 ] == 0 :
185- return cb
186- bid_sorted = np .sort (np .asarray (bid_matrix , dtype = np .int64 ), axis = 1 )
187- dep_key = np .ascontiguousarray (dep_sorted ).view (np .dtype ((np .void , dep_sorted .dtype .itemsize * dep_sorted .shape [1 ]))).reshape (- 1 )
188- bid_key = np .ascontiguousarray (bid_sorted ).view (np .dtype ((np .void , bid_sorted .dtype .itemsize * bid_sorted .shape [1 ]))).reshape (- 1 )
189- is_dep = np .isin (bid_key , dep_key )
203+ dep_key = _build_order_invariant_row_keys (dep_sorted , assume_sorted = True )
204+ if precomputed_bid_key is None :
205+ bid_matrix = cb .loc [:, bid_cols ].to_numpy (copy = False )
206+ bid_key = _build_order_invariant_row_keys (bid_matrix , assume_sorted = False )
207+ else :
208+ bid_key = np .asarray (precomputed_bid_key ).reshape (- 1 )
209+ if bid_key .shape [0 ] != cb .shape [0 ]:
210+ txt = 'precomputed_bid_key length ({}) did not match cb rows ({}).'
211+ raise ValueError (txt .format (bid_key .shape [0 ], cb .shape [0 ]))
212+ return np .isin (bid_key , dep_key )
213+
214+
215+ def _mark_dependent_foreground_rows (cb , bid_cols , trait_name , dependent_id_combinations ):
216+ col_is_fg = 'is_fg_' + trait_name
217+ is_dep = _compute_dependent_foreground_mask (
218+ cb = cb ,
219+ bid_cols = bid_cols ,
220+ dependent_id_combinations = dependent_id_combinations ,
221+ )
190222 cb .loc [is_dep , col_is_fg ] = 'N'
191223 return cb
192224
193225
194- def _assign_trait_labels (cb , trait_name , arity ):
226+ def _assign_trait_labels (cb , trait_name , arity , is_fg_dependent = None , num_fg = None , num_mg = None ):
195227 col_num_fg = 'branch_num_fg_' + trait_name
196228 col_num_mg = 'branch_num_mg_' + trait_name
197229 col_is_fg = 'is_fg_' + trait_name
198230 col_is_mg = 'is_mg_' + trait_name
199231 col_is_mf = 'is_mf_' + trait_name
200- cb .loc [:, col_is_fg ] = 'N'
201- cb .loc [cb .loc [:, col_num_fg ] == arity , col_is_fg ] = 'Y'
202- cb .loc [:, col_is_mg ] = 'N'
203- cb .loc [cb .loc [:, col_num_mg ] == arity , col_is_mg ] = 'Y'
232+ if num_fg is None :
233+ num_fg = cb .loc [:, col_num_fg ].to_numpy (copy = False )
234+ else :
235+ num_fg = np .asarray (num_fg , dtype = np .int64 ).reshape (- 1 )
236+ if num_fg .shape [0 ] != cb .shape [0 ]:
237+ txt = 'num_fg length ({}) did not match cb rows ({}).'
238+ raise ValueError (txt .format (num_fg .shape [0 ], cb .shape [0 ]))
239+ if num_mg is None :
240+ num_mg = cb .loc [:, col_num_mg ].to_numpy (copy = False )
241+ else :
242+ num_mg = np .asarray (num_mg , dtype = np .int64 ).reshape (- 1 )
243+ if num_mg .shape [0 ] != cb .shape [0 ]:
244+ txt = 'num_mg length ({}) did not match cb rows ({}).'
245+ raise ValueError (txt .format (num_mg .shape [0 ], cb .shape [0 ]))
246+ is_fg = (num_fg == arity )
247+ if is_fg_dependent is not None :
248+ is_fg_dependent = np .asarray (is_fg_dependent , dtype = bool ).reshape (- 1 )
249+ if is_fg_dependent .shape [0 ] != cb .shape [0 ]:
250+ txt = 'is_fg_dependent length ({}) did not match cb rows ({}).'
251+ raise ValueError (txt .format (is_fg_dependent .shape [0 ], cb .shape [0 ]))
252+ is_fg &= (~ is_fg_dependent )
253+ is_mg = (num_mg == arity )
254+ is_mf = (num_fg > 0 ) & (num_mg > 0 )
255+ is_mf = is_mf & ((num_fg + num_mg ) == arity )
256+ cb .loc [:, col_is_fg ] = np .where (is_fg , 'Y' , 'N' )
257+ cb .loc [:, col_is_mg ] = np .where (is_mg , 'Y' , 'N' )
204258 cb .loc [:, col_is_mf ] = 'N'
205- is_mf = (cb .loc [:, col_num_fg ] > 0 ) & (cb .loc [:, col_num_mg ] > 0 )
206- is_mf = is_mf & ((cb .loc [:, col_num_fg ] + cb .loc [:, col_num_mg ]) == arity )
207259 cb .loc [is_mf , col_is_mf ] = 'Y'
208260 return cb
209261
@@ -812,29 +864,43 @@ def get_foreground_branch_num(cb, g):
812864 start_time = time .time ()
813865 bid_cols = cb .columns [cb .columns .str .startswith ('branch_id_' )]
814866 arity = len (bid_cols )
867+ bid_matrix = np .asarray (cb .loc [:, bid_cols ].to_numpy (copy = False ), dtype = np .int64 )
868+ precomputed_bid_key = _build_order_invariant_row_keys (bid_matrix , assume_sorted = False )
815869 trait_names = _get_trait_names (g )
816870 for trait_name in trait_names :
817871 col_num_fg = 'branch_num_fg_' + trait_name
818872 col_num_mg = 'branch_num_mg_' + trait_name
819873 col_num_fg_stem = 'branch_num_fg_stem_' + trait_name
820- col_is_fg = 'is_fg_' + trait_name
821- cb . loc [:, col_num_fg ] = _count_branch_memberships ( cb = cb , bid_cols = bid_cols , ids = g ['fg_ids ' ][trait_name ])
822- cb .loc [:, col_num_mg ] = _count_branch_memberships ( cb = cb , bid_cols = bid_cols , ids = g [ 'mg_ids' ][ trait_name ])
823- cb = _assign_trait_labels ( cb = cb , trait_name = trait_name , arity = arity )
824- cb = _mark_dependent_foreground_rows (
874+ num_fg_array = _count_branch_memberships_from_bid_matrix ( bid_matrix = bid_matrix , ids = g [ 'fg_ids' ][ trait_name ])
875+ num_mg_array = _count_branch_memberships_from_bid_matrix ( bid_matrix = bid_matrix , ids = g ['mg_ids ' ][trait_name ])
876+ cb .loc [:, col_num_fg ] = num_fg_array
877+ cb . loc [:, col_num_mg ] = num_mg_array
878+ is_fg_dependent = _compute_dependent_foreground_mask (
825879 cb = cb ,
826880 bid_cols = bid_cols ,
827- trait_name = trait_name ,
828881 dependent_id_combinations = g ['fg_dependent_id_combinations' ][trait_name ],
882+ precomputed_bid_key = precomputed_bid_key ,
883+ )
884+ cb = _assign_trait_labels (
885+ cb = cb ,
886+ trait_name = trait_name ,
887+ arity = arity ,
888+ is_fg_dependent = is_fg_dependent ,
889+ num_fg = num_fg_array ,
890+ num_mg = num_mg_array ,
829891 )
830892 df_clade_size = get_df_clade_size (g , trait_name )
831893 fg_stem_bids = df_clade_size .loc [df_clade_size .loc [:,'is_fg_stem_' + trait_name ],'branch_id' ].values
832- cb .loc [:, col_num_fg_stem ] = _count_branch_memberships ( cb = cb , bid_cols = bid_cols , ids = fg_stem_bids )
833- is_fg = (cb [ col_is_fg ] == 'Y' )
894+ cb .loc [:, col_num_fg_stem ] = _count_branch_memberships_from_bid_matrix ( bid_matrix = bid_matrix , ids = fg_stem_bids )
895+ is_fg = (num_fg_array == arity ) & ( ~ is_fg_dependent )
834896 is_enough_stat = table .get_cutoff_stat_bool_array (cb = cb , cutoff_stat_str = g ['cutoff_stat' ])
835- num_enough = is_enough_stat .sum ()
836- num_fg = is_fg .sum ()
837- num_fg_enough = (is_enough_stat & is_fg ).sum ()
897+ if isinstance (is_enough_stat , (bool , np .bool_ )):
898+ is_enough_stat = np .full (shape = (cb .shape [0 ],), fill_value = bool (is_enough_stat ), dtype = bool )
899+ else :
900+ is_enough_stat = np .asarray (is_enough_stat , dtype = bool ).reshape (- 1 )
901+ num_enough = int (is_enough_stat .sum ())
902+ num_fg = int (is_fg .sum ())
903+ num_fg_enough = int ((is_enough_stat & is_fg ).sum ())
838904 num_all = cb .shape [0 ]
839905 percent_fg_enough , enrichment_factor = _calculate_fg_enrichment (
840906 num_enough = num_enough ,
0 commit comments