Skip to content

Commit a67ba5b

Browse files
Merge pull request #189 from szhan/filter_by_hmm_cost_without_group
Filter samples by HMM cost without match grouping
2 parents 1ec032e + 12e362a commit a67ba5b

File tree

3 files changed

+25
-32
lines changed

3 files changed

+25
-32
lines changed

sc2ts/inference.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -447,10 +447,12 @@ def add_matching_results(
447447
# By default, arbitraily high.
448448
max_hmm_cost = 1e6
449449

450-
# Group matches by path and set of reversion mutations
450+
# Group matches by path and set of immediate reversions.
451451
grouped_matches = collections.defaultdict(list)
452452
site_masked_samples = np.zeros(int(ts.sequence_length), dtype=int)
453453
for sample in samples:
454+
if sample.get_hmm_cost(num_mismatches) > max_hmm_cost:
455+
continue
454456
site_masked_samples[sample.masked_sites] += 1
455457
path = tuple(sample.path)
456458
reversions = tuple(
@@ -460,17 +462,6 @@ def add_matching_results(
460462
)
461463
grouped_matches[(path, reversions)].append(sample)
462464

463-
# Exclude single samples with "high-HMM cost" attachment paths.
464-
tmp = {}
465-
for k, v in grouped_matches.items():
466-
if len(v) == 1:
467-
# Exclude sample if it's HMM cost exceeds a maximum.
468-
sample = v[0]
469-
if sample.get_hmm_cost(num_mismatches) > max_hmm_cost:
470-
continue
471-
tmp[k] = v
472-
grouped_matches = tmp
473-
474465
tables = ts.dump_tables()
475466
logger.info(f"Got {len(grouped_matches)} distinct paths")
476467

tests/test_inference.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -83,24 +83,25 @@ def test_one_sample_recombinant_filtered(self):
8383
assert ts2.num_nodes == ts.num_nodes
8484
assert ts2.num_samples == ts.num_samples
8585

86-
def test_two_samples_recombinant_not_filtered(self):
87-
"""
88-
Test case where two identical recombinant samples get added
89-
but not excluded despite HMM costs above the max threshold.
90-
"""
86+
def test_two_samples_recombinant_one_filtered(self):
9187
ts = util.example_binary(2)
9288
L = ts.sequence_length
9389
x = L / 2
94-
new_sample_paths = [
95-
[(0, x, 2), (x, L, 3)],
96-
[(0, x, 2), (x, L, 3)],
90+
new_paths = [
91+
[(0, x, 2), (x, L, 3)], # Added
92+
[
93+
(0, L / 4, 2),
94+
(L / 4, L / 2, 3),
95+
(L / 2, 3 / 4 * L, 4),
96+
(3 / 4 * L, L, 2),
97+
], # Filtered
9798
]
98-
samples = util.get_samples(ts, new_sample_paths)
99+
samples = util.get_samples(ts, new_paths)
99100
ts2 = sc2ts.add_matching_results(
100-
samples, ts, "2021", num_mismatches=1e3, max_hmm_cost=1e3 - 1
101+
samples, ts, "2021", num_mismatches=3, max_hmm_cost=4
101102
)
102103
assert ts2.num_trees == 2
103-
assert ts2.num_samples == ts.num_samples + len(new_sample_paths)
104+
assert ts2.num_samples == ts.num_samples + 1
104105

105106
def test_one_sample_one_mutation(self):
106107
ts = sc2ts.initial_ts()
@@ -130,24 +131,25 @@ def test_one_sample_one_mutation_filtered(self):
130131
assert ts2.site(0).ancestral_state == ts.site(0).ancestral_state
131132
assert ts2.num_mutations == 0
132133

133-
def test_two_samples_one_mutation_not_filtered(self):
134+
def test_two_samples_one_mutation_one_filtered(self):
134135
ts = sc2ts.initial_ts()
135136
ts = sc2ts.increment_time("2020-01-01", ts)
136-
new_sample_paths = [
137+
x = int(ts.sequence_length / 2)
138+
new_paths = [
137139
[(0, ts.sequence_length, 1)],
138140
[(0, ts.sequence_length, 1)],
139141
]
140-
new_sample_mutations = [
141-
[(0, "X")],
142-
[(0, "X")],
142+
new_mutations = [
143+
[(0, "X")], # Added
144+
[(0, "X"), (x, "X")], # Filtered
143145
]
144146
samples = util.get_samples(
145147
ts,
146-
paths=new_sample_paths,
147-
mutations=new_sample_mutations,
148+
paths=new_paths,
149+
mutations=new_mutations,
148150
)
149151
ts2 = sc2ts.add_matching_results(
150-
samples, ts, "2021", num_mismatches=0.0, max_hmm_cost=0.0
152+
samples, ts, "2021", num_mismatches=3, max_hmm_cost=1
151153
)
152154
assert ts2.num_trees == ts.num_trees
153155
assert ts2.site(0).ancestral_state == ts.site(0).ancestral_state

tests/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def make_recombinant_tree(self, num_samples=1):
3535
x = L / 2
3636
samples = util.get_samples(ts, [[(0, x, 2), (x, L, 3)]] * num_samples)
3737
ts_rec = sc2ts.add_matching_results(
38-
samples, ts, "2021", num_mismatches=None, num_samples=None
38+
samples, ts, "2021", num_mismatches=None, max_hmm_cost=None
3939
)
4040
assert ts_rec.num_trees == 2
4141
return ts_rec

0 commit comments

Comments
 (0)