Skip to content

Commit d336065

Browse files
committed
Dump pickled objects of samples filtered by HMM cost
1 parent a67ba5b commit d336065

File tree

4 files changed

+31
-18
lines changed

4 files changed

+31
-18
lines changed

sc2ts/cli.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import contextlib
99
import dataclasses
1010
import datetime
11+
import pickle
1112

1213
import tqdm
1314
import tskit
@@ -147,6 +148,11 @@ def add_provenance(ts, output_file):
147148
tables.dump(output_file)
148149

149150

151+
def dump_samples(samples, output_file):
152+
with open(output_file, "wb") as f:
153+
pickle.dump(samples, file=f)
154+
155+
150156
@click.command()
151157
@click.argument("alignments", type=click.Path(exists=True, dir_okay=False))
152158
@click.argument("metadata", type=click.Path(exists=True, dir_okay=False))
@@ -229,9 +235,11 @@ def daily_extend(
229235
num_threads=num_threads,
230236
show_progress=not no_progress,
231237
)
232-
for ts, date in ts_iter:
233-
output = output_prefix + date + ".ts"
234-
add_provenance(ts, output)
238+
for ts, excluded_samples, date in ts_iter:
239+
output_ts = output_prefix + date + ".ts"
240+
add_provenance(ts, output_ts)
241+
output_excluded_samples = output_prefix + date + ".excluded_samples.pickle"
242+
dump_samples(excluded_samples, output_excluded_samples)
235243

236244

237245
@click.command()

sc2ts/inference.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def daily_extend(
263263
start_day = last_date(base_ts)
264264
last_ts = base_ts
265265
for date in metadata_db.get_days(start_day):
266-
ts = extend(
266+
ts, excluded_samples = extend(
267267
alignment_store=alignment_store,
268268
metadata_db=metadata_db,
269269
date=date,
@@ -277,7 +277,7 @@ def daily_extend(
277277
precision=precision,
278278
rng=rng,
279279
)
280-
yield ts, date
280+
yield ts, excluded_samples, date
281281
last_ts = ts
282282

283283

@@ -380,7 +380,7 @@ def extend(
380380
)
381381
ts = increment_time(date, base_ts)
382382

383-
return add_matching_results(
383+
ts, excluded_samples = add_matching_results(
384384
samples=samples,
385385
ts=ts,
386386
date=date,
@@ -389,6 +389,8 @@ def extend(
389389
show_progress=show_progress,
390390
)
391391

392+
return ts, excluded_samples
393+
392394

393395
def match_path_ts(samples, ts, path, reversions):
394396
"""
@@ -449,9 +451,11 @@ def add_matching_results(
449451

450452
# Group matches by path and set of immediate reversions.
451453
grouped_matches = collections.defaultdict(list)
454+
excluded_samples = []
452455
site_masked_samples = np.zeros(int(ts.sequence_length), dtype=int)
453456
for sample in samples:
454457
if sample.get_hmm_cost(num_mismatches) > max_hmm_cost:
458+
excluded_samples.append(sample)
455459
continue
456460
site_masked_samples[sample.masked_sites] += 1
457461
path = tuple(sample.path)
@@ -532,7 +536,8 @@ def add_matching_results(
532536
# print("AFTER")
533537
# print(ts.draw_text())
534538
ts = coalesce_mutations(ts, attach_nodes)
535-
return ts
539+
540+
return ts, excluded_samples
536541

537542

538543
def solve_num_mismatches(ts, k):

tests/test_inference.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ class TestAddMatchingResults:
1212
def add_matching_results(
1313
self, samples, ts, date="2020-01-01", num_mismatches=None, max_hmm_cost=None
1414
):
15-
ts2 = sc2ts.add_matching_results(
16-
samples,
17-
ts,
18-
date,
19-
num_mismatches,
20-
max_hmm_cost,
15+
ts2, _ = sc2ts.add_matching_results(
16+
samples=samples,
17+
ts=ts,
18+
date=date,
19+
num_mismatches=num_mismatches,
20+
max_hmm_cost=max_hmm_cost,
2121
)
2222
assert ts2.num_samples == len(samples) + ts.num_samples
2323
for u, sample in zip(ts2.samples()[-len(samples) :], samples):
@@ -76,7 +76,7 @@ def test_one_sample_recombinant_filtered(self):
7676
x = L / 2
7777
samples = util.get_samples(ts, [[(0, x, 2), (x, L, 3)]])
7878
# Note that it is calling the function in the main module.
79-
ts2 = sc2ts.add_matching_results(
79+
ts2, _ = sc2ts.add_matching_results(
8080
samples, ts, "2021", num_mismatches=1e3, max_hmm_cost=1e3 - 1
8181
)
8282
assert ts2.num_trees == 1
@@ -97,7 +97,7 @@ def test_two_samples_recombinant_one_filtered(self):
9797
], # Filtered
9898
]
9999
samples = util.get_samples(ts, new_paths)
100-
ts2 = sc2ts.add_matching_results(
100+
ts2, _ = sc2ts.add_matching_results(
101101
samples, ts, "2021", num_mismatches=3, max_hmm_cost=4
102102
)
103103
assert ts2.num_trees == 2
@@ -124,7 +124,7 @@ def test_one_sample_one_mutation_filtered(self):
124124
samples = util.get_samples(
125125
ts, [[(0, ts.sequence_length, 1)]], mutations=[[(0, "X")]]
126126
)
127-
ts2 = sc2ts.add_matching_results(
127+
ts2, _ = sc2ts.add_matching_results(
128128
samples, ts, "2021", num_mismatches=0.0, max_hmm_cost=0.0
129129
)
130130
assert ts2.num_trees == ts.num_trees
@@ -148,7 +148,7 @@ def test_two_samples_one_mutation_one_filtered(self):
148148
paths=new_paths,
149149
mutations=new_mutations,
150150
)
151-
ts2 = sc2ts.add_matching_results(
151+
ts2, _ = sc2ts.add_matching_results(
152152
samples, ts, "2021", num_mismatches=3, max_hmm_cost=1
153153
)
154154
assert ts2.num_trees == ts.num_trees

tests/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def make_recombinant_tree(self, num_samples=1):
3434
L = ts.sequence_length
3535
x = L / 2
3636
samples = util.get_samples(ts, [[(0, x, 2), (x, L, 3)]] * num_samples)
37-
ts_rec = sc2ts.add_matching_results(
37+
ts_rec, _ = sc2ts.add_matching_results(
3838
samples, ts, "2021", num_mismatches=None, max_hmm_cost=None
3939
)
4040
assert ts_rec.num_trees == 2

0 commit comments

Comments
 (0)