Skip to content

Commit 4562063

Browse files
committed
Bump version to 1.10.2 and optimize foreground labeling
1 parent dade65d commit 4562063

File tree

3 files changed

+203
-37
lines changed

3 files changed

+203
-37
lines changed

csubst/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.10.1'
1+
__version__ = '1.10.2'

csubst/foreground.py

Lines changed: 102 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -160,50 +160,102 @@ def _normalize_branch_ids(branch_ids):
160160
def _count_branch_memberships(cb, bid_cols, ids):
161161
if len(bid_cols) == 0:
162162
return np.zeros(shape=(cb.shape[0],), dtype=np.int64)
163-
id_list = _normalize_branch_ids(ids).tolist()
164-
if len(id_list) == 0:
165-
return np.zeros(shape=(cb.shape[0],), dtype=np.int64)
166163
bid_matrix = cb.loc[:, bid_cols].to_numpy(copy=False)
167-
return np.isin(bid_matrix, id_list).sum(axis=1).astype(np.int64)
168-
169-
170-
def _mark_dependent_foreground_rows(cb, bid_cols, trait_name, dependent_id_combinations):
164+
return _count_branch_memberships_from_bid_matrix(bid_matrix=bid_matrix, ids=ids)
165+
166+
167+
def _count_branch_memberships_from_bid_matrix(bid_matrix, ids):
168+
bid_matrix = np.asarray(bid_matrix, dtype=np.int64)
169+
if bid_matrix.ndim != 2:
170+
raise ValueError('bid_matrix should be a 2D array.')
171+
if bid_matrix.shape[1] == 0:
172+
return np.zeros(shape=(bid_matrix.shape[0],), dtype=np.int64)
173+
id_array = _normalize_branch_ids(ids)
174+
if id_array.shape[0] == 0:
175+
return np.zeros(shape=(bid_matrix.shape[0],), dtype=np.int64)
176+
id_array = np.unique(id_array)
177+
return np.isin(bid_matrix, id_array).sum(axis=1).astype(np.int64)
178+
179+
180+
def _build_order_invariant_row_keys(matrix, assume_sorted=False):
181+
matrix = np.asarray(matrix, dtype=np.int64)
182+
if matrix.ndim != 2:
183+
raise ValueError('matrix should be a 2D array.')
184+
if not assume_sorted:
185+
matrix = np.sort(matrix, axis=1)
186+
key_dtype = np.dtype((np.void, matrix.dtype.itemsize * matrix.shape[1]))
187+
if matrix.shape[0] == 0:
188+
return np.zeros(shape=(0,), dtype=key_dtype)
189+
return np.ascontiguousarray(matrix).view(key_dtype).reshape(-1)
190+
191+
192+
def _compute_dependent_foreground_mask(cb, bid_cols, dependent_id_combinations, precomputed_bid_key=None):
171193
if len(bid_cols) == 0:
172-
return cb
194+
return np.zeros(shape=(cb.shape[0],), dtype=bool)
173195
dep = np.asarray(dependent_id_combinations, dtype=np.int64)
174196
if dep.size == 0:
175-
return cb
197+
return np.zeros(shape=(cb.shape[0],), dtype=bool)
176198
if dep.size % len(bid_cols) != 0:
177199
raise ValueError('dependent_id_combinations had an unexpected shape.')
178-
col_is_fg = 'is_fg_' + trait_name
179-
# Branch-combination semantics are order-invariant; compare sorted row tuples.
180200
dep = dep.reshape(-1, len(bid_cols))
181201
dep_sorted = np.sort(dep, axis=1)
182202
dep_sorted = np.unique(dep_sorted, axis=0)
183-
bid_matrix = cb.loc[:, bid_cols].to_numpy(copy=False)
184-
if bid_matrix.shape[0] == 0:
185-
return cb
186-
bid_sorted = np.sort(np.asarray(bid_matrix, dtype=np.int64), axis=1)
187-
dep_key = np.ascontiguousarray(dep_sorted).view(np.dtype((np.void, dep_sorted.dtype.itemsize * dep_sorted.shape[1]))).reshape(-1)
188-
bid_key = np.ascontiguousarray(bid_sorted).view(np.dtype((np.void, bid_sorted.dtype.itemsize * bid_sorted.shape[1]))).reshape(-1)
189-
is_dep = np.isin(bid_key, dep_key)
203+
dep_key = _build_order_invariant_row_keys(dep_sorted, assume_sorted=True)
204+
if precomputed_bid_key is None:
205+
bid_matrix = cb.loc[:, bid_cols].to_numpy(copy=False)
206+
bid_key = _build_order_invariant_row_keys(bid_matrix, assume_sorted=False)
207+
else:
208+
bid_key = np.asarray(precomputed_bid_key).reshape(-1)
209+
if bid_key.shape[0] != cb.shape[0]:
210+
txt = 'precomputed_bid_key length ({}) did not match cb rows ({}).'
211+
raise ValueError(txt.format(bid_key.shape[0], cb.shape[0]))
212+
return np.isin(bid_key, dep_key)
213+
214+
215+
def _mark_dependent_foreground_rows(cb, bid_cols, trait_name, dependent_id_combinations):
216+
col_is_fg = 'is_fg_' + trait_name
217+
is_dep = _compute_dependent_foreground_mask(
218+
cb=cb,
219+
bid_cols=bid_cols,
220+
dependent_id_combinations=dependent_id_combinations,
221+
)
190222
cb.loc[is_dep, col_is_fg] = 'N'
191223
return cb
192224

193225

194-
def _assign_trait_labels(cb, trait_name, arity):
226+
def _assign_trait_labels(cb, trait_name, arity, is_fg_dependent=None, num_fg=None, num_mg=None):
195227
col_num_fg = 'branch_num_fg_' + trait_name
196228
col_num_mg = 'branch_num_mg_' + trait_name
197229
col_is_fg = 'is_fg_' + trait_name
198230
col_is_mg = 'is_mg_' + trait_name
199231
col_is_mf = 'is_mf_' + trait_name
200-
cb.loc[:, col_is_fg] = 'N'
201-
cb.loc[cb.loc[:, col_num_fg] == arity, col_is_fg] = 'Y'
202-
cb.loc[:, col_is_mg] = 'N'
203-
cb.loc[cb.loc[:, col_num_mg] == arity, col_is_mg] = 'Y'
232+
if num_fg is None:
233+
num_fg = cb.loc[:, col_num_fg].to_numpy(copy=False)
234+
else:
235+
num_fg = np.asarray(num_fg, dtype=np.int64).reshape(-1)
236+
if num_fg.shape[0] != cb.shape[0]:
237+
txt = 'num_fg length ({}) did not match cb rows ({}).'
238+
raise ValueError(txt.format(num_fg.shape[0], cb.shape[0]))
239+
if num_mg is None:
240+
num_mg = cb.loc[:, col_num_mg].to_numpy(copy=False)
241+
else:
242+
num_mg = np.asarray(num_mg, dtype=np.int64).reshape(-1)
243+
if num_mg.shape[0] != cb.shape[0]:
244+
txt = 'num_mg length ({}) did not match cb rows ({}).'
245+
raise ValueError(txt.format(num_mg.shape[0], cb.shape[0]))
246+
is_fg = (num_fg == arity)
247+
if is_fg_dependent is not None:
248+
is_fg_dependent = np.asarray(is_fg_dependent, dtype=bool).reshape(-1)
249+
if is_fg_dependent.shape[0] != cb.shape[0]:
250+
txt = 'is_fg_dependent length ({}) did not match cb rows ({}).'
251+
raise ValueError(txt.format(is_fg_dependent.shape[0], cb.shape[0]))
252+
is_fg &= (~is_fg_dependent)
253+
is_mg = (num_mg == arity)
254+
is_mf = (num_fg > 0) & (num_mg > 0)
255+
is_mf = is_mf & ((num_fg + num_mg) == arity)
256+
cb.loc[:, col_is_fg] = np.where(is_fg, 'Y', 'N')
257+
cb.loc[:, col_is_mg] = np.where(is_mg, 'Y', 'N')
204258
cb.loc[:, col_is_mf] = 'N'
205-
is_mf = (cb.loc[:, col_num_fg] > 0) & (cb.loc[:, col_num_mg] > 0)
206-
is_mf = is_mf & ((cb.loc[:, col_num_fg] + cb.loc[:, col_num_mg]) == arity)
207259
cb.loc[is_mf, col_is_mf] = 'Y'
208260
return cb
209261

@@ -812,29 +864,43 @@ def get_foreground_branch_num(cb, g):
812864
start_time = time.time()
813865
bid_cols = cb.columns[cb.columns.str.startswith('branch_id_')]
814866
arity = len(bid_cols)
867+
bid_matrix = np.asarray(cb.loc[:, bid_cols].to_numpy(copy=False), dtype=np.int64)
868+
precomputed_bid_key = _build_order_invariant_row_keys(bid_matrix, assume_sorted=False)
815869
trait_names = _get_trait_names(g)
816870
for trait_name in trait_names:
817871
col_num_fg = 'branch_num_fg_' + trait_name
818872
col_num_mg = 'branch_num_mg_' + trait_name
819873
col_num_fg_stem = 'branch_num_fg_stem_' + trait_name
820-
col_is_fg = 'is_fg_' + trait_name
821-
cb.loc[:, col_num_fg] = _count_branch_memberships(cb=cb, bid_cols=bid_cols, ids=g['fg_ids'][trait_name])
822-
cb.loc[:, col_num_mg] = _count_branch_memberships(cb=cb, bid_cols=bid_cols, ids=g['mg_ids'][trait_name])
823-
cb = _assign_trait_labels(cb=cb, trait_name=trait_name, arity=arity)
824-
cb = _mark_dependent_foreground_rows(
874+
num_fg_array = _count_branch_memberships_from_bid_matrix(bid_matrix=bid_matrix, ids=g['fg_ids'][trait_name])
875+
num_mg_array = _count_branch_memberships_from_bid_matrix(bid_matrix=bid_matrix, ids=g['mg_ids'][trait_name])
876+
cb.loc[:, col_num_fg] = num_fg_array
877+
cb.loc[:, col_num_mg] = num_mg_array
878+
is_fg_dependent = _compute_dependent_foreground_mask(
825879
cb=cb,
826880
bid_cols=bid_cols,
827-
trait_name=trait_name,
828881
dependent_id_combinations=g['fg_dependent_id_combinations'][trait_name],
882+
precomputed_bid_key=precomputed_bid_key,
883+
)
884+
cb = _assign_trait_labels(
885+
cb=cb,
886+
trait_name=trait_name,
887+
arity=arity,
888+
is_fg_dependent=is_fg_dependent,
889+
num_fg=num_fg_array,
890+
num_mg=num_mg_array,
829891
)
830892
df_clade_size = get_df_clade_size(g, trait_name)
831893
fg_stem_bids = df_clade_size.loc[df_clade_size.loc[:,'is_fg_stem_'+trait_name],'branch_id'].values
832-
cb.loc[:, col_num_fg_stem] = _count_branch_memberships(cb=cb, bid_cols=bid_cols, ids=fg_stem_bids)
833-
is_fg = (cb[col_is_fg] == 'Y')
894+
cb.loc[:, col_num_fg_stem] = _count_branch_memberships_from_bid_matrix(bid_matrix=bid_matrix, ids=fg_stem_bids)
895+
is_fg = (num_fg_array == arity) & (~is_fg_dependent)
834896
is_enough_stat = table.get_cutoff_stat_bool_array(cb=cb, cutoff_stat_str=g['cutoff_stat'])
835-
num_enough = is_enough_stat.sum()
836-
num_fg = is_fg.sum()
837-
num_fg_enough = (is_enough_stat&is_fg).sum()
897+
if isinstance(is_enough_stat, (bool, np.bool_)):
898+
is_enough_stat = np.full(shape=(cb.shape[0],), fill_value=bool(is_enough_stat), dtype=bool)
899+
else:
900+
is_enough_stat = np.asarray(is_enough_stat, dtype=bool).reshape(-1)
901+
num_enough = int(is_enough_stat.sum())
902+
num_fg = int(is_fg.sum())
903+
num_fg_enough = int((is_enough_stat & is_fg).sum())
838904
num_all = cb.shape[0]
839905
percent_fg_enough, enrichment_factor = _calculate_fg_enrichment(
840906
num_enough=num_enough,

tests/test_foreground_clade_permutation.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,12 @@ def test_count_branch_memberships_accepts_scalar_ids():
8888
assert out.tolist() == [1, 0]
8989

9090

91+
def test_count_branch_memberships_from_bid_matrix_accepts_scalar_ids():
92+
bid_matrix = np.array([[1, 3], [2, 4]], dtype=np.int64)
93+
out = foreground._count_branch_memberships_from_bid_matrix(bid_matrix=bid_matrix, ids=np.int64(3))
94+
assert out.tolist() == [1, 0]
95+
96+
9197
def test_mark_dependent_foreground_rows_is_order_invariant_for_pairs():
9298
cb = pd.DataFrame(
9399
{
@@ -125,6 +131,100 @@ def test_mark_dependent_foreground_rows_is_order_invariant_for_higher_arity():
125131
assert out.loc[:, "is_fg_traitA"].tolist() == ["N", "Y", "N"]
126132

127133

134+
def test_compute_dependent_foreground_mask_is_order_invariant_for_pairs():
135+
cb = pd.DataFrame(
136+
{
137+
"branch_id_1": [1, 1, 2, 3],
138+
"branch_id_2": [5, 3, 5, 4],
139+
}
140+
)
141+
dep = np.array([[5, 1], [5, 2]], dtype=np.int64)
142+
out = foreground._compute_dependent_foreground_mask(
143+
cb=cb,
144+
bid_cols=["branch_id_1", "branch_id_2"],
145+
dependent_id_combinations=dep,
146+
)
147+
assert out.tolist() == [True, False, True, False]
148+
149+
150+
def test_compute_dependent_foreground_mask_accepts_precomputed_bid_key():
151+
cb = pd.DataFrame(
152+
{
153+
"branch_id_1": [1, 1, 2, 3],
154+
"branch_id_2": [5, 3, 5, 4],
155+
}
156+
)
157+
dep = np.array([[5, 1], [5, 2]], dtype=np.int64)
158+
bid_matrix = cb.loc[:, ["branch_id_1", "branch_id_2"]].to_numpy(copy=False)
159+
bid_key = foreground._build_order_invariant_row_keys(bid_matrix, assume_sorted=False)
160+
out = foreground._compute_dependent_foreground_mask(
161+
cb=cb,
162+
bid_cols=["branch_id_1", "branch_id_2"],
163+
dependent_id_combinations=dep,
164+
precomputed_bid_key=bid_key,
165+
)
166+
assert out.tolist() == [True, False, True, False]
167+
168+
169+
def test_assign_trait_labels_applies_dependent_mask_to_foreground_only():
170+
cb = pd.DataFrame(
171+
{
172+
"branch_num_fg_traitA": [2, 2, 1, 0],
173+
"branch_num_mg_traitA": [0, 0, 1, 2],
174+
}
175+
)
176+
out = foreground._assign_trait_labels(
177+
cb=cb.copy(deep=True),
178+
trait_name="traitA",
179+
arity=2,
180+
is_fg_dependent=np.array([True, False, False, False], dtype=bool),
181+
)
182+
assert out.loc[:, "is_fg_traitA"].tolist() == ["N", "Y", "N", "N"]
183+
assert out.loc[:, "is_mf_traitA"].tolist() == ["N", "N", "Y", "N"]
184+
assert out.loc[:, "is_mg_traitA"].tolist() == ["N", "N", "N", "Y"]
185+
186+
187+
def test_assign_trait_labels_rejects_mismatched_dependent_mask_length():
188+
cb = pd.DataFrame(
189+
{
190+
"branch_num_fg_traitA": [2, 2],
191+
"branch_num_mg_traitA": [0, 0],
192+
}
193+
)
194+
with pytest.raises(ValueError, match="did not match cb rows"):
195+
foreground._assign_trait_labels(
196+
cb=cb.copy(deep=True),
197+
trait_name="traitA",
198+
arity=2,
199+
is_fg_dependent=np.array([True], dtype=bool),
200+
)
201+
202+
203+
def test_assign_trait_labels_rejects_mismatched_num_fg_num_mg_length():
204+
cb = pd.DataFrame(
205+
{
206+
"branch_num_fg_traitA": [2, 2],
207+
"branch_num_mg_traitA": [0, 0],
208+
}
209+
)
210+
with pytest.raises(ValueError, match="num_fg length"):
211+
foreground._assign_trait_labels(
212+
cb=cb.copy(deep=True),
213+
trait_name="traitA",
214+
arity=2,
215+
num_fg=np.array([2], dtype=np.int64),
216+
num_mg=np.array([0, 0], dtype=np.int64),
217+
)
218+
with pytest.raises(ValueError, match="num_mg length"):
219+
foreground._assign_trait_labels(
220+
cb=cb.copy(deep=True),
221+
trait_name="traitA",
222+
arity=2,
223+
num_fg=np.array([2, 2], dtype=np.int64),
224+
num_mg=np.array([0], dtype=np.int64),
225+
)
226+
227+
128228
def test_annotate_foreground_fg_stem_only_keeps_lineage_specific_stem_colors():
129229
tr = tree.add_numerical_node_labels(ete.PhyloNode("((Nep1:1,Nep2:1)N:1,Ceph:1)R;", format=1))
130230
g = {

0 commit comments

Comments
 (0)