Skip to content

Commit c7ff5ce

Browse files
Patch up tests for changed base_ts
1 parent 7bd4356 commit c7ff5ce

File tree

3 files changed

+16
-12
lines changed

3 files changed

+16
-12
lines changed

sc2ts/core.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,16 @@ def get_problematic_sites():
7979
__cached_reference = None
8080

8181

82-
def get_reference_sequence():
82+
def get_reference_sequence(as_array=False):
8383
global __cached_reference
8484
if __cached_reference is None:
8585
reader = pyfaidx.Fasta(str(data_path / "reference.fasta"))
86-
__cached_reference = "X" + str(reader[REFERENCE_GENBANK])
87-
return __cached_reference
86+
__cached_reference = reader[REFERENCE_GENBANK]
87+
if as_array:
88+
h = np.array(__cached_reference).astype(str)
89+
return np.append(["X"], h)
90+
else:
91+
return "X" + str(__cached_reference)
8892

8993

9094
__cached_genes = None

tests/test_inference.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def test_one_sample_one_mutation(self, tmp_path):
148148
assert ts2.site(0).ancestral_state == ts.site(0).ancestral_state
149149
assert ts2.num_mutations == 1
150150
var = next(ts2.variants())
151-
assert var.alleles[var.genotypes[0]] == "X"
151+
assert var.alleles[var.genotypes[1]] == "X"
152152

153153
def test_one_sample_one_mutation_filtered(self, tmp_path):
154154
ts = sc2ts.initial_ts()
@@ -187,7 +187,7 @@ def test_two_samples_one_mutation_one_filtered(self, tmp_path):
187187
assert ts2.site(0).ancestral_state == ts.site(0).ancestral_state
188188
assert ts2.num_mutations == 1
189189
var = next(ts2.variants())
190-
assert var.alleles[var.genotypes[0]] == "X"
190+
assert var.alleles[var.genotypes[1]] == "X"
191191

192192

193193
class TestMatchTsinfer:
@@ -203,7 +203,7 @@ def test_match_reference(self, mirror):
203203
tables.sites.truncate(20)
204204
ts = tables.tree_sequence()
205205
samples = util.get_samples(ts, [[(0, ts.sequence_length, 1)]])
206-
alignment = sc2ts.core.get_reference_sequence()
206+
alignment = sc2ts.core.get_reference_sequence(as_array=True)
207207
ma = sc2ts.alignments.encode_and_mask(alignment)
208208
h = ma.alignment[ts.sites_position.astype(int)]
209209
samples[0].alignment = h
@@ -220,7 +220,7 @@ def test_match_reference_one_mutation(self, mirror, site_id):
220220
tables.sites.truncate(20)
221221
ts = tables.tree_sequence()
222222
samples = util.get_samples(ts, [[(0, ts.sequence_length, 1)]])
223-
alignment = sc2ts.core.get_reference_sequence()
223+
alignment = sc2ts.core.get_reference_sequence(as_array=True)
224224
ma = sc2ts.alignments.encode_and_mask(alignment)
225225
h = ma.alignment[ts.sites_position.astype(int)]
226226
# Mutate to gap
@@ -246,7 +246,7 @@ def test_match_reference_all_same(self, mirror, allele):
246246
tables.sites.truncate(20)
247247
ts = tables.tree_sequence()
248248
samples = util.get_samples(ts, [[(0, ts.sequence_length, 1)]])
249-
alignment = sc2ts.core.get_reference_sequence()
249+
alignment = sc2ts.core.get_reference_sequence(as_array=True)
250250
ma = sc2ts.alignments.encode_and_mask(alignment)
251251
ref = ma.alignment[ts.sites_position.astype(int)]
252252
h = np.zeros_like(ref) + allele
@@ -267,7 +267,7 @@ def match_path_ts(self, samples, ts):
267267
# FIXME this API is terrible
268268
ts2 = sc2ts.match_path_ts(samples, ts, samples[0].path, [])
269269
assert ts2.num_samples == len(samples)
270-
for u, sample in zip(ts.samples(), samples):
270+
for u, sample in zip(ts.samples()[1:], samples):
271271
node = ts.node(u)
272272
assert node.time == 0
273273
assert node.metadata == sample.metadata

tests/test_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
class TestPadSites:
1010
def check_site_padding(self, ts):
1111
ts = sc2ts.utils.pad_sites(ts)
12-
ref = sc2ts.core.get_reference_sequence()
12+
ref = sc2ts.core.get_reference_sequence(as_array=True)
1313
assert ts.num_sites == len(ref) - 1
1414
ancestral_state = ts.tables.sites.ancestral_state.view("S1").astype(str)
1515
assert np.all(ancestral_state == ref[1:])
@@ -67,7 +67,7 @@ def test_no_recombinants(self, ts, tmp_path):
6767

6868
def test_one_sample_recombinant(self, tmp_path):
6969
ts = self.make_recombinant_tree(tmp_path / "match.db")
70-
assert ts.num_samples == 3
70+
assert ts.num_samples == 4
7171
re_nodes = [
7272
node.id for node in ts.nodes() if node.flags & sc2ts.NODE_IS_RECOMBINANT
7373
]
@@ -93,6 +93,6 @@ def test_one_sample_recombinant(self, tmp_path):
9393
def test_two_sample_recombinant(self, tmp_path):
9494
"""Test that we don't detach anything if the recombinant node is not a singleton"""
9595
ts = self.make_recombinant_tree(num_samples=2, db_path=tmp_path / "match.db")
96-
assert ts.num_samples == 4
96+
assert ts.num_samples == 5
9797
ts2 = utils.detach_singleton_recombinants(ts)
9898
ts.tables.assert_equals(ts2.tables, ignore_provenance=True)

0 commit comments

Comments
 (0)