Skip to content

Commit 23ad2d3

Browse files
Impove match testing and infrastructure
Add metadata to exact match samples Closes #238
1 parent 7d5a5be commit 23ad2d3

File tree

4 files changed

+221
-102
lines changed

4 files changed

+221
-102
lines changed

run.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@ num_threads=8
99

1010
# Paths
1111
datadir=testrun
12-
run_id=tmp-dev
12+
run_id=tmp-dev-hp
1313
# run_id=upgma-mds-$max_daily_samples-md-$max_submission_delay-mm-$mismatches
1414
resultsdir=results/$run_id
1515
results_prefix=$resultsdir/$run_id-
1616
logfile=logs/$run_id.log
1717

1818
alignments=$datadir/alignments.db
1919
metadata=$datadir/metadata.db
20-
matches=$resultsdir/matces.db
20+
matches=$resultsdir/matches.db
2121

22-
dates=`python3 -m sc2ts list-dates $metadata | grep -v 2021-12-31`
22+
dates=`python3 -m sc2ts list-dates $metadata | grep -v 2021-12-31 | head -n 14`
2323
echo $dates
2424

2525
options="--num-threads $num_threads -vv -l $logfile "

sc2ts/inference.py

Lines changed: 41 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -404,35 +404,36 @@ def match_samples(
404404
# Default to no recombination
405405
num_mismatches = 1000
406406

407-
remaining_samples = samples
408407
# FIXME Something wrong here, we don't seem to get precisely the same
409408
# ARG for some reason. Need to track it down
410409
# Also: should only run the things at low precision that have that HMM cost.
411410
# Start out by setting everything to have 0 mutations and work up from there.
412411

413-
for cost, precision in [(0, 0), (1, 2)]: #, (2, 3)]:
414-
match_tsinfer(
415-
samples=remaining_samples,
416-
ts=base_ts,
417-
num_mismatches=num_mismatches,
418-
precision=precision,
419-
num_threads=num_threads,
420-
show_progress=show_progress,
421-
mirror_coordinates=mirror_coordinates,
422-
)
423-
samples_to_rerun = []
424-
for sample in remaining_samples:
425-
hmm_cost = sample.get_hmm_cost(num_mismatches)
426-
# print(f"HMM@p={precision}: {sample.strain} hmm_cost={hmm_cost} path={sample.path}")
427-
logger.debug(
428-
f"HMM@p={precision}: {sample.strain} hmm_cost={hmm_cost} path={sample.path}"
429-
)
430-
if hmm_cost > cost:
431-
sample.path.clear()
432-
sample.mutations.clear()
433-
samples_to_rerun.append(sample)
434-
remaining_samples = samples_to_rerun
435-
412+
# remaining_samples = samples
413+
# for cost, precision in [(0, 0), (1, 2)]: #, (2, 3)]:
414+
# match_tsinfer(
415+
# samples=remaining_samples,
416+
# ts=base_ts,
417+
# num_mismatches=num_mismatches,
418+
# precision=precision,
419+
# num_threads=num_threads,
420+
# show_progress=show_progress,
421+
# mirror_coordinates=mirror_coordinates,
422+
# )
423+
# samples_to_rerun = []
424+
# for sample in remaining_samples:
425+
# hmm_cost = sample.get_hmm_cost(num_mismatches)
426+
# # print(f"HMM@p={precision}: {sample.strain} hmm_cost={hmm_cost} path={sample.path}")
427+
# logger.debug(
428+
# f"HMM@p={precision}: {sample.strain} hmm_cost={hmm_cost} path={sample.path}"
429+
# )
430+
# if hmm_cost > cost:
431+
# sample.path.clear()
432+
# sample.mutations.clear()
433+
# samples_to_rerun.append(sample)
434+
# remaining_samples = samples_to_rerun
435+
436+
samples_to_rerun = samples
436437
match_tsinfer(
437438
samples=samples_to_rerun,
438439
ts=base_ts,
@@ -605,6 +606,18 @@ def update_top_level_metadata(ts, date):
605606
return tables.tree_sequence()
606607

607608

609+
def add_sample_to_tables(sample, tables, flags=tskit.NODE_IS_SAMPLE, time=0):
610+
metadata = {
611+
**sample.metadata,
612+
"sc2ts": {
613+
"qc": sample.alignment_qc,
614+
"path": [x.asdict() for x in sample.path],
615+
"mutations": [x.asdict() for x in sample.mutations],
616+
},
617+
}
618+
return tables.nodes.add_row(flags=flags, time=time, metadata=metadata)
619+
620+
608621
def match_path_ts(samples, ts, path, reversions):
609622
"""
610623
Given the specified list of samples with equal copying paths,
@@ -623,17 +636,7 @@ def match_path_ts(samples, ts, path, reversions):
623636
)
624637
for sample in samples:
625638
assert sample.path == path
626-
metadata = {
627-
**sample.metadata,
628-
"sc2ts": {
629-
"qc": sample.alignment_qc,
630-
"path": [x.asdict() for x in sample.path],
631-
"mutations": [x.asdict() for x in sample.mutations],
632-
},
633-
}
634-
node_id = tables.nodes.add_row(
635-
flags=tskit.NODE_IS_SAMPLE, time=0, metadata=metadata
636-
)
639+
node_id = add_sample_to_tables(sample, tables)
637640
tables.edges.add_row(0, ts.sequence_length, parent=0, child=node_id)
638641
for mut in sample.mutations:
639642
if mut.site_id not in site_id_map:
@@ -671,10 +674,10 @@ def add_exact_matches(match_db, ts, date):
671674
for sample in samples:
672675
assert len(sample.path) == 1
673676
assert len(sample.mutations) == 0
674-
node_id = tables.nodes.add_row(
677+
node_id = add_sample_to_tables(
678+
sample,
679+
tables,
675680
flags=tskit.NODE_IS_SAMPLE | core.NODE_IS_EXACT_MATCH,
676-
time=0,
677-
metadata=sample.metadata,
678681
)
679682
parent = sample.path[0].parent
680683
logger.debug(f"ARG add exact match {sample.strain}:{node_id}->{parent}")

tests/conftest.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def fx_alignment_store(fx_data_cache, fx_alignments_fasta):
3535
a.append(fasta, show_progress=False)
3636
return sc2ts.AlignmentStore(cache_path)
3737

38+
3839
@pytest.fixture
3940
def fx_metadata_db(fx_data_cache):
4041
cache_path = fx_data_cache / "metadata.db"
@@ -44,26 +45,46 @@ def fx_metadata_db(fx_data_cache):
4445
return sc2ts.MetadataDb(cache_path)
4546

4647

48+
# TODO make this a session fixture cacheing the tree sequences.
4749
@pytest.fixture
48-
def fx_ts_2020_02_10(tmp_path, fx_data_cache, fx_metadata_db, fx_alignment_store):
49-
target_date = "2020-02-10"
50-
cache_path = fx_data_cache / f"{target_date}.ts"
50+
def fx_ts_map(tmp_path, fx_data_cache, fx_metadata_db, fx_alignment_store):
51+
dates = [
52+
"2020-01-01",
53+
"2020-01-19",
54+
"2020-01-24",
55+
"2020-01-25",
56+
"2020-01-28",
57+
"2020-01-29",
58+
"2020-01-30",
59+
"2020-01-31",
60+
"2020-02-01",
61+
"2020-02-02",
62+
"2020-02-03",
63+
"2020-02-04",
64+
"2020-02-05",
65+
"2020-02-06",
66+
"2020-02-07",
67+
"2020-02-08",
68+
"2020-02-09",
69+
"2020-02-10",
70+
"2020-02-11",
71+
"2020-02-13",
72+
]
73+
cache_path = fx_data_cache / f"{dates[-1]}.ts"
5174
if not cache_path.exists():
5275
last_ts = sc2ts.initial_ts()
5376
match_db = sc2ts.MatchDb.initialise(tmp_path / "match.db")
54-
for date in fx_metadata_db.date_sample_counts():
55-
print("INFERRING", date)
77+
for date in dates:
5678
last_ts = sc2ts.extend(
5779
alignment_store=fx_alignment_store,
5880
metadata_db=fx_metadata_db,
5981
base_ts=last_ts,
6082
date=date,
6183
match_db=match_db,
62-
min_group_size=2,
6384
)
64-
if date == target_date:
65-
break
66-
last_ts.dump(cache_path)
67-
return tskit.load(cache_path)
68-
69-
85+
print(
86+
f"INFERRED {date} nodes={last_ts.num_nodes} mutations={last_ts.num_mutations}"
87+
)
88+
cache_path = fx_data_cache / f"{date}.ts"
89+
last_ts.dump(cache_path)
90+
return {date: tskit.load(fx_data_cache / f"{date}.ts") for date in dates}

0 commit comments

Comments
 (0)