@@ -64,6 +64,8 @@ def add(self, samples, date, num_mismatches):
64
64
for j , sample in enumerate (samples ):
65
65
d = sample .asdict ()
66
66
assert sample .date == date
67
+ # FIXME we want to be more selective about what we're storing
68
+ # here, as we're including the alignment too.
67
69
pkl = pickle .dumps (sample )
68
70
# BZ2 compressing drops this by ~10X, so worth it.
69
71
pkl_compressed = bz2 .compress (pkl )
@@ -75,6 +77,10 @@ def add(self, samples, date, num_mismatches):
75
77
pkl_compressed ,
76
78
)
77
79
data .append (args )
80
+ pango = sample .metadata .get ("Viridian_pangolin" , "Unknown" )
81
+ logger .debug (
82
+ f"MatchDB insert: { sample .strain } { date } { pango } hmm_cost={ hmm_cost [j ]} "
83
+ )
78
84
# Batch insert, for efficiency.
79
85
with self .conn :
80
86
self .conn .executemany (sql , data )
@@ -124,7 +130,11 @@ def get(self, where_clause):
124
130
for row in self .conn .execute (sql ):
125
131
pkl = row .pop ("pickle" )
126
132
sample = pickle .loads (bz2 .decompress (pkl ))
127
- logger .debug (f"MatchDb got: { row } " )
133
+ pango = sample .metadata .get ("Viridian_pangolin" , "Unknown" )
134
+ logger .debug (
135
+ f"MatchDb got: { sample .strain } { sample .date } { pango } "
136
+ f"hmm_cost={ row ['hmm_cost' ]} "
137
+ )
128
138
# print(row)
129
139
yield sample
130
140
@@ -149,19 +159,20 @@ def initialise(db_path):
149
159
)
150
160
return MatchDb (db_path )
151
161
152
-
153
162
def print_all (self ):
154
163
"""
155
164
Debug method to print out full state of the DB.
156
165
"""
157
166
import pandas as pd
167
+
158
168
data = []
159
169
with self .conn :
160
170
for row in self .conn .execute ("SELECT * from samples" ):
161
171
data .append (row )
162
172
df = pd .DataFrame (row , index = ["strain" ])
163
173
print (df )
164
174
175
+
165
176
def mirror (x , L ):
166
177
return L - x
167
178
@@ -253,7 +264,7 @@ def last_date(ts):
253
264
# reference but not as a sample
254
265
u = ts .num_nodes - 1
255
266
node = ts .node (u )
256
- assert node .time == 0
267
+ # assert node.time == 0
257
268
return parse_date (node .metadata ["date" ])
258
269
else :
259
270
samples = ts .samples ()
@@ -336,6 +347,7 @@ class Sample:
336
347
mutations : List = dataclasses .field (default_factory = list )
337
348
alignment_qc : Dict = dataclasses .field (default_factory = dict )
338
349
masked_sites : List = dataclasses .field (default_factory = list )
350
+ alignment : List = None
339
351
340
352
# def __repr__(self):
341
353
# return self.strain
@@ -352,18 +364,6 @@ def breakpoints(self):
352
364
def parents (self ):
353
365
return [seg .parent for seg in self .path ]
354
366
355
- # @property
356
- # def date(self):
357
- # return parse_date(self.metadata["date"])
358
-
359
- # @property
360
- # def submission_date(self):
361
- # return parse_date(self.metadata["date_submitted"])
362
-
363
- # @property
364
- # def submission_delay(self):
365
- # return (self.submission_date - self.date).days
366
-
367
367
def get_hmm_cost (self , num_mismatches ):
368
368
# Note that Recombinant objects have total_cost.
369
369
# This bit of code is sort of repeated.
@@ -424,70 +424,84 @@ def daily_extend(
424
424
last_ts = ts
425
425
426
426
427
- def preprocess_and_match_alignments (
427
+ def preprocess (
428
428
date ,
429
429
* ,
430
+ base_ts ,
430
431
metadata_db ,
431
432
alignment_store ,
432
- match_db ,
433
- base_ts ,
434
- num_mismatches = None ,
435
- show_progress = False ,
436
- num_threads = None ,
437
- precision = None ,
438
433
max_daily_samples = None ,
439
- mirror_coordinates = False ,
434
+ show_progress = False ,
440
435
):
441
- if num_mismatches is None :
442
- # Default to no recombination
443
- num_mismatches = 1000
444
-
445
436
samples = []
446
- for md in metadata_db .get (date ):
447
- samples .append (Sample (md ["strain" ], md ["date" ], md ))
448
- if len (samples ) == 0 :
449
- logger .warn (f"Zero samples for { date } " )
450
- return
437
+ metadata_matches = list (metadata_db .get (date ))
438
+
439
+ if len (metadata_matches ) == 0 :
440
+ logger .warn (f"Zero metadata matches for { date } " )
441
+ return []
442
+
443
+ if date .endswith ("01-01" ):
444
+ logger .warning (f"Skipping { len (metadata_matches )} samples for { date } " )
445
+ return []
446
+
451
447
# TODO implement this.
452
448
assert max_daily_samples is None
453
449
454
- # Note: there's not a lot of point in making the G matrix here,
455
- # we should just pass on the encoded alignments to the matching
456
- # algorithm directly through the Sample class, and let it
457
- # do the low-level haplotype storage.
458
- G = np .zeros ((base_ts .num_sites , len (samples )), dtype = np .int8 )
459
450
keep_sites = base_ts .sites_position .astype (int )
460
451
problematic_sites = core .get_problematic_sites ()
452
+ samples = []
461
453
462
- samples_iter = enumerate (samples )
463
454
with tqdm .tqdm (
464
- samples_iter ,
465
- desc = f"Fetch:{ date } " ,
466
- total = len (samples ),
455
+ metadata_matches ,
456
+ desc = f"Preprocess:{ date } " ,
467
457
disable = not show_progress ,
468
458
) as bar :
469
- for j , sample in bar :
470
- logger .debug (f"Getting alignment for { sample .strain } " )
471
- alignment = alignment_store [sample .strain ]
472
- sample .alignment = alignment
473
- logger .debug ("Encoding alignment" )
459
+ for md in bar :
460
+ strain = md ["strain" ]
461
+ logger .debug (f"Getting alignment for { strain } " )
462
+ try :
463
+ alignment = alignment_store [strain ]
464
+ except KeyError :
465
+ logger .debug (f"No alignment stored for { strain } " )
466
+ continue
467
+
468
+ sample = Sample (strain , date , metadata = md )
474
469
ma = alignments .encode_and_mask (alignment )
475
470
# Always mask the problematic_sites as well. We need to do this
476
471
# for follow-up matching to inspect recombinants, as tsinfer
477
472
# needs us to keep all sites in the table when doing mirrored
478
473
# coordinates.
479
474
ma .alignment [problematic_sites ] = - 1
480
- G [:, j ] = ma .alignment [keep_sites ]
481
475
sample .alignment_qc = ma .qc_summary ()
482
476
sample .masked_sites = ma .masked_sites
477
+ sample .alignment = ma .alignment [keep_sites ]
478
+ samples .append (sample )
483
479
484
- masked_per_sample = np .mean ([len (sample .masked_sites )])
485
- logger .info (f"Masked average of { masked_per_sample :.2f} nucleotides per sample" )
480
+ logger .info (
481
+ f"Got alignments for { len (samples )} of { len (metadata_matches )} in metadata"
482
+ )
483
+ return samples
484
+
485
+
486
+ def match_samples (
487
+ date ,
488
+ samples ,
489
+ * ,
490
+ match_db ,
491
+ base_ts ,
492
+ num_mismatches = None ,
493
+ show_progress = False ,
494
+ num_threads = None ,
495
+ precision = None ,
496
+ mirror_coordinates = False ,
497
+ ):
498
+ if num_mismatches is None :
499
+ # Default to no recombination
500
+ num_mismatches = 1000
486
501
487
502
match_tsinfer (
488
503
samples = samples ,
489
504
ts = base_ts ,
490
- genotypes = G ,
491
505
num_mismatches = num_mismatches ,
492
506
precision = precision ,
493
507
num_threads = num_threads ,
@@ -515,21 +529,36 @@ def extend(
515
529
precision = None ,
516
530
rng = None ,
517
531
):
532
+ logger .info (
533
+ f"Extend { date } ; ts:nodes={ base_ts .num_nodes } ;edges={ base_ts .num_edges } ;"
534
+ f"mutations={ base_ts .num_mutations } "
535
+ )
518
536
# TODO not sure whether we'll keep these params. Making sure they're not
519
537
# used for now
520
538
assert max_submission_delay is None
521
539
522
- preprocess_and_match_alignments (
540
+ samples = preprocess (
523
541
date ,
524
542
metadata_db = metadata_db ,
525
543
alignment_store = alignment_store ,
526
544
base_ts = base_ts ,
545
+ max_daily_samples = max_daily_samples ,
546
+ show_progress = show_progress ,
547
+ )
548
+
549
+ if len (samples ) == 0 :
550
+ logger .warning (f"Nothing to do for { date } " )
551
+ return base_ts
552
+
553
+ match_samples (
554
+ date ,
555
+ samples ,
556
+ base_ts = base_ts ,
527
557
match_db = match_db ,
528
558
num_mismatches = num_mismatches ,
529
559
show_progress = show_progress ,
530
560
num_threads = num_threads ,
531
561
precision = precision ,
532
- max_daily_samples = max_daily_samples ,
533
562
)
534
563
535
564
match_db .create_mask_table (base_ts )
@@ -574,6 +603,10 @@ def match_path_ts(samples, ts, path, reversions):
574
603
path = samples [0 ].path
575
604
site_id_map = {}
576
605
first_sample = len (tables .nodes )
606
+ logger .debug (
607
+ f"Adding group of { len (samples )} with path={ path } and "
608
+ f"reversions={ reversions } "
609
+ )
577
610
for sample in samples :
578
611
assert sample .path == path
579
612
metadata = {
@@ -596,6 +629,10 @@ def match_path_ts(samples, ts, path, reversions):
596
629
# Now add the mutations
597
630
for node_id , sample in enumerate (samples , first_sample ):
598
631
# metadata = {**sample.metadata, "sc2ts_qc": sample.alignment_qc}
632
+ logger .debug (
633
+ f"Adding { sample .strain } :{ sample .date } with "
634
+ f"{ len (sample .mutations )} mutations"
635
+ )
599
636
for mut in sample .mutations :
600
637
tables .mutations .add_row (
601
638
site = site_id_map [mut .site_id ],
@@ -1210,14 +1247,14 @@ def resize_copy(array, new_size):
1210
1247
def match_tsinfer (
1211
1248
samples ,
1212
1249
ts ,
1213
- genotypes ,
1214
1250
* ,
1215
1251
num_mismatches ,
1216
1252
precision = None ,
1217
1253
num_threads = 0 ,
1218
1254
show_progress = False ,
1219
1255
mirror_coordinates = False ,
1220
1256
):
1257
+ genotypes = np .array ([sample .alignment for sample in samples ], dtype = np .int8 ).T
1221
1258
input_ts = ts
1222
1259
if mirror_coordinates :
1223
1260
ts = mirror_ts_coordinates (ts )
0 commit comments