Skip to content

Commit 8b59740

Browse files
Fixup tests
More fixups Patch-up remaining tests
1 parent f84f74f commit 8b59740

File tree

4 files changed

+105
-45
lines changed

4 files changed

+105
-45
lines changed

sc2ts/inference.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
class MatchDb:
2929
def __init__(self, path):
3030
uri = f"file:{path}"
31-
# uri += "?mode=rw"
3231
self.uri = uri
3332
self.conn = sqlite3.connect(uri, uri=True)
3433
self.conn.row_factory = metadata.dict_factory
@@ -151,6 +150,18 @@ def initialise(db_path):
151150
return MatchDb(db_path)
152151

153152

153+
def print_all(self):
154+
"""
155+
Debug method to print out full state of the DB.
156+
"""
157+
import pandas as pd
158+
data = []
159+
with self.conn:
160+
for row in self.conn.execute("SELECT * from samples"):
161+
data.append(row)
162+
df = pd.DataFrame(row, index=["strain"])
163+
print(df)
164+
154165
def mirror(x, L):
155166
return L - x
156167

@@ -363,8 +374,8 @@ def asdict(self):
363374
"strain": self.strain,
364375
"path": self.path,
365376
"mutations": self.mutations,
366-
"masked_sites": self.masked_sites.tolist(),
367-
"alignment_qc": self.alignment_qc,
377+
# "masked_sites": self.masked_sites.tolist(),
378+
# "alignment_qc": self.alignment_qc,
368379
}
369380

370381

tests/test_inference.py

Lines changed: 51 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,38 @@
1010

1111
class TestAddMatchingResults:
1212
def add_matching_results(
13-
self, samples, ts, date="2020-01-01", num_mismatches=None, max_hmm_cost=None
13+
self,
14+
samples,
15+
ts,
16+
db_path,
17+
date="2020-01-01",
18+
num_mismatches=1000,
19+
max_hmm_cost=1e7,
1420
):
21+
# This is pretty ugly, need to figure out how to neatly factor this
22+
# model of Sample object vs metadata vs alignment QC
23+
for sample in samples:
24+
sample.date = date
25+
sample.metadata["date"] = date
26+
sample.metadata["strain"] = sample.strain
27+
28+
match_db = util.get_match_db(ts, db_path, samples, date, num_mismatches)
29+
# print("Match DB", len(match_db))
30+
# match_db.print_all()
1531
ts2 = sc2ts.add_matching_results(
16-
samples=samples,
32+
f"hmm_cost <= {max_hmm_cost}",
33+
match_db=match_db,
1734
ts=ts,
1835
date=date,
19-
num_mismatches=num_mismatches,
20-
max_hmm_cost=max_hmm_cost,
2136
)
22-
assert ts2.num_samples == len(samples) + ts.num_samples
23-
for u, sample in zip(ts2.samples()[-len(samples) :], samples):
24-
node = ts2.node(u)
25-
assert node.time == 0
37+
# assert ts2.num_samples == len(samples) + ts.num_samples
38+
# for u, sample in zip(ts2.samples()[-len(samples) :], samples):
39+
# node = ts2.node(u)
40+
# assert node.time == 0
2641
assert ts2.num_sites == ts.num_sites
2742
return ts2
2843

29-
def test_one_sample(self):
44+
def test_one_sample(self, tmp_path):
3045
# 4.00┊ 0 ┊
3146
# ┊ ┃ ┊
3247
# 3.00┊ 1 ┊
@@ -37,12 +52,12 @@ def test_one_sample(self):
3752
# 0 29904
3853
ts = util.example_binary(2)
3954
samples = util.get_samples(ts, [[(0, ts.sequence_length, 1)]])
40-
ts2 = self.add_matching_results(samples, ts)
55+
ts2 = self.add_matching_results(samples, ts, tmp_path / "match.db")
4156
assert ts2.num_trees == 1
4257
tree = ts2.first()
4358
assert tree.parent_dict == {1: 0, 4: 1, 2: 4, 3: 4, 5: 1}
4459

45-
def test_one_sample_recombinant(self):
60+
def test_one_sample_recombinant(self, tmp_path):
4661
# 4.00┊ 0 ┊
4762
# ┊ ┃ ┊
4863
# 3.00┊ 1 ┊
@@ -55,14 +70,16 @@ def test_one_sample_recombinant(self):
5570
L = ts.sequence_length
5671
x = L / 2
5772
samples = util.get_samples(ts, [[(0, x, 2), (x, L, 3)]])
58-
ts2 = self.add_matching_results(samples, ts, "2021")
73+
date = "2021-01-05"
74+
ts2 = self.add_matching_results(samples, ts, tmp_path / "match.db", date=date)
75+
5976
assert ts2.num_trees == 2
6077
assert ts2.first().parent_dict == {1: 0, 4: 1, 2: 4, 3: 4, 6: 2, 5: 6}
6178
assert ts2.last().parent_dict == {1: 0, 4: 1, 2: 4, 3: 4, 6: 3, 5: 6}
6279
assert ts2.node(6).flags == sc2ts.NODE_IS_RECOMBINANT
63-
assert ts2.node(6).metadata == {"date_added": "2021"}
80+
assert ts2.node(6).metadata == {"date_added": date}
6481

65-
def test_one_sample_recombinant_filtered(self):
82+
def test_one_sample_recombinant_filtered(self, tmp_path):
6683
# 4.00┊ 0 ┊
6784
# ┊ ┃ ┊
6885
# 3.00┊ 1 ┊
@@ -75,15 +92,14 @@ def test_one_sample_recombinant_filtered(self):
7592
L = ts.sequence_length
7693
x = L / 2
7794
samples = util.get_samples(ts, [[(0, x, 2), (x, L, 3)]])
78-
# Note that it is calling the function in the main module.
79-
ts2 = sc2ts.add_matching_results(
80-
samples, ts, "2021", num_mismatches=1e3, max_hmm_cost=1e3 - 1
95+
ts2 = self.add_matching_results(
96+
samples, ts, tmp_path / "match.db", num_mismatches=1e3, max_hmm_cost=1e3 - 1
8197
)
8298
assert ts2.num_trees == 1
8399
assert ts2.num_nodes == ts.num_nodes
84100
assert ts2.num_samples == ts.num_samples
85101

86-
def test_two_samples_recombinant_one_filtered(self):
102+
def test_two_samples_recombinant_one_filtered(self, tmp_path):
87103
ts = util.example_binary(2)
88104
L = ts.sequence_length
89105
x = L / 2
@@ -97,19 +113,19 @@ def test_two_samples_recombinant_one_filtered(self):
97113
], # Filtered
98114
]
99115
samples = util.get_samples(ts, new_paths)
100-
ts2 = sc2ts.add_matching_results(
101-
samples, ts, "2021", num_mismatches=3, max_hmm_cost=4
116+
ts2 = self.add_matching_results(
117+
samples, ts, tmp_path / "match.db", num_mismatches=3, max_hmm_cost=4
102118
)
103119
assert ts2.num_trees == 2
104120
assert ts2.num_samples == ts.num_samples + 1
105121

106-
def test_one_sample_one_mutation(self):
122+
def test_one_sample_one_mutation(self, tmp_path):
107123
ts = sc2ts.initial_ts()
108124
ts = sc2ts.increment_time("2020-01-01", ts)
109125
samples = util.get_samples(
110126
ts, [[(0, ts.sequence_length, 1)]], mutations=[[(0, "X")]]
111127
)
112-
ts2 = self.add_matching_results(samples, ts)
128+
ts2 = self.add_matching_results(samples, ts, tmp_path / "match.db")
113129
assert ts2.num_trees == 1
114130
tree = ts2.first()
115131
assert tree.parent_dict == {1: 0, 2: 1}
@@ -118,20 +134,20 @@ def test_one_sample_one_mutation(self):
118134
var = next(ts2.variants())
119135
assert var.alleles[var.genotypes[0]] == "X"
120136

121-
def test_one_sample_one_mutation_filtered(self):
137+
def test_one_sample_one_mutation_filtered(self, tmp_path):
122138
ts = sc2ts.initial_ts()
123139
ts = sc2ts.increment_time("2020-01-01", ts)
124140
samples = util.get_samples(
125141
ts, [[(0, ts.sequence_length, 1)]], mutations=[[(0, "X")]]
126142
)
127-
ts2 = sc2ts.add_matching_results(
128-
samples, ts, "2021", num_mismatches=0.0, max_hmm_cost=0.0
143+
ts2 = self.add_matching_results(
144+
samples, ts, tmp_path / "match.db", num_mismatches=0.0, max_hmm_cost=0.0
129145
)
130146
assert ts2.num_trees == ts.num_trees
131147
assert ts2.site(0).ancestral_state == ts.site(0).ancestral_state
132148
assert ts2.num_mutations == 0
133149

134-
def test_two_samples_one_mutation_one_filtered(self):
150+
def test_two_samples_one_mutation_one_filtered(self, tmp_path):
135151
ts = sc2ts.initial_ts()
136152
ts = sc2ts.increment_time("2020-01-01", ts)
137153
x = int(ts.sequence_length / 2)
@@ -148,8 +164,8 @@ def test_two_samples_one_mutation_one_filtered(self):
148164
paths=new_paths,
149165
mutations=new_mutations,
150166
)
151-
ts2= sc2ts.add_matching_results(
152-
samples, ts, "2021", num_mismatches=3, max_hmm_cost=1
167+
ts2 = self.add_matching_results(
168+
samples, ts, tmp_path / "match.db", num_mismatches=3, max_hmm_cost=1
153169
)
154170
assert ts2.num_trees == ts.num_trees
155171
assert ts2.site(0).ancestral_state == ts.site(0).ancestral_state
@@ -162,7 +178,9 @@ class TestMatchTsinfer:
162178
def match_tsinfer(self, samples, ts, haplotypes, **kwargs):
163179
assert len(samples) == len(haplotypes)
164180
G = np.array(haplotypes).T
165-
sc2ts.inference.match_tsinfer(samples=samples, ts=ts, genotypes=G, **kwargs)
181+
sc2ts.inference.match_tsinfer(
182+
samples=samples, ts=ts, genotypes=G, num_mismatches=1000, **kwargs
183+
)
166184

167185
@pytest.mark.parametrize("mirror", [False, True])
168186
def test_match_reference(self, mirror):
@@ -351,8 +369,12 @@ def test_n_samples_metadata(self):
351369
ts = sc2ts.initial_ts()
352370
samples = []
353371
for j in range(10):
372+
strain = f"x{j}"
373+
date = "2021-01-01"
354374
samples.append(
355375
sc2ts.Sample(
376+
strain=strain,
377+
date=date,
356378
metadata={f"x{j}": j, f"y{j}": list(range(j))},
357379
path=[(0, ts.sequence_length, 1)],
358380
mutations=[],

tests/test_utils.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def test_initial(self):
1919

2020

2121
class TestDetachSingletonRecombinants:
22-
def make_recombinant_tree(self, num_samples=1):
22+
def make_recombinant_tree(self, db_path, num_samples=1):
2323
# Make a tree sequence by adding num_samples samples under a
2424
# single recombination node. Start with the following tree:
2525
# 4.00┊ 0 ┊
@@ -34,8 +34,23 @@ def make_recombinant_tree(self, num_samples=1):
3434
L = ts.sequence_length
3535
x = L / 2
3636
samples = util.get_samples(ts, [[(0, x, 2), (x, L, 3)]] * num_samples)
37+
date = "2021-01-01"
38+
39+
# This is pretty ugly, need to figure out how to neatly factor this
40+
# model of Sample object vs metadata vs alignment QC
41+
# NOTE: code copied from test_inference.py
42+
for sample in samples:
43+
sample.date = date
44+
sample.metadata["date"] = date
45+
sample.metadata["strain"] = sample.strain
46+
match_db = util.get_match_db(ts, db_path, samples, date, num_mismatches=1000)
47+
# print("Match DB", len(match_db))
48+
# match_db.print_all()
3749
ts_rec = sc2ts.add_matching_results(
38-
samples, ts, "2021", num_mismatches=None, max_hmm_cost=None
50+
"True",
51+
match_db=match_db,
52+
ts=ts,
53+
date=date,
3954
)
4055
assert ts_rec.num_trees == 2
4156
return ts_rec
@@ -46,12 +61,12 @@ def make_recombinant_tree(self, num_samples=1):
4661
# https://github.com/jeromekelleher/sc2ts/issues/152
4762
[util.example_binary(1), util.example_binary(2), util.example_binary(3)],
4863
)
49-
def test_no_recombinants(self, ts):
64+
def test_no_recombinants(self, ts, tmp_path):
5065
ts2 = utils.detach_singleton_recombinants(ts)
5166
ts.tables.assert_equals(ts2.tables, ignore_provenance=True)
5267

53-
def test_one_sample_recombinant(self):
54-
ts = self.make_recombinant_tree()
68+
def test_one_sample_recombinant(self, tmp_path):
69+
ts = self.make_recombinant_tree(tmp_path / "match.db")
5570
assert ts.num_samples == 3
5671
re_nodes = [
5772
node.id for node in ts.nodes() if node.flags & sc2ts.NODE_IS_RECOMBINANT
@@ -75,9 +90,9 @@ def test_one_sample_recombinant(self):
7590
assert ts3.num_samples == ts.num_samples - 1
7691
assert ts3.num_nodes == ts.num_nodes - 2 # both sample and re node gone
7792

78-
def test_two_sample_recombinant(self):
93+
def test_two_sample_recombinant(self, tmp_path):
7994
"""Test that we don't detach anything if the recombinant node is not a singleton"""
80-
ts = self.make_recombinant_tree(num_samples=2)
95+
ts = self.make_recombinant_tree(num_samples=2, db_path=tmp_path / "match.db")
8196
assert ts.num_samples == 4
8297
ts2 = utils.detach_singleton_recombinants(ts)
8398
ts.tables.assert_equals(ts2.tables, ignore_provenance=True)

tests/util.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# NOTE: the current API in which we update the Sample objects is
88
# really horrible and we need to refactor to make it more testable.
99
# This function is a symptom of that.
10-
def get_samples(ts, paths, mutations=None):
10+
def get_samples(ts, paths, mutations=None, date=None):
1111
if mutations is None:
1212
mutations = [[] for _ in paths]
1313

@@ -18,12 +18,21 @@ def get_samples(ts, paths, mutations=None):
1818
(ts.sites_position[site], state) for (site, state) in sample_mutations
1919
]
2020
updated_mutations.append(updated)
21-
samples = [sc2ts.Sample() for _ in paths]
21+
data = "2020-12-29" if date is None else date
22+
samples = [sc2ts.Sample(f"strain_{j}", date) for j, _ in enumerate(paths)]
2223
sc2ts.update_path_info(samples, ts, paths, updated_mutations)
2324
return samples
2425

2526

26-
def example_binary(n):
27+
def get_match_db(ts, db_path, samples, date, num_mismatches):
28+
sc2ts.MatchDb.initialise(db_path)
29+
match_db = sc2ts.MatchDb(db_path)
30+
match_db.add(samples, date, num_mismatches)
31+
match_db.create_mask_table(ts)
32+
return match_db
33+
34+
35+
def example_binary(n, date="2020-01-01"):
2736
base = sc2ts.initial_ts()
2837
tables = base.dump_tables()
2938
tree = tskit.Tree.generate_balanced(n, span=base.sequence_length)
@@ -32,12 +41,15 @@ def example_binary(n):
3241
tables.nodes.time += np.max(binary_tables.nodes.time) + 1
3342
binary_tables.edges.child += len(tables.nodes)
3443
binary_tables.edges.parent += len(tables.nodes)
35-
for node in binary_tables.nodes:
36-
tables.nodes.append(node.replace(metadata={}))
44+
for j, node in enumerate(binary_tables.nodes):
45+
md = {}
46+
if node.flags == tskit.NODE_IS_SAMPLE:
47+
md["strain"] = f"x{j}"
48+
md["date"] = date
49+
tables.nodes.append(node.replace(metadata=md))
3750
for edge in binary_tables.edges:
3851
tables.edges.append(edge)
3952
# FIXME brittle
4053
tables.edges.add_row(0, base.sequence_length, parent=1, child=tree.root + 2)
4154
tables.sort()
4255
return tables.tree_sequence()
43-

0 commit comments

Comments
 (0)