Skip to content

Commit cca80b9

Browse files
Various updates getting inference working again
Closes #218 Closes #223
1 parent 5c80d4f commit cca80b9

File tree

1 file changed

+38
-46
lines changed

1 file changed

+38
-46
lines changed

sc2ts/inference.py

Lines changed: 38 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,7 @@ def create_mask_table(self, ts):
110110
# the rows in the DB that *are* in the ts, as a separate
111111
# transaction once we know that the trees have been saved to disk.
112112
logger.info("Loading used samples into DB")
113-
# TODO this is inefficient - need some logging to see how much time
114-
# we're spending here.
115-
# One thing we can do is to store the list of strain IDs in the
116-
# tree sequence top-level metadata, which we could even store using
117-
# some numpy tricks to make it fast.
118-
samples = [(ts.node(u).metadata["strain"],) for u in ts.samples()]
113+
samples = [(strain,) for strain in ts.metadata["sc2ts"]["samples_strain"]]
119114
logger.debug(f"Got {len(samples)} from ts")
120115
with self.conn:
121116
self.conn.execute("DROP TABLE IF EXISTS used_samples")
@@ -224,11 +219,17 @@ def initial_ts():
224219
tables.reference_sequence.metadata_schema = tskit.MetadataSchema(base_schema)
225220
tables.reference_sequence.metadata = {
226221
"genbank_id": core.REFERENCE_GENBANK,
227-
"notes": "X prepended to alignment to map from 1-based to 0-based coordinates"
222+
"notes": "X prepended to alignment to map from 1-based to 0-based coordinates",
228223
}
229224
tables.reference_sequence.data = reference
230225

231226
tables.metadata_schema = tskit.MetadataSchema(base_schema)
227+
tables.metadata = {
228+
"sc2ts": {
229+
"date": core.REFERENCE_DATE,
230+
"samples_strain": [core.REFERENCE_STRAIN],
231+
}
232+
}
232233

233234
# TODO gene annotations to top level
234235
# TODO add known fields to the schemas and document them.
@@ -245,7 +246,9 @@ def initial_ts():
245246
# in later versions when we remove the dependence on tsinfer.
246247
tables.nodes.add_row(time=1, metadata={"strain": "Vestigial_ignore"})
247248
tables.nodes.add_row(
248-
flags=tskit.NODE_IS_SAMPLE, time=0, metadata={"strain": core.REFERENCE_STRAIN, "date": core.REFERENCE_DATE}
249+
flags=tskit.NODE_IS_SAMPLE,
250+
time=0,
251+
metadata={"strain": core.REFERENCE_STRAIN, "date": core.REFERENCE_DATE},
249252
)
250253
tables.edges.add_row(0, L, 0, 1)
251254
return tables.tree_sequence()
@@ -255,42 +258,8 @@ def parse_date(date):
255258
return datetime.datetime.fromisoformat(date)
256259

257260

258-
def filter_samples(samples, alignment_store, max_submission_delay=None):
259-
if max_submission_delay is None:
260-
max_submission_delay = 10**8 # Arbitrary large number of days.
261-
not_in_store = 0
262-
num_filtered = 0
263-
ret = []
264-
for sample in samples:
265-
if sample.strain not in alignment_store:
266-
logger.warn(f"{sample.strain} not in alignment store")
267-
not_in_store += 1
268-
continue
269-
if sample.submission_delay < max_submission_delay:
270-
ret.append(sample)
271-
else:
272-
num_filtered += 1
273-
if not_in_store == len(samples):
274-
raise ValueError("All samples for day missing")
275-
logger.info(
276-
f"Filtered {num_filtered} samples with "
277-
f"max_submission_delay >= {max_submission_delay}"
278-
)
279-
return ret
280-
281-
282261
def last_date(ts):
283-
if ts.num_samples == 0:
284-
# Special case for the initial ts which contains the
285-
# reference but not as a sample
286-
u = ts.num_nodes - 1
287-
node = ts.node(u)
288-
# assert node.time == 0
289-
return parse_date(node.metadata["date"])
290-
else:
291-
samples = ts.samples()
292-
samples_t0 = samples[ts.nodes_time[samples] == 0]
293-
return max([parse_date(ts.node(u).metadata["date"]) for u in samples_t0])
262+
return parse_date(ts.metadata["sc2ts"]["date"])
294263

295264

296265
def increment_time(date, ts):
@@ -562,6 +531,14 @@ def match_samples(
562531
match_db.add(samples, date, num_mismatches)
563532

564533

534+
def check_base_ts(ts):
535+
md = ts.metadata
536+
assert "sc2ts" in md
537+
sc2ts_md = md["sc2ts"]
538+
assert "date" in sc2ts_md
539+
assert len(sc2ts_md["samples_strain"]) == ts.num_samples
540+
541+
565542
def extend(
566543
*,
567544
alignment_store,
@@ -579,9 +556,10 @@ def extend(
579556
precision=None,
580557
rng=None,
581558
):
559+
check_base_ts(base_ts)
582560
logger.info(
583-
f"Extend {date}; ts:nodes={base_ts.num_nodes};edges={base_ts.num_edges};"
584-
f"mutations={base_ts.num_mutations}"
561+
f"Extend {date}; ts:nodes={base_ts.num_nodes};samples={base_ts.num_samples};"
562+
f"mutations={base_ts.num_mutations};date={base_ts.metadata['sc2ts']['date']}"
585563
)
586564
# TODO not sure whether we'll keep these params. Making sure they're not
587565
# used for now
@@ -640,7 +618,21 @@ def extend(
640618
min_group_size=min_group_size,
641619
show_progress=show_progress,
642620
)
643-
return ts
621+
return update_top_level_metadata(ts, date)
622+
623+
624+
def update_top_level_metadata(ts, date):
625+
tables = ts.dump_tables()
626+
md = tables.metadata
627+
md["sc2ts"]["date"] = date
628+
samples_strain = md["sc2ts"]["samples_strain"]
629+
new_samples = ts.samples()[len(samples_strain) :]
630+
for u in new_samples:
631+
node = ts.node(u)
632+
samples_strain.append(node.metadata["strain"])
633+
md["sc2ts"]["samples_strain"] = samples_strain
634+
tables.metadata = md
635+
return tables.tree_sequence()
644636

645637

646638
def match_path_ts(samples, ts, path, reversions):

0 commit comments

Comments
 (0)