@@ -175,11 +175,9 @@ def test_two_samples_one_mutation_one_filtered(self, tmp_path):
175
175
176
176
177
177
class TestMatchTsinfer :
178
- def match_tsinfer (self , samples , ts , haplotypes , ** kwargs ):
179
- assert len (samples ) == len (haplotypes )
180
- G = np .array (haplotypes ).T
178
+ def match_tsinfer (self , samples , ts , ** kwargs ):
181
179
sc2ts .inference .match_tsinfer (
182
- samples = samples , ts = ts , genotypes = G , num_mismatches = 1000 , ** kwargs
180
+ samples = samples , ts = ts , num_mismatches = 1000 , ** kwargs
183
181
)
184
182
185
183
@pytest .mark .parametrize ("mirror" , [False , True ])
@@ -189,10 +187,11 @@ def test_match_reference(self, mirror):
189
187
tables .sites .truncate (20 )
190
188
ts = tables .tree_sequence ()
191
189
samples = util .get_samples (ts , [[(0 , ts .sequence_length , 1 )]])
192
- samples [ 0 ]. alignment = sc2ts .core .get_reference_sequence ()
193
- ma = sc2ts .alignments .encode_and_mask (samples [ 0 ]. alignment )
190
+ alignment = sc2ts .core .get_reference_sequence ()
191
+ ma = sc2ts .alignments .encode_and_mask (alignment )
194
192
h = ma .alignment [ts .sites_position .astype (int )]
195
- self .match_tsinfer (samples , ts , [h ], mirror_coordinates = mirror )
193
+ samples [0 ].alignment = h
194
+ self .match_tsinfer (samples , ts , mirror_coordinates = mirror )
196
195
assert samples [0 ].breakpoints == [0 , ts .sequence_length ]
197
196
assert samples [0 ].parents == [ts .num_nodes - 1 ]
198
197
assert len (samples [0 ].mutations ) == 0
@@ -205,12 +204,13 @@ def test_match_reference_one_mutation(self, mirror, site_id):
205
204
tables .sites .truncate (20 )
206
205
ts = tables .tree_sequence ()
207
206
samples = util .get_samples (ts , [[(0 , ts .sequence_length , 1 )]])
208
- samples [ 0 ]. alignment = sc2ts .core .get_reference_sequence ()
209
- ma = sc2ts .alignments .encode_and_mask (samples [ 0 ]. alignment )
207
+ alignment = sc2ts .core .get_reference_sequence ()
208
+ ma = sc2ts .alignments .encode_and_mask (alignment )
210
209
h = ma .alignment [ts .sites_position .astype (int )]
211
210
# Mutate to gap
212
211
h [site_id ] = sc2ts .core .ALLELES .index ("-" )
213
- self .match_tsinfer (samples , ts , [h ], mirror_coordinates = mirror )
212
+ samples [0 ].alignment = h
213
+ self .match_tsinfer (samples , ts , mirror_coordinates = mirror )
214
214
assert samples [0 ].breakpoints == [0 , ts .sequence_length ]
215
215
assert samples [0 ].parents == [ts .num_nodes - 1 ]
216
216
assert len (samples [0 ].mutations ) == 1
@@ -230,11 +230,12 @@ def test_match_reference_all_same(self, mirror, allele):
230
230
tables .sites .truncate (20 )
231
231
ts = tables .tree_sequence ()
232
232
samples = util .get_samples (ts , [[(0 , ts .sequence_length , 1 )]])
233
- samples [ 0 ]. alignment = sc2ts .core .get_reference_sequence ()
234
- ma = sc2ts .alignments .encode_and_mask (samples [ 0 ]. alignment )
233
+ alignment = sc2ts .core .get_reference_sequence ()
234
+ ma = sc2ts .alignments .encode_and_mask (alignment )
235
235
ref = ma .alignment [ts .sites_position .astype (int )]
236
236
h = np .zeros_like (ref ) + allele
237
- self .match_tsinfer (samples , ts , [h ], mirror_coordinates = mirror )
237
+ samples [0 ].alignment = h
238
+ self .match_tsinfer (samples , ts , mirror_coordinates = mirror )
238
239
assert samples [0 ].breakpoints == [0 , ts .sequence_length ]
239
240
assert samples [0 ].parents == [ts .num_nodes - 1 ]
240
241
muts = samples [0 ].mutations
0 commit comments