Skip to content

Commit 1ec032e

Browse files
Merge pull request #187 from szhan/hmm_cost_filter
Implement HMM cost to filter samples
2 parents 29c2bd0 + bbfd8c2 commit 1ec032e

File tree

5 files changed

+141
-15
lines changed

5 files changed

+141
-15
lines changed

sc2ts/cli.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def add_provenance(ts, output_file):
162162
),
163163
)
164164
@click.option("--num-mismatches", default=None, type=float, help="num-mismatches")
165+
@click.option("--max-hmm-cost", default=None, type=float, help="max-hmm-cost")
165166
@click.option(
166167
"--max-submission-delay",
167168
default=None,
@@ -192,6 +193,7 @@ def daily_extend(
192193
output_prefix,
193194
base,
194195
num_mismatches,
196+
max_hmm_cost,
195197
max_submission_delay,
196198
max_daily_samples,
197199
num_threads,
@@ -219,6 +221,7 @@ def daily_extend(
219221
metadata_db=metadata_db,
220222
base_ts=base_ts,
221223
num_mismatches=num_mismatches,
224+
max_hmm_cost=max_hmm_cost,
222225
max_submission_delay=max_submission_delay,
223226
max_daily_samples=max_daily_samples,
224227
rng=rng,

sc2ts/inference.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,11 @@ def submission_date(self):
222222
def submission_delay(self):
223223
return (self.submission_date - self.date).days
224224

225+
def get_hmm_cost(self, num_mismatches):
226+
# Note that Recombinant objects have total_cost.
227+
# This bit of code is sort of repeated.
228+
return num_mismatches * (len(self.path) - 1) + len(self.mutations)
229+
225230
def asdict(self):
226231
return {
227232
"strain": self.strain,
@@ -247,6 +252,7 @@ def daily_extend(
247252
metadata_db,
248253
base_ts,
249254
num_mismatches=None,
255+
max_hmm_cost=None,
250256
show_progress=False,
251257
max_submission_delay=None,
252258
max_daily_samples=None,
@@ -263,6 +269,7 @@ def daily_extend(
263269
date=date,
264270
base_ts=last_ts,
265271
num_mismatches=num_mismatches,
272+
max_hmm_cost=max_hmm_cost,
266273
show_progress=show_progress,
267274
max_submission_delay=max_submission_delay,
268275
max_daily_samples=max_daily_samples,
@@ -340,14 +347,14 @@ def extend(
340347
date,
341348
base_ts,
342349
num_mismatches=None,
350+
max_hmm_cost=None,
343351
show_progress=False,
344352
max_submission_delay=None,
345353
max_daily_samples=None,
346354
num_threads=None,
347355
precision=None,
348356
rng=None,
349357
):
350-
351358
date_samples = [Sample(md) for md in metadata_db.get(date)]
352359
samples = filter_samples(date_samples, alignment_store, max_submission_delay)
353360

@@ -361,6 +368,7 @@ def extend(
361368

362369
logger.info(f"Got {len(samples)} samples")
363370

371+
# Note num_mismatches is assigned a default value in match_tsinfer.
364372
samples = match(
365373
samples,
366374
alignment_store=alignment_store,
@@ -371,7 +379,15 @@ def extend(
371379
precision=precision,
372380
)
373381
ts = increment_time(date, base_ts)
374-
return add_matching_results(samples, ts, date, show_progress)
382+
383+
return add_matching_results(
384+
samples=samples,
385+
ts=ts,
386+
date=date,
387+
num_mismatches=num_mismatches,
388+
max_hmm_cost=max_hmm_cost,
389+
show_progress=show_progress,
390+
)
375391

376392

377393
def match_path_ts(samples, ts, path, reversions):
@@ -394,7 +410,7 @@ def match_path_ts(samples, ts, path, reversions):
394410
"qc": sample.alignment_qc,
395411
"path": [x.asdict() for x in sample.path],
396412
"mutations": [x.asdict() for x in sample.mutations],
397-
}
413+
},
398414
}
399415
node_id = tables.nodes.add_row(
400416
flags=tskit.NODE_IS_SAMPLE, time=0, metadata=metadata
@@ -407,7 +423,7 @@ def match_path_ts(samples, ts, path, reversions):
407423

408424
# Now add the mutations
409425
for node_id, sample in enumerate(samples, first_sample):
410-
#metadata = {**sample.metadata, "sc2ts_qc": sample.alignment_qc}
426+
# metadata = {**sample.metadata, "sc2ts_qc": sample.alignment_qc}
411427
for mut in sample.mutations:
412428
tables.mutations.add_row(
413429
site=site_id_map[mut.site_id],
@@ -420,7 +436,16 @@ def match_path_ts(samples, ts, path, reversions):
420436
# print(tables)
421437

422438

423-
def add_matching_results(samples, ts, date, show_progress=False):
439+
def add_matching_results(
440+
samples, ts, date, num_mismatches, max_hmm_cost, show_progress=False
441+
):
442+
if num_mismatches is None:
443+
# Note that this is the default assigned in match_tsinfer.
444+
num_mismatches = 1e3
445+
446+
if max_hmm_cost is None:
447+
# By default, arbitraily high.
448+
max_hmm_cost = 1e6
424449

425450
# Group matches by path and set of reversion mutations
426451
grouped_matches = collections.defaultdict(list)
@@ -435,6 +460,17 @@ def add_matching_results(samples, ts, date, show_progress=False):
435460
)
436461
grouped_matches[(path, reversions)].append(sample)
437462

463+
# Exclude single samples with "high-HMM cost" attachment paths.
464+
tmp = {}
465+
for k, v in grouped_matches.items():
466+
if len(v) == 1:
467+
# Exclude sample if it's HMM cost exceeds a maximum.
468+
sample = v[0]
469+
if sample.get_hmm_cost(num_mismatches) > max_hmm_cost:
470+
continue
471+
tmp[k] = v
472+
grouped_matches = tmp
473+
438474
tables = ts.dump_tables()
439475
logger.info(f"Got {len(grouped_matches)} distinct paths")
440476

@@ -981,6 +1017,7 @@ def match_tsinfer(
9811017
show_progress=False,
9821018
mirror_coordinates=False,
9831019
):
1020+
# TODO: Should this default be assigned elsewhere?
9841021
if num_mismatches is None:
9851022
# Default to no recombination
9861023
num_mismatches = 1000

sc2ts/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,14 +179,14 @@ def data_summary(self):
179179
return d
180180

181181
@property
182-
def total_cost(self):
182+
def total_cost(self, num_mismatches):
183183
"""
184184
How different is the causal sequence from the rest, roughly?
185185
"""
186186
fwd = self.hmm_runs["forward"]
187187
bck = self.hmm_runs["backward"]
188-
cost_fwd = 3 * (len(fwd.parents) - 1) + len(fwd.mutations)
189-
cost_bck = 3 * (len(bck.parents) - 1) + len(bck.mutations)
188+
cost_fwd = num_mismatches * (len(fwd.parents) - 1) + len(fwd.mutations)
189+
cost_bck = num_mismatches * (len(bck.parents) - 1) + len(bck.mutations)
190190
assert cost_fwd == cost_bck
191191
return cost_fwd
192192

tests/test_inference.py

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,16 @@
99

1010

1111
class TestAddMatchingResults:
12-
def add_matching_results(self, samples, ts, date="2020-01-01"):
13-
ts2 = sc2ts.add_matching_results(samples, ts, date)
12+
def add_matching_results(
13+
self, samples, ts, date="2020-01-01", num_mismatches=None, max_hmm_cost=None
14+
):
15+
ts2 = sc2ts.add_matching_results(
16+
samples,
17+
ts,
18+
date,
19+
num_mismatches,
20+
max_hmm_cost,
21+
)
1422
assert ts2.num_samples == len(samples) + ts.num_samples
1523
for u, sample in zip(ts2.samples()[-len(samples) :], samples):
1624
node = ts2.node(u)
@@ -54,6 +62,46 @@ def test_one_sample_recombinant(self):
5462
assert ts2.node(6).flags == sc2ts.NODE_IS_RECOMBINANT
5563
assert ts2.node(6).metadata == {"date_added": "2021"}
5664

65+
def test_one_sample_recombinant_filtered(self):
66+
# 4.00┊ 0 ┊
67+
# ┊ ┃ ┊
68+
# 3.00┊ 1 ┊
69+
# ┊ ┃ ┊
70+
# 2.00┊ 4 ┊
71+
# ┊ ┏┻┓ ┊
72+
# 1.00┊ 2 3 ┊
73+
# 0 29904
74+
ts = util.example_binary(2)
75+
L = ts.sequence_length
76+
x = L / 2
77+
samples = util.get_samples(ts, [[(0, x, 2), (x, L, 3)]])
78+
# Note that it is calling the function in the main module.
79+
ts2 = sc2ts.add_matching_results(
80+
samples, ts, "2021", num_mismatches=1e3, max_hmm_cost=1e3 - 1
81+
)
82+
assert ts2.num_trees == 1
83+
assert ts2.num_nodes == ts.num_nodes
84+
assert ts2.num_samples == ts.num_samples
85+
86+
def test_two_samples_recombinant_not_filtered(self):
87+
"""
88+
Test case where two identical recombinant samples get added
89+
but not excluded despite HMM costs above the max threshold.
90+
"""
91+
ts = util.example_binary(2)
92+
L = ts.sequence_length
93+
x = L / 2
94+
new_sample_paths = [
95+
[(0, x, 2), (x, L, 3)],
96+
[(0, x, 2), (x, L, 3)],
97+
]
98+
samples = util.get_samples(ts, new_sample_paths)
99+
ts2 = sc2ts.add_matching_results(
100+
samples, ts, "2021", num_mismatches=1e3, max_hmm_cost=1e3 - 1
101+
)
102+
assert ts2.num_trees == 2
103+
assert ts2.num_samples == ts.num_samples + len(new_sample_paths)
104+
57105
def test_one_sample_one_mutation(self):
58106
ts = sc2ts.initial_ts()
59107
ts = sc2ts.increment_time("2020-01-01", ts)
@@ -69,6 +117,44 @@ def test_one_sample_one_mutation(self):
69117
var = next(ts2.variants())
70118
assert var.alleles[var.genotypes[0]] == "X"
71119

120+
def test_one_sample_one_mutation_filtered(self):
121+
ts = sc2ts.initial_ts()
122+
ts = sc2ts.increment_time("2020-01-01", ts)
123+
samples = util.get_samples(
124+
ts, [[(0, ts.sequence_length, 1)]], mutations=[[(0, "X")]]
125+
)
126+
ts2 = sc2ts.add_matching_results(
127+
samples, ts, "2021", num_mismatches=0.0, max_hmm_cost=0.0
128+
)
129+
assert ts2.num_trees == ts.num_trees
130+
assert ts2.site(0).ancestral_state == ts.site(0).ancestral_state
131+
assert ts2.num_mutations == 0
132+
133+
def test_two_samples_one_mutation_not_filtered(self):
134+
ts = sc2ts.initial_ts()
135+
ts = sc2ts.increment_time("2020-01-01", ts)
136+
new_sample_paths = [
137+
[(0, ts.sequence_length, 1)],
138+
[(0, ts.sequence_length, 1)],
139+
]
140+
new_sample_mutations = [
141+
[(0, "X")],
142+
[(0, "X")],
143+
]
144+
samples = util.get_samples(
145+
ts,
146+
paths=new_sample_paths,
147+
mutations=new_sample_mutations,
148+
)
149+
ts2 = sc2ts.add_matching_results(
150+
samples, ts, "2021", num_mismatches=0.0, max_hmm_cost=0.0
151+
)
152+
assert ts2.num_trees == ts.num_trees
153+
assert ts2.site(0).ancestral_state == ts.site(0).ancestral_state
154+
assert ts2.num_mutations == 1
155+
var = next(ts2.variants())
156+
assert var.alleles[var.genotypes[0]] == "X"
157+
72158

73159
class TestMatchTsinfer:
74160
def match_tsinfer(self, samples, ts, haplotypes, **kwargs):

tests/test_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,17 @@ def make_recombinant_tree(self, num_samples=1):
3434
L = ts.sequence_length
3535
x = L / 2
3636
samples = util.get_samples(ts, [[(0, x, 2), (x, L, 3)]] * num_samples)
37-
ts_rec = sc2ts.add_matching_results(samples, ts, "2021")
37+
ts_rec = sc2ts.add_matching_results(
38+
samples, ts, "2021", num_mismatches=None, num_samples=None
39+
)
3840
assert ts_rec.num_trees == 2
3941
return ts_rec
4042

4143
@pytest.mark.parametrize(
4244
"ts",
4345
# Should probably add sc2ts.initial_ts() here, but see
4446
# https://github.com/jeromekelleher/sc2ts/issues/152
45-
[util.example_binary(1), util.example_binary(2), util.example_binary(3)]
47+
[util.example_binary(1), util.example_binary(2), util.example_binary(3)],
4648
)
4749
def test_no_recombinants(self, ts):
4850
ts2 = utils.detach_singleton_recombinants(ts)
@@ -52,9 +54,7 @@ def test_one_sample_recombinant(self):
5254
ts = self.make_recombinant_tree()
5355
assert ts.num_samples == 3
5456
re_nodes = [
55-
node.id
56-
for node in ts.nodes()
57-
if node.flags & sc2ts.NODE_IS_RECOMBINANT
57+
node.id for node in ts.nodes() if node.flags & sc2ts.NODE_IS_RECOMBINANT
5858
]
5959
assert len(re_nodes) == 1
6060
re_node = re_nodes[0]

0 commit comments

Comments
 (0)