Skip to content

Commit c24cd33

Browse files
Merge pull request #217 from jeromekelleher/run-full-viridian
Run full viridian
2 parents f4855f2 + 726f717 commit c24cd33

File tree

3 files changed

+113
-67
lines changed

3 files changed

+113
-67
lines changed

sc2ts/cli.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,12 @@ def setup_logging(verbosity, log_file=None):
7575
# at the console output. For development this is better than having
7676
# to go to the log to see the traceback, but for production it may
7777
# be better to let daiquiri record the errors as well.
78-
daiquiri.setup(level=log_level, outputs=outputs, set_excepthook=False)
78+
daiquiri.setup(outputs=outputs, set_excepthook=False)
79+
# Only show stuff coming from sc2ts. Sometimes it's handy to look
80+
# at the tsinfer logs too, so we could add an option to set its
81+
# levels
82+
logger = logging.getLogger("sc2ts")
83+
logger.setLevel(log_level)
7984

8085

8186
# TODO add options to list keys, dump specific alignments etc

sc2ts/inference.py

Lines changed: 93 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def add(self, samples, date, num_mismatches):
6464
for j, sample in enumerate(samples):
6565
d = sample.asdict()
6666
assert sample.date == date
67+
# FIXME we want to be more selective about what we're storing
68+
# here, as we're including the alignment too.
6769
pkl = pickle.dumps(sample)
6870
# BZ2 compressing drops this by ~10X, so worth it.
6971
pkl_compressed = bz2.compress(pkl)
@@ -75,6 +77,10 @@ def add(self, samples, date, num_mismatches):
7577
pkl_compressed,
7678
)
7779
data.append(args)
80+
pango = sample.metadata.get("Viridian_pangolin", "Unknown")
81+
logger.debug(
82+
f"MatchDB insert: {sample.strain} {date} {pango} hmm_cost={hmm_cost[j]}"
83+
)
7884
# Batch insert, for efficiency.
7985
with self.conn:
8086
self.conn.executemany(sql, data)
@@ -124,7 +130,11 @@ def get(self, where_clause):
124130
for row in self.conn.execute(sql):
125131
pkl = row.pop("pickle")
126132
sample = pickle.loads(bz2.decompress(pkl))
127-
logger.debug(f"MatchDb got: {row}")
133+
pango = sample.metadata.get("Viridian_pangolin", "Unknown")
134+
logger.debug(
135+
f"MatchDb got: {sample.strain} {sample.date} {pango} "
136+
f"hmm_cost={row['hmm_cost']}"
137+
)
128138
# print(row)
129139
yield sample
130140

@@ -149,19 +159,20 @@ def initialise(db_path):
149159
)
150160
return MatchDb(db_path)
151161

152-
153162
def print_all(self):
154163
"""
155164
Debug method to print out full state of the DB.
156165
"""
157166
import pandas as pd
167+
158168
data = []
159169
with self.conn:
160170
for row in self.conn.execute("SELECT * from samples"):
161171
data.append(row)
162172
df = pd.DataFrame(row, index=["strain"])
163173
print(df)
164174

175+
165176
def mirror(x, L):
166177
return L - x
167178

@@ -253,7 +264,7 @@ def last_date(ts):
253264
# reference but not as a sample
254265
u = ts.num_nodes - 1
255266
node = ts.node(u)
256-
assert node.time == 0
267+
# assert node.time == 0
257268
return parse_date(node.metadata["date"])
258269
else:
259270
samples = ts.samples()
@@ -336,6 +347,10 @@ class Sample:
336347
mutations: List = dataclasses.field(default_factory=list)
337348
alignment_qc: Dict = dataclasses.field(default_factory=dict)
338349
masked_sites: List = dataclasses.field(default_factory=list)
350+
# FIXME need a better name for this, as it's a different thing
351+
# the original alignment. Haplotype is probably good, as it's
352+
# what it would be in the tskit/tsinfer world.
353+
alignment: List = None
339354

340355
# def __repr__(self):
341356
# return self.strain
@@ -352,18 +367,6 @@ def breakpoints(self):
352367
def parents(self):
353368
return [seg.parent for seg in self.path]
354369

355-
# @property
356-
# def date(self):
357-
# return parse_date(self.metadata["date"])
358-
359-
# @property
360-
# def submission_date(self):
361-
# return parse_date(self.metadata["date_submitted"])
362-
363-
# @property
364-
# def submission_delay(self):
365-
# return (self.submission_date - self.date).days
366-
367370
def get_hmm_cost(self, num_mismatches):
368371
# Note that Recombinant objects have total_cost.
369372
# This bit of code is sort of repeated.
@@ -424,70 +427,84 @@ def daily_extend(
424427
last_ts = ts
425428

426429

427-
def preprocess_and_match_alignments(
430+
def preprocess(
428431
date,
429432
*,
433+
base_ts,
430434
metadata_db,
431435
alignment_store,
432-
match_db,
433-
base_ts,
434-
num_mismatches=None,
435-
show_progress=False,
436-
num_threads=None,
437-
precision=None,
438436
max_daily_samples=None,
439-
mirror_coordinates=False,
437+
show_progress=False,
440438
):
441-
if num_mismatches is None:
442-
# Default to no recombination
443-
num_mismatches = 1000
444-
445439
samples = []
446-
for md in metadata_db.get(date):
447-
samples.append(Sample(md["strain"], md["date"], md))
448-
if len(samples) == 0:
449-
logger.warn(f"Zero samples for {date}")
450-
return
440+
metadata_matches = list(metadata_db.get(date))
441+
442+
if len(metadata_matches) == 0:
443+
logger.warn(f"Zero metadata matches for {date}")
444+
return []
445+
446+
if date.endswith("01-01"):
447+
logger.warning(f"Skipping {len(metadata_matches)} samples for {date}")
448+
return []
449+
451450
# TODO implement this.
452451
assert max_daily_samples is None
453452

454-
# Note: there's not a lot of point in making the G matrix here,
455-
# we should just pass on the encoded alignments to the matching
456-
# algorithm directly through the Sample class, and let it
457-
# do the low-level haplotype storage.
458-
G = np.zeros((base_ts.num_sites, len(samples)), dtype=np.int8)
459453
keep_sites = base_ts.sites_position.astype(int)
460454
problematic_sites = core.get_problematic_sites()
455+
samples = []
461456

462-
samples_iter = enumerate(samples)
463457
with tqdm.tqdm(
464-
samples_iter,
465-
desc=f"Fetch:{date}",
466-
total=len(samples),
458+
metadata_matches,
459+
desc=f"Preprocess:{date}",
467460
disable=not show_progress,
468461
) as bar:
469-
for j, sample in bar:
470-
logger.debug(f"Getting alignment for {sample.strain}")
471-
alignment = alignment_store[sample.strain]
472-
sample.alignment = alignment
473-
logger.debug("Encoding alignment")
462+
for md in bar:
463+
strain = md["strain"]
464+
logger.debug(f"Getting alignment for {strain}")
465+
try:
466+
alignment = alignment_store[strain]
467+
except KeyError:
468+
logger.debug(f"No alignment stored for {strain}")
469+
continue
470+
471+
sample = Sample(strain, date, metadata=md)
474472
ma = alignments.encode_and_mask(alignment)
475473
# Always mask the problematic_sites as well. We need to do this
476474
# for follow-up matching to inspect recombinants, as tsinfer
477475
# needs us to keep all sites in the table when doing mirrored
478476
# coordinates.
479477
ma.alignment[problematic_sites] = -1
480-
G[:, j] = ma.alignment[keep_sites]
481478
sample.alignment_qc = ma.qc_summary()
482479
sample.masked_sites = ma.masked_sites
480+
sample.alignment = ma.alignment[keep_sites]
481+
samples.append(sample)
483482

484-
masked_per_sample = np.mean([len(sample.masked_sites)])
485-
logger.info(f"Masked average of {masked_per_sample:.2f} nucleotides per sample")
483+
logger.info(
484+
f"Got alignments for {len(samples)} of {len(metadata_matches)} in metadata"
485+
)
486+
return samples
487+
488+
489+
def match_samples(
490+
date,
491+
samples,
492+
*,
493+
match_db,
494+
base_ts,
495+
num_mismatches=None,
496+
show_progress=False,
497+
num_threads=None,
498+
precision=None,
499+
mirror_coordinates=False,
500+
):
501+
if num_mismatches is None:
502+
# Default to no recombination
503+
num_mismatches = 1000
486504

487505
match_tsinfer(
488506
samples=samples,
489507
ts=base_ts,
490-
genotypes=G,
491508
num_mismatches=num_mismatches,
492509
precision=precision,
493510
num_threads=num_threads,
@@ -515,21 +532,36 @@ def extend(
515532
precision=None,
516533
rng=None,
517534
):
535+
logger.info(
536+
f"Extend {date}; ts:nodes={base_ts.num_nodes};edges={base_ts.num_edges};"
537+
f"mutations={base_ts.num_mutations}"
538+
)
518539
# TODO not sure whether we'll keep these params. Making sure they're not
519540
# used for now
520541
assert max_submission_delay is None
521542

522-
preprocess_and_match_alignments(
543+
samples = preprocess(
523544
date,
524545
metadata_db=metadata_db,
525546
alignment_store=alignment_store,
526547
base_ts=base_ts,
548+
max_daily_samples=max_daily_samples,
549+
show_progress=show_progress,
550+
)
551+
552+
if len(samples) == 0:
553+
logger.warning(f"Nothing to do for {date}")
554+
return base_ts
555+
556+
match_samples(
557+
date,
558+
samples,
559+
base_ts=base_ts,
527560
match_db=match_db,
528561
num_mismatches=num_mismatches,
529562
show_progress=show_progress,
530563
num_threads=num_threads,
531564
precision=precision,
532-
max_daily_samples=max_daily_samples,
533565
)
534566

535567
match_db.create_mask_table(base_ts)
@@ -574,6 +606,10 @@ def match_path_ts(samples, ts, path, reversions):
574606
path = samples[0].path
575607
site_id_map = {}
576608
first_sample = len(tables.nodes)
609+
logger.debug(
610+
f"Adding group of {len(samples)} with path={path} and "
611+
f"reversions={reversions}"
612+
)
577613
for sample in samples:
578614
assert sample.path == path
579615
metadata = {
@@ -596,6 +632,10 @@ def match_path_ts(samples, ts, path, reversions):
596632
# Now add the mutations
597633
for node_id, sample in enumerate(samples, first_sample):
598634
# metadata = {**sample.metadata, "sc2ts_qc": sample.alignment_qc}
635+
logger.debug(
636+
f"Adding {sample.strain}:{sample.date} with "
637+
f"{len(sample.mutations)} mutations"
638+
)
599639
for mut in sample.mutations:
600640
tables.mutations.add_row(
601641
site=site_id_map[mut.site_id],
@@ -1210,14 +1250,14 @@ def resize_copy(array, new_size):
12101250
def match_tsinfer(
12111251
samples,
12121252
ts,
1213-
genotypes,
12141253
*,
12151254
num_mismatches,
12161255
precision=None,
12171256
num_threads=0,
12181257
show_progress=False,
12191258
mirror_coordinates=False,
12201259
):
1260+
genotypes = np.array([sample.alignment for sample in samples], dtype=np.int8).T
12211261
input_ts = ts
12221262
if mirror_coordinates:
12231263
ts = mirror_ts_coordinates(ts)

tests/test_inference.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,9 @@ def test_two_samples_one_mutation_one_filtered(self, tmp_path):
175175

176176

177177
class TestMatchTsinfer:
178-
def match_tsinfer(self, samples, ts, haplotypes, **kwargs):
179-
assert len(samples) == len(haplotypes)
180-
G = np.array(haplotypes).T
178+
def match_tsinfer(self, samples, ts, **kwargs):
181179
sc2ts.inference.match_tsinfer(
182-
samples=samples, ts=ts, genotypes=G, num_mismatches=1000, **kwargs
180+
samples=samples, ts=ts, num_mismatches=1000, **kwargs
183181
)
184182

185183
@pytest.mark.parametrize("mirror", [False, True])
@@ -189,10 +187,11 @@ def test_match_reference(self, mirror):
189187
tables.sites.truncate(20)
190188
ts = tables.tree_sequence()
191189
samples = util.get_samples(ts, [[(0, ts.sequence_length, 1)]])
192-
samples[0].alignment = sc2ts.core.get_reference_sequence()
193-
ma = sc2ts.alignments.encode_and_mask(samples[0].alignment)
190+
alignment = sc2ts.core.get_reference_sequence()
191+
ma = sc2ts.alignments.encode_and_mask(alignment)
194192
h = ma.alignment[ts.sites_position.astype(int)]
195-
self.match_tsinfer(samples, ts, [h], mirror_coordinates=mirror)
193+
samples[0].alignment = h
194+
self.match_tsinfer(samples, ts, mirror_coordinates=mirror)
196195
assert samples[0].breakpoints == [0, ts.sequence_length]
197196
assert samples[0].parents == [ts.num_nodes - 1]
198197
assert len(samples[0].mutations) == 0
@@ -205,12 +204,13 @@ def test_match_reference_one_mutation(self, mirror, site_id):
205204
tables.sites.truncate(20)
206205
ts = tables.tree_sequence()
207206
samples = util.get_samples(ts, [[(0, ts.sequence_length, 1)]])
208-
samples[0].alignment = sc2ts.core.get_reference_sequence()
209-
ma = sc2ts.alignments.encode_and_mask(samples[0].alignment)
207+
alignment = sc2ts.core.get_reference_sequence()
208+
ma = sc2ts.alignments.encode_and_mask(alignment)
210209
h = ma.alignment[ts.sites_position.astype(int)]
211210
# Mutate to gap
212211
h[site_id] = sc2ts.core.ALLELES.index("-")
213-
self.match_tsinfer(samples, ts, [h], mirror_coordinates=mirror)
212+
samples[0].alignment = h
213+
self.match_tsinfer(samples, ts, mirror_coordinates=mirror)
214214
assert samples[0].breakpoints == [0, ts.sequence_length]
215215
assert samples[0].parents == [ts.num_nodes - 1]
216216
assert len(samples[0].mutations) == 1
@@ -230,11 +230,12 @@ def test_match_reference_all_same(self, mirror, allele):
230230
tables.sites.truncate(20)
231231
ts = tables.tree_sequence()
232232
samples = util.get_samples(ts, [[(0, ts.sequence_length, 1)]])
233-
samples[0].alignment = sc2ts.core.get_reference_sequence()
234-
ma = sc2ts.alignments.encode_and_mask(samples[0].alignment)
233+
alignment = sc2ts.core.get_reference_sequence()
234+
ma = sc2ts.alignments.encode_and_mask(alignment)
235235
ref = ma.alignment[ts.sites_position.astype(int)]
236236
h = np.zeros_like(ref) + allele
237-
self.match_tsinfer(samples, ts, [h], mirror_coordinates=mirror)
237+
samples[0].alignment = h
238+
self.match_tsinfer(samples, ts, mirror_coordinates=mirror)
238239
assert samples[0].breakpoints == [0, ts.sequence_length]
239240
assert samples[0].parents == [ts.num_nodes - 1]
240241
muts = samples[0].mutations

0 commit comments

Comments
 (0)