Skip to content

Commit 5c80d4f

Browse files
Merge pull request #244 from jeromekelleher/add_ref_as_sample
Add reference as sample, and include reference sequence in ts
2 parents 8739cad + c7ff5ce commit 5c80d4f

File tree

4 files changed

+41
-14
lines changed

4 files changed

+41
-14
lines changed

sc2ts/core.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,16 @@ def get_problematic_sites():
7979
__cached_reference = None
8080

8181

82-
def get_reference_sequence():
82+
def get_reference_sequence(as_array=False):
8383
global __cached_reference
8484
if __cached_reference is None:
85-
reader = FastaReader(data_path / "reference.fasta")
85+
reader = pyfaidx.Fasta(str(data_path / "reference.fasta"))
8686
__cached_reference = reader[REFERENCE_GENBANK]
87-
return __cached_reference
87+
if as_array:
88+
h = np.array(__cached_reference).astype(str)
89+
return np.append(["X"], h)
90+
else:
91+
return "X" + str(__cached_reference)
8892

8993

9094
__cached_genes = None

sc2ts/inference.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import datetime
55
import dataclasses
66
import collections
7+
import json
78
import pickle
89
import os
910
import sqlite3
@@ -220,6 +221,12 @@ def initial_ts():
220221
tables = tskit.TableCollection(L)
221222
tables.time_units = core.TIME_UNITS
222223
base_schema = tskit.MetadataSchema.permissive_json().schema
224+
tables.reference_sequence.metadata_schema = tskit.MetadataSchema(base_schema)
225+
tables.reference_sequence.metadata = {
226+
"genbank_id": core.REFERENCE_GENBANK,
227+
"notes": "X prepended to alignment to map from 1-based to 0-based coordinates"
228+
}
229+
tables.reference_sequence.data = reference
223230

224231
tables.metadata_schema = tskit.MetadataSchema(base_schema)
225232

@@ -235,10 +242,10 @@ def initial_ts():
235242
tables.sites.add_row(pos, reference[pos], metadata={"masked_samples": 0})
236243
# TODO should probably make the ultimate ancestor time something less
237244
# plausible or at least configurable. However, this will be removed
238-
# in later versions when we remove the dependence on tskit.
245+
# in later versions when we remove the dependence on tsinfer.
239246
tables.nodes.add_row(time=1, metadata={"strain": "Vestigial_ignore"})
240247
tables.nodes.add_row(
241-
time=0, metadata={"strain": core.REFERENCE_STRAIN, "date": core.REFERENCE_DATE}
248+
flags=tskit.NODE_IS_SAMPLE, time=0, metadata={"strain": core.REFERENCE_STRAIN, "date": core.REFERENCE_DATE}
242249
)
243250
tables.edges.add_row(0, L, 0, 1)
244251
return tables.tree_sequence()

tests/test_inference.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,22 @@
88
import util
99

1010

11+
class TestInitialTs:
12+
def test_reference_sequence(self):
13+
ts = sc2ts.initial_ts()
14+
assert ts.reference_sequence.metadata["genbank_id"] == "MN908947"
15+
assert ts.reference_sequence.data == sc2ts.core.get_reference_sequence()
16+
17+
def test_reference_sample(self):
18+
ts = sc2ts.initial_ts()
19+
assert ts.num_samples == 1
20+
node = ts.node(ts.samples()[0])
21+
assert node.time == 0
22+
assert node.metadata == {"date": "2019-12-26", "strain": "Wuhan/Hu-1/2019"}
23+
alignment = next(ts.alignments())
24+
assert alignment == sc2ts.core.get_reference_sequence()
25+
26+
1127
class TestAddMatchingResults:
1228
def add_matching_results(
1329
self,
@@ -132,7 +148,7 @@ def test_one_sample_one_mutation(self, tmp_path):
132148
assert ts2.site(0).ancestral_state == ts.site(0).ancestral_state
133149
assert ts2.num_mutations == 1
134150
var = next(ts2.variants())
135-
assert var.alleles[var.genotypes[0]] == "X"
151+
assert var.alleles[var.genotypes[1]] == "X"
136152

137153
def test_one_sample_one_mutation_filtered(self, tmp_path):
138154
ts = sc2ts.initial_ts()
@@ -171,7 +187,7 @@ def test_two_samples_one_mutation_one_filtered(self, tmp_path):
171187
assert ts2.site(0).ancestral_state == ts.site(0).ancestral_state
172188
assert ts2.num_mutations == 1
173189
var = next(ts2.variants())
174-
assert var.alleles[var.genotypes[0]] == "X"
190+
assert var.alleles[var.genotypes[1]] == "X"
175191

176192

177193
class TestMatchTsinfer:
@@ -187,7 +203,7 @@ def test_match_reference(self, mirror):
187203
tables.sites.truncate(20)
188204
ts = tables.tree_sequence()
189205
samples = util.get_samples(ts, [[(0, ts.sequence_length, 1)]])
190-
alignment = sc2ts.core.get_reference_sequence()
206+
alignment = sc2ts.core.get_reference_sequence(as_array=True)
191207
ma = sc2ts.alignments.encode_and_mask(alignment)
192208
h = ma.alignment[ts.sites_position.astype(int)]
193209
samples[0].alignment = h
@@ -204,7 +220,7 @@ def test_match_reference_one_mutation(self, mirror, site_id):
204220
tables.sites.truncate(20)
205221
ts = tables.tree_sequence()
206222
samples = util.get_samples(ts, [[(0, ts.sequence_length, 1)]])
207-
alignment = sc2ts.core.get_reference_sequence()
223+
alignment = sc2ts.core.get_reference_sequence(as_array=True)
208224
ma = sc2ts.alignments.encode_and_mask(alignment)
209225
h = ma.alignment[ts.sites_position.astype(int)]
210226
# Mutate to gap
@@ -230,7 +246,7 @@ def test_match_reference_all_same(self, mirror, allele):
230246
tables.sites.truncate(20)
231247
ts = tables.tree_sequence()
232248
samples = util.get_samples(ts, [[(0, ts.sequence_length, 1)]])
233-
alignment = sc2ts.core.get_reference_sequence()
249+
alignment = sc2ts.core.get_reference_sequence(as_array=True)
234250
ma = sc2ts.alignments.encode_and_mask(alignment)
235251
ref = ma.alignment[ts.sites_position.astype(int)]
236252
h = np.zeros_like(ref) + allele
@@ -251,7 +267,7 @@ def match_path_ts(self, samples, ts):
251267
# FIXME this API is terrible
252268
ts2 = sc2ts.match_path_ts(samples, ts, samples[0].path, [])
253269
assert ts2.num_samples == len(samples)
254-
for u, sample in zip(ts.samples(), samples):
270+
for u, sample in zip(ts.samples()[1:], samples):
255271
node = ts.node(u)
256272
assert node.time == 0
257273
assert node.metadata == sample.metadata

tests/test_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
class TestPadSites:
1010
def check_site_padding(self, ts):
1111
ts = sc2ts.utils.pad_sites(ts)
12-
ref = sc2ts.core.get_reference_sequence()
12+
ref = sc2ts.core.get_reference_sequence(as_array=True)
1313
assert ts.num_sites == len(ref) - 1
1414
ancestral_state = ts.tables.sites.ancestral_state.view("S1").astype(str)
1515
assert np.all(ancestral_state == ref[1:])
@@ -67,7 +67,7 @@ def test_no_recombinants(self, ts, tmp_path):
6767

6868
def test_one_sample_recombinant(self, tmp_path):
6969
ts = self.make_recombinant_tree(tmp_path / "match.db")
70-
assert ts.num_samples == 3
70+
assert ts.num_samples == 4
7171
re_nodes = [
7272
node.id for node in ts.nodes() if node.flags & sc2ts.NODE_IS_RECOMBINANT
7373
]
@@ -93,6 +93,6 @@ def test_one_sample_recombinant(self, tmp_path):
9393
def test_two_sample_recombinant(self, tmp_path):
9494
"""Test that we don't detach anything if the recombinant node is not a singleton"""
9595
ts = self.make_recombinant_tree(num_samples=2, db_path=tmp_path / "match.db")
96-
assert ts.num_samples == 4
96+
assert ts.num_samples == 5
9797
ts2 = utils.detach_singleton_recombinants(ts)
9898
ts.tables.assert_equals(ts2.tables, ignore_provenance=True)

0 commit comments

Comments
 (0)