Skip to content

Commit 5876424

Browse files
Batch of updates for running on full viridian data
1 parent 81a8bc6 commit 5876424

File tree

1 file changed

+90
-53
lines changed

1 file changed

+90
-53
lines changed

sc2ts/inference.py

Lines changed: 90 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def add(self, samples, date, num_mismatches):
6464
for j, sample in enumerate(samples):
6565
d = sample.asdict()
6666
assert sample.date == date
67+
# FIXME we want to be more selective about what we're storing
68+
# here, as we're including the alignment too.
6769
pkl = pickle.dumps(sample)
6870
# BZ2 compressing drops this by ~10X, so worth it.
6971
pkl_compressed = bz2.compress(pkl)
@@ -75,6 +77,10 @@ def add(self, samples, date, num_mismatches):
7577
pkl_compressed,
7678
)
7779
data.append(args)
80+
pango = sample.metadata.get("Viridian_pangolin", "Unknown")
81+
logger.debug(
82+
f"MatchDB insert: {sample.strain} {date} {pango} hmm_cost={hmm_cost[j]}"
83+
)
7884
# Batch insert, for efficiency.
7985
with self.conn:
8086
self.conn.executemany(sql, data)
@@ -124,7 +130,11 @@ def get(self, where_clause):
124130
for row in self.conn.execute(sql):
125131
pkl = row.pop("pickle")
126132
sample = pickle.loads(bz2.decompress(pkl))
127-
logger.debug(f"MatchDb got: {row}")
133+
pango = sample.metadata.get("Viridian_pangolin", "Unknown")
134+
logger.debug(
135+
f"MatchDb got: {sample.strain} {sample.date} {pango} "
136+
f"hmm_cost={row['hmm_cost']}"
137+
)
128138
# print(row)
129139
yield sample
130140

@@ -149,19 +159,20 @@ def initialise(db_path):
149159
)
150160
return MatchDb(db_path)
151161

152-
153162
def print_all(self):
154163
"""
155164
Debug method to print out full state of the DB.
156165
"""
157166
import pandas as pd
167+
158168
data = []
159169
with self.conn:
160170
for row in self.conn.execute("SELECT * from samples"):
161171
data.append(row)
162172
df = pd.DataFrame(row, index=["strain"])
163173
print(df)
164174

175+
165176
def mirror(x, L):
166177
return L - x
167178

@@ -253,7 +264,7 @@ def last_date(ts):
253264
# reference but not as a sample
254265
u = ts.num_nodes - 1
255266
node = ts.node(u)
256-
assert node.time == 0
267+
# assert node.time == 0
257268
return parse_date(node.metadata["date"])
258269
else:
259270
samples = ts.samples()
@@ -336,6 +347,7 @@ class Sample:
336347
mutations: List = dataclasses.field(default_factory=list)
337348
alignment_qc: Dict = dataclasses.field(default_factory=dict)
338349
masked_sites: List = dataclasses.field(default_factory=list)
350+
alignment: List = None
339351

340352
# def __repr__(self):
341353
# return self.strain
@@ -352,18 +364,6 @@ def breakpoints(self):
352364
def parents(self):
353365
return [seg.parent for seg in self.path]
354366

355-
# @property
356-
# def date(self):
357-
# return parse_date(self.metadata["date"])
358-
359-
# @property
360-
# def submission_date(self):
361-
# return parse_date(self.metadata["date_submitted"])
362-
363-
# @property
364-
# def submission_delay(self):
365-
# return (self.submission_date - self.date).days
366-
367367
def get_hmm_cost(self, num_mismatches):
368368
# Note that Recombinant objects have total_cost.
369369
# This bit of code is sort of repeated.
@@ -424,70 +424,84 @@ def daily_extend(
424424
last_ts = ts
425425

426426

427-
def preprocess_and_match_alignments(
427+
def preprocess(
428428
date,
429429
*,
430+
base_ts,
430431
metadata_db,
431432
alignment_store,
432-
match_db,
433-
base_ts,
434-
num_mismatches=None,
435-
show_progress=False,
436-
num_threads=None,
437-
precision=None,
438433
max_daily_samples=None,
439-
mirror_coordinates=False,
434+
show_progress=False,
440435
):
441-
if num_mismatches is None:
442-
# Default to no recombination
443-
num_mismatches = 1000
444-
445436
samples = []
446-
for md in metadata_db.get(date):
447-
samples.append(Sample(md["strain"], md["date"], md))
448-
if len(samples) == 0:
449-
logger.warn(f"Zero samples for {date}")
450-
return
437+
metadata_matches = list(metadata_db.get(date))
438+
439+
if len(metadata_matches) == 0:
440+
logger.warn(f"Zero metadata matches for {date}")
441+
return []
442+
443+
if date.endswith("01-01"):
444+
logger.warning(f"Skipping {len(metadata_matches)} samples for {date}")
445+
return []
446+
451447
# TODO implement this.
452448
assert max_daily_samples is None
453449

454-
# Note: there's not a lot of point in making the G matrix here,
455-
# we should just pass on the encoded alignments to the matching
456-
# algorithm directly through the Sample class, and let it
457-
# do the low-level haplotype storage.
458-
G = np.zeros((base_ts.num_sites, len(samples)), dtype=np.int8)
459450
keep_sites = base_ts.sites_position.astype(int)
460451
problematic_sites = core.get_problematic_sites()
452+
samples = []
461453

462-
samples_iter = enumerate(samples)
463454
with tqdm.tqdm(
464-
samples_iter,
465-
desc=f"Fetch:{date}",
466-
total=len(samples),
455+
metadata_matches,
456+
desc=f"Preprocess:{date}",
467457
disable=not show_progress,
468458
) as bar:
469-
for j, sample in bar:
470-
logger.debug(f"Getting alignment for {sample.strain}")
471-
alignment = alignment_store[sample.strain]
472-
sample.alignment = alignment
473-
logger.debug("Encoding alignment")
459+
for md in bar:
460+
strain = md["strain"]
461+
logger.debug(f"Getting alignment for {strain}")
462+
try:
463+
alignment = alignment_store[strain]
464+
except KeyError:
465+
logger.debug(f"No alignment stored for {strain}")
466+
continue
467+
468+
sample = Sample(strain, date, metadata=md)
474469
ma = alignments.encode_and_mask(alignment)
475470
# Always mask the problematic_sites as well. We need to do this
476471
# for follow-up matching to inspect recombinants, as tsinfer
477472
# needs us to keep all sites in the table when doing mirrored
478473
# coordinates.
479474
ma.alignment[problematic_sites] = -1
480-
G[:, j] = ma.alignment[keep_sites]
481475
sample.alignment_qc = ma.qc_summary()
482476
sample.masked_sites = ma.masked_sites
477+
sample.alignment = ma.alignment[keep_sites]
478+
samples.append(sample)
483479

484-
masked_per_sample = np.mean([len(sample.masked_sites)])
485-
logger.info(f"Masked average of {masked_per_sample:.2f} nucleotides per sample")
480+
logger.info(
481+
f"Got alignments for {len(samples)} of {len(metadata_matches)} in metadata"
482+
)
483+
return samples
484+
485+
486+
def match_samples(
487+
date,
488+
samples,
489+
*,
490+
match_db,
491+
base_ts,
492+
num_mismatches=None,
493+
show_progress=False,
494+
num_threads=None,
495+
precision=None,
496+
mirror_coordinates=False,
497+
):
498+
if num_mismatches is None:
499+
# Default to no recombination
500+
num_mismatches = 1000
486501

487502
match_tsinfer(
488503
samples=samples,
489504
ts=base_ts,
490-
genotypes=G,
491505
num_mismatches=num_mismatches,
492506
precision=precision,
493507
num_threads=num_threads,
@@ -515,21 +529,36 @@ def extend(
515529
precision=None,
516530
rng=None,
517531
):
532+
logger.info(
533+
f"Extend {date}; ts:nodes={base_ts.num_nodes};edges={base_ts.num_edges};"
534+
f"mutations={base_ts.num_mutations}"
535+
)
518536
# TODO not sure whether we'll keep these params. Making sure they're not
519537
# used for now
520538
assert max_submission_delay is None
521539

522-
preprocess_and_match_alignments(
540+
samples = preprocess(
523541
date,
524542
metadata_db=metadata_db,
525543
alignment_store=alignment_store,
526544
base_ts=base_ts,
545+
max_daily_samples=max_daily_samples,
546+
show_progress=show_progress,
547+
)
548+
549+
if len(samples) == 0:
550+
logger.warning(f"Nothing to do for {date}")
551+
return base_ts
552+
553+
match_samples(
554+
date,
555+
samples,
556+
base_ts=base_ts,
527557
match_db=match_db,
528558
num_mismatches=num_mismatches,
529559
show_progress=show_progress,
530560
num_threads=num_threads,
531561
precision=precision,
532-
max_daily_samples=max_daily_samples,
533562
)
534563

535564
match_db.create_mask_table(base_ts)
@@ -574,6 +603,10 @@ def match_path_ts(samples, ts, path, reversions):
574603
path = samples[0].path
575604
site_id_map = {}
576605
first_sample = len(tables.nodes)
606+
logger.debug(
607+
f"Adding group of {len(samples)} with path={path} and "
608+
f"reversions={reversions}"
609+
)
577610
for sample in samples:
578611
assert sample.path == path
579612
metadata = {
@@ -596,6 +629,10 @@ def match_path_ts(samples, ts, path, reversions):
596629
# Now add the mutations
597630
for node_id, sample in enumerate(samples, first_sample):
598631
# metadata = {**sample.metadata, "sc2ts_qc": sample.alignment_qc}
632+
logger.debug(
633+
f"Adding {sample.strain}:{sample.date} with "
634+
f"{len(sample.mutations)} mutations"
635+
)
599636
for mut in sample.mutations:
600637
tables.mutations.add_row(
601638
site=site_id_map[mut.site_id],
@@ -1210,14 +1247,14 @@ def resize_copy(array, new_size):
12101247
def match_tsinfer(
12111248
samples,
12121249
ts,
1213-
genotypes,
12141250
*,
12151251
num_mismatches,
12161252
precision=None,
12171253
num_threads=0,
12181254
show_progress=False,
12191255
mirror_coordinates=False,
12201256
):
1257+
genotypes = np.array([sample.alignment for sample in samples], dtype=np.int8).T
12211258
input_ts = ts
12221259
if mirror_coordinates:
12231260
ts = mirror_ts_coordinates(ts)

0 commit comments

Comments
 (0)