Skip to content

Commit 83fcb1a

Browse files
Merge pull request #246 from jeromekelleher/improvements
Improvements
2 parents 5c80d4f + 43017c1 commit 83fcb1a

File tree

3 files changed

+72
-78
lines changed

3 files changed

+72
-78
lines changed

sc2ts/cli.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import datetime
1212
import pickle
1313

14+
import numpy as np
1415
import tqdm
1516
import tskit
1617
import tszip
@@ -22,6 +23,8 @@
2223
from . import core
2324
from . import inference
2425

26+
logger = logging.getLogger(__name__)
27+
2528

2629
def get_environment():
2730
"""
@@ -230,6 +233,12 @@ def dump_samples(samples, output_file):
230233
@click.option("--num-threads", default=0, type=int, help="Number of match threads")
231234
@click.option("--random-seed", default=42, type=int, help="Random seed for subsampling")
232235
@click.option("--stop-date", default="2030-01-01", type=str, help="Stopping date")
236+
@click.option(
237+
"--additional-problematic-sites",
238+
default=None,
239+
type=str,
240+
help="File containing the list of additional problematic sites to exclude.",
241+
)
233242
@click.option("-p", "--precision", default=None, type=int, help="Match precision")
234243
@click.option("--no-progress", default=False, type=bool, help="Don't show progress")
235244
@click.option("-v", "--verbose", count=True)
@@ -248,6 +257,7 @@ def daily_extend(
248257
num_threads,
249258
random_seed,
250259
stop_date,
260+
additional_problematic_sites,
251261
precision,
252262
no_progress,
253263
verbose,
@@ -259,13 +269,27 @@ def daily_extend(
259269
setup_logging(verbose, log_file)
260270
rng = random.Random(random_seed)
261271

272+
additional_problematic = []
273+
if additional_problematic_sites is not None:
274+
additional_problematic = (
275+
np.loadtxt(additional_problematic_sites).astype(int).tolist()
276+
)
277+
logger.info(
278+
f"Excluding additional {len(additional_problematic)} problematic sites"
279+
)
280+
262281
match_db_path = f"{output_prefix}match.db"
263282
if base is None:
264-
base_ts = inference.initial_ts()
283+
base_ts = inference.initial_ts(additional_problematic)
265284
match_db = inference.MatchDb.initialise(match_db_path)
266285
else:
267286
base_ts = tskit.load(base)
268287

288+
assert (
289+
base_ts.metadata["sc2ts"]["additional_problematic_sites"]
290+
== additional_problematic
291+
)
292+
269293
with contextlib.ExitStack() as exit_stack:
270294
alignment_store = exit_stack.enter_context(sc2ts.AlignmentStore(alignments))
271295
metadata_db = exit_stack.enter_context(sc2ts.MetadataDb(metadata))

sc2ts/core.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -50,30 +50,7 @@ def __len__(self):
5050

5151

5252
def get_problematic_sites():
53-
base = np.loadtxt(data_path / "problematic_sites.txt", dtype=np.int64)
54-
# Temporary to try out removing these outliers. See
55-
# https://github.com/jeromekelleher/sc2ts/issues/231#issuecomment-2306665447
56-
# In reality we'd probably want to provide an additional file of extra sites
57-
# to remove.
58-
additional = [
59-
7851,
60-
10323,
61-
11750,
62-
17040,
63-
21137,
64-
21846,
65-
22917,
66-
22995,
67-
26681,
68-
27384,
69-
27638,
70-
27752,
71-
28254,
72-
28271,
73-
29614,
74-
]
75-
full = np.append(base, additional)
76-
return np.sort(full)
53+
return np.loadtxt(data_path / "problematic_sites.txt", dtype=np.int64)
7754

7855

7956
__cached_reference = None

sc2ts/inference.py

Lines changed: 46 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import datetime
55
import dataclasses
66
import collections
7-
import json
87
import pickle
98
import os
109
import sqlite3
@@ -77,7 +76,6 @@ def add(self, samples, date, num_mismatches):
7776
data = []
7877
hmm_cost = np.zeros(len(samples))
7978
for j, sample in enumerate(samples):
80-
d = sample.asdict()
8179
assert sample.date == date
8280
# FIXME we want to be more selective about what we're storing
8381
# here, as we're including the alignment too.
@@ -110,12 +108,7 @@ def create_mask_table(self, ts):
110108
# the rows in the DB that *are* in the ts, as a separate
111109
# transaction once we know that the trees have been saved to disk.
112110
logger.info("Loading used samples into DB")
113-
# TODO this is inefficient - need some logging to see how much time
114-
# we're spending here.
115-
# One thing we can do is to store the list of strain IDs in the
116-
# tree sequence top-level metadata, which we could even store using
117-
# some numpy tricks to make it fast.
118-
samples = [(ts.node(u).metadata["strain"],) for u in ts.samples()]
111+
samples = [(strain,) for strain in ts.metadata["sc2ts"]["samples_strain"]]
119112
logger.debug(f"Got {len(samples)} from ts")
120113
with self.conn:
121114
self.conn.execute("DROP TABLE IF EXISTS used_samples")
@@ -212,26 +205,35 @@ def mirror_ts_coordinates(ts):
212205
return tables.tree_sequence()
213206

214207

215-
def initial_ts():
208+
def initial_ts(additional_problematic_sites=list()):
216209
reference = core.get_reference_sequence()
217210
L = core.REFERENCE_SEQUENCE_LENGTH
218211
assert L == len(reference)
219-
problematic_sites = set(core.get_problematic_sites())
212+
problematic_sites = set(core.get_problematic_sites()) | set(additional_problematic_sites)
220213

221214
tables = tskit.TableCollection(L)
222215
tables.time_units = core.TIME_UNITS
216+
217+
# TODO add known fields to the schemas and document them.
218+
223219
base_schema = tskit.MetadataSchema.permissive_json().schema
224220
tables.reference_sequence.metadata_schema = tskit.MetadataSchema(base_schema)
225221
tables.reference_sequence.metadata = {
226222
"genbank_id": core.REFERENCE_GENBANK,
227-
"notes": "X prepended to alignment to map from 1-based to 0-based coordinates"
223+
"notes": "X prepended to alignment to map from 1-based to 0-based coordinates",
228224
}
229225
tables.reference_sequence.data = reference
230226

231227
tables.metadata_schema = tskit.MetadataSchema(base_schema)
232-
233228
# TODO gene annotations to top level
234-
# TODO add known fields to the schemas and document them.
229+
tables.metadata = {
230+
"sc2ts": {
231+
"date": core.REFERENCE_DATE,
232+
"samples_strain": [core.REFERENCE_STRAIN],
233+
"additional_problematic_sites": additional_problematic_sites,
234+
}
235+
}
236+
235237
tables.nodes.metadata_schema = tskit.MetadataSchema(base_schema)
236238
tables.sites.metadata_schema = tskit.MetadataSchema(base_schema)
237239
tables.mutations.metadata_schema = tskit.MetadataSchema(base_schema)
@@ -245,7 +247,9 @@ def initial_ts():
245247
# in later versions when we remove the dependence on tsinfer.
246248
tables.nodes.add_row(time=1, metadata={"strain": "Vestigial_ignore"})
247249
tables.nodes.add_row(
248-
flags=tskit.NODE_IS_SAMPLE, time=0, metadata={"strain": core.REFERENCE_STRAIN, "date": core.REFERENCE_DATE}
250+
flags=tskit.NODE_IS_SAMPLE,
251+
time=0,
252+
metadata={"strain": core.REFERENCE_STRAIN, "date": core.REFERENCE_DATE},
249253
)
250254
tables.edges.add_row(0, L, 0, 1)
251255
return tables.tree_sequence()
@@ -255,42 +259,8 @@ def parse_date(date):
255259
return datetime.datetime.fromisoformat(date)
256260

257261

258-
def filter_samples(samples, alignment_store, max_submission_delay=None):
259-
if max_submission_delay is None:
260-
max_submission_delay = 10**8 # Arbitrary large number of days.
261-
not_in_store = 0
262-
num_filtered = 0
263-
ret = []
264-
for sample in samples:
265-
if sample.strain not in alignment_store:
266-
logger.warn(f"{sample.strain} not in alignment store")
267-
not_in_store += 1
268-
continue
269-
if sample.submission_delay < max_submission_delay:
270-
ret.append(sample)
271-
else:
272-
num_filtered += 1
273-
if not_in_store == len(samples):
274-
raise ValueError("All samples for day missing")
275-
logger.info(
276-
f"Filtered {num_filtered} samples with "
277-
f"max_submission_delay >= {max_submission_delay}"
278-
)
279-
return ret
280-
281-
282262
def last_date(ts):
283-
if ts.num_samples == 0:
284-
# Special case for the initial ts which contains the
285-
# reference but not as a sample
286-
u = ts.num_nodes - 1
287-
node = ts.node(u)
288-
# assert node.time == 0
289-
return parse_date(node.metadata["date"])
290-
else:
291-
samples = ts.samples()
292-
samples_t0 = samples[ts.nodes_time[samples] == 0]
293-
return max([parse_date(ts.node(u).metadata["date"]) for u in samples_t0])
263+
return parse_date(ts.metadata["sc2ts"]["date"])
294264

295265

296266
def increment_time(date, ts):
@@ -343,7 +313,7 @@ def validate(ts, alignment_store, show_progress=False):
343313
Check that all the samples in the specified tree sequence are correctly
344314
representing the original alignments.
345315
"""
346-
samples = ts.samples()
316+
samples = ts.samples()[1:]
347317
chunk_size = 10**3
348318
offset = 0
349319
num_chunks = ts.num_samples // chunk_size
@@ -562,6 +532,14 @@ def match_samples(
562532
match_db.add(samples, date, num_mismatches)
563533

564534

535+
def check_base_ts(ts):
536+
md = ts.metadata
537+
assert "sc2ts" in md
538+
sc2ts_md = md["sc2ts"]
539+
assert "date" in sc2ts_md
540+
assert len(sc2ts_md["samples_strain"]) == ts.num_samples
541+
542+
565543
def extend(
566544
*,
567545
alignment_store,
@@ -579,9 +557,10 @@ def extend(
579557
precision=None,
580558
rng=None,
581559
):
560+
check_base_ts(base_ts)
582561
logger.info(
583-
f"Extend {date}; ts:nodes={base_ts.num_nodes};edges={base_ts.num_edges};"
584-
f"mutations={base_ts.num_mutations}"
562+
f"Extend {date}; ts:nodes={base_ts.num_nodes};samples={base_ts.num_samples};"
563+
f"mutations={base_ts.num_mutations};date={base_ts.metadata['sc2ts']['date']}"
585564
)
586565
# TODO not sure whether we'll keep these params. Making sure they're not
587566
# used for now
@@ -640,7 +619,21 @@ def extend(
640619
min_group_size=min_group_size,
641620
show_progress=show_progress,
642621
)
643-
return ts
622+
return update_top_level_metadata(ts, date)
623+
624+
625+
def update_top_level_metadata(ts, date):
626+
tables = ts.dump_tables()
627+
md = tables.metadata
628+
md["sc2ts"]["date"] = date
629+
samples_strain = md["sc2ts"]["samples_strain"]
630+
new_samples = ts.samples()[len(samples_strain) :]
631+
for u in new_samples:
632+
node = ts.node(u)
633+
samples_strain.append(node.metadata["strain"])
634+
md["sc2ts"]["samples_strain"] = samples_strain
635+
tables.metadata = md
636+
return tables.tree_sequence()
644637

645638

646639
def match_path_ts(samples, ts, path, reversions):

0 commit comments

Comments
 (0)