10
10
11
11
class TestAddMatchingResults :
12
12
def add_matching_results (
13
- self , samples , ts , date = "2020-01-01" , num_mismatches = None , max_hmm_cost = None
13
+ self ,
14
+ samples ,
15
+ ts ,
16
+ db_path ,
17
+ date = "2020-01-01" ,
18
+ num_mismatches = 1000 ,
19
+ max_hmm_cost = 1e7 ,
14
20
):
21
+ # This is pretty ugly, need to figure out how to neatly factor this
22
+ # model of Sample object vs metadata vs alignment QC
23
+ for sample in samples :
24
+ sample .date = date
25
+ sample .metadata ["date" ] = date
26
+ sample .metadata ["strain" ] = sample .strain
27
+
28
+ match_db = util .get_match_db (ts , db_path , samples , date , num_mismatches )
29
+ # print("Match DB", len(match_db))
30
+ # match_db.print_all()
15
31
ts2 = sc2ts .add_matching_results (
16
- samples = samples ,
32
+ f"hmm_cost <= { max_hmm_cost } " ,
33
+ match_db = match_db ,
17
34
ts = ts ,
18
35
date = date ,
19
- num_mismatches = num_mismatches ,
20
- max_hmm_cost = max_hmm_cost ,
21
36
)
22
- assert ts2 .num_samples == len (samples ) + ts .num_samples
23
- for u , sample in zip (ts2 .samples ()[- len (samples ) :], samples ):
24
- node = ts2 .node (u )
25
- assert node .time == 0
37
+ # assert ts2.num_samples == len(samples) + ts.num_samples
38
+ # for u, sample in zip(ts2.samples()[-len(samples) :], samples):
39
+ # node = ts2.node(u)
40
+ # assert node.time == 0
26
41
assert ts2 .num_sites == ts .num_sites
27
42
return ts2
28
43
29
- def test_one_sample (self ):
44
+ def test_one_sample (self , tmp_path ):
30
45
# 4.00┊ 0 ┊
31
46
# ┊ ┃ ┊
32
47
# 3.00┊ 1 ┊
@@ -37,12 +52,12 @@ def test_one_sample(self):
37
52
# 0 29904
38
53
ts = util .example_binary (2 )
39
54
samples = util .get_samples (ts , [[(0 , ts .sequence_length , 1 )]])
40
- ts2 = self .add_matching_results (samples , ts )
55
+ ts2 = self .add_matching_results (samples , ts , tmp_path / "match.db" )
41
56
assert ts2 .num_trees == 1
42
57
tree = ts2 .first ()
43
58
assert tree .parent_dict == {1 : 0 , 4 : 1 , 2 : 4 , 3 : 4 , 5 : 1 }
44
59
45
- def test_one_sample_recombinant (self ):
60
+ def test_one_sample_recombinant (self , tmp_path ):
46
61
# 4.00┊ 0 ┊
47
62
# ┊ ┃ ┊
48
63
# 3.00┊ 1 ┊
@@ -55,14 +70,16 @@ def test_one_sample_recombinant(self):
55
70
L = ts .sequence_length
56
71
x = L / 2
57
72
samples = util .get_samples (ts , [[(0 , x , 2 ), (x , L , 3 )]])
58
- ts2 = self .add_matching_results (samples , ts , "2021" )
73
+ date = "2021-01-05"
74
+ ts2 = self .add_matching_results (samples , ts , tmp_path / "match.db" , date = date )
75
+
59
76
assert ts2 .num_trees == 2
60
77
assert ts2 .first ().parent_dict == {1 : 0 , 4 : 1 , 2 : 4 , 3 : 4 , 6 : 2 , 5 : 6 }
61
78
assert ts2 .last ().parent_dict == {1 : 0 , 4 : 1 , 2 : 4 , 3 : 4 , 6 : 3 , 5 : 6 }
62
79
assert ts2 .node (6 ).flags == sc2ts .NODE_IS_RECOMBINANT
63
- assert ts2 .node (6 ).metadata == {"date_added" : "2021" }
80
+ assert ts2 .node (6 ).metadata == {"date_added" : date }
64
81
65
- def test_one_sample_recombinant_filtered (self ):
82
+ def test_one_sample_recombinant_filtered (self , tmp_path ):
66
83
# 4.00┊ 0 ┊
67
84
# ┊ ┃ ┊
68
85
# 3.00┊ 1 ┊
@@ -75,15 +92,14 @@ def test_one_sample_recombinant_filtered(self):
75
92
L = ts .sequence_length
76
93
x = L / 2
77
94
samples = util .get_samples (ts , [[(0 , x , 2 ), (x , L , 3 )]])
78
- # Note that it is calling the function in the main module.
79
- ts2 = sc2ts .add_matching_results (
80
- samples , ts , "2021" , num_mismatches = 1e3 , max_hmm_cost = 1e3 - 1
95
+ ts2 = self .add_matching_results (
96
+ samples , ts , tmp_path / "match.db" , num_mismatches = 1e3 , max_hmm_cost = 1e3 - 1
81
97
)
82
98
assert ts2 .num_trees == 1
83
99
assert ts2 .num_nodes == ts .num_nodes
84
100
assert ts2 .num_samples == ts .num_samples
85
101
86
- def test_two_samples_recombinant_one_filtered (self ):
102
+ def test_two_samples_recombinant_one_filtered (self , tmp_path ):
87
103
ts = util .example_binary (2 )
88
104
L = ts .sequence_length
89
105
x = L / 2
@@ -97,19 +113,19 @@ def test_two_samples_recombinant_one_filtered(self):
97
113
], # Filtered
98
114
]
99
115
samples = util .get_samples (ts , new_paths )
100
- ts2 = sc2ts .add_matching_results (
101
- samples , ts , "2021 " , num_mismatches = 3 , max_hmm_cost = 4
116
+ ts2 = self .add_matching_results (
117
+ samples , ts , tmp_path / "match.db " , num_mismatches = 3 , max_hmm_cost = 4
102
118
)
103
119
assert ts2 .num_trees == 2
104
120
assert ts2 .num_samples == ts .num_samples + 1
105
121
106
- def test_one_sample_one_mutation (self ):
122
+ def test_one_sample_one_mutation (self , tmp_path ):
107
123
ts = sc2ts .initial_ts ()
108
124
ts = sc2ts .increment_time ("2020-01-01" , ts )
109
125
samples = util .get_samples (
110
126
ts , [[(0 , ts .sequence_length , 1 )]], mutations = [[(0 , "X" )]]
111
127
)
112
- ts2 = self .add_matching_results (samples , ts )
128
+ ts2 = self .add_matching_results (samples , ts , tmp_path / "match.db" )
113
129
assert ts2 .num_trees == 1
114
130
tree = ts2 .first ()
115
131
assert tree .parent_dict == {1 : 0 , 2 : 1 }
@@ -118,20 +134,20 @@ def test_one_sample_one_mutation(self):
118
134
var = next (ts2 .variants ())
119
135
assert var .alleles [var .genotypes [0 ]] == "X"
120
136
121
- def test_one_sample_one_mutation_filtered (self ):
137
+ def test_one_sample_one_mutation_filtered (self , tmp_path ):
122
138
ts = sc2ts .initial_ts ()
123
139
ts = sc2ts .increment_time ("2020-01-01" , ts )
124
140
samples = util .get_samples (
125
141
ts , [[(0 , ts .sequence_length , 1 )]], mutations = [[(0 , "X" )]]
126
142
)
127
- ts2 = sc2ts .add_matching_results (
128
- samples , ts , "2021 " , num_mismatches = 0.0 , max_hmm_cost = 0.0
143
+ ts2 = self .add_matching_results (
144
+ samples , ts , tmp_path / "match.db " , num_mismatches = 0.0 , max_hmm_cost = 0.0
129
145
)
130
146
assert ts2 .num_trees == ts .num_trees
131
147
assert ts2 .site (0 ).ancestral_state == ts .site (0 ).ancestral_state
132
148
assert ts2 .num_mutations == 0
133
149
134
- def test_two_samples_one_mutation_one_filtered (self ):
150
+ def test_two_samples_one_mutation_one_filtered (self , tmp_path ):
135
151
ts = sc2ts .initial_ts ()
136
152
ts = sc2ts .increment_time ("2020-01-01" , ts )
137
153
x = int (ts .sequence_length / 2 )
@@ -148,8 +164,8 @@ def test_two_samples_one_mutation_one_filtered(self):
148
164
paths = new_paths ,
149
165
mutations = new_mutations ,
150
166
)
151
- ts2 = sc2ts .add_matching_results (
152
- samples , ts , "2021 " , num_mismatches = 3 , max_hmm_cost = 1
167
+ ts2 = self .add_matching_results (
168
+ samples , ts , tmp_path / "match.db " , num_mismatches = 3 , max_hmm_cost = 1
153
169
)
154
170
assert ts2 .num_trees == ts .num_trees
155
171
assert ts2 .site (0 ).ancestral_state == ts .site (0 ).ancestral_state
@@ -162,7 +178,9 @@ class TestMatchTsinfer:
162
178
def match_tsinfer (self , samples , ts , haplotypes , ** kwargs ):
163
179
assert len (samples ) == len (haplotypes )
164
180
G = np .array (haplotypes ).T
165
- sc2ts .inference .match_tsinfer (samples = samples , ts = ts , genotypes = G , ** kwargs )
181
+ sc2ts .inference .match_tsinfer (
182
+ samples = samples , ts = ts , genotypes = G , num_mismatches = 1000 , ** kwargs
183
+ )
166
184
167
185
@pytest .mark .parametrize ("mirror" , [False , True ])
168
186
def test_match_reference (self , mirror ):
@@ -351,8 +369,12 @@ def test_n_samples_metadata(self):
351
369
ts = sc2ts .initial_ts ()
352
370
samples = []
353
371
for j in range (10 ):
372
+ strain = f"x{ j } "
373
+ date = "2021-01-01"
354
374
samples .append (
355
375
sc2ts .Sample (
376
+ strain = strain ,
377
+ date = date ,
356
378
metadata = {f"x{ j } " : j , f"y{ j } " : list (range (j ))},
357
379
path = [(0 , ts .sequence_length , 1 )],
358
380
mutations = [],
0 commit comments