@@ -64,6 +64,8 @@ def add(self, samples, date, num_mismatches):
64
64
for j , sample in enumerate (samples ):
65
65
d = sample .asdict ()
66
66
assert sample .date == date
67
+ # FIXME we want to be more selective about what we're storing
68
+ # here, as we're including the alignment too.
67
69
pkl = pickle .dumps (sample )
68
70
# BZ2 compressing drops this by ~10X, so worth it.
69
71
pkl_compressed = bz2 .compress (pkl )
@@ -75,6 +77,10 @@ def add(self, samples, date, num_mismatches):
75
77
pkl_compressed ,
76
78
)
77
79
data .append (args )
80
+ pango = sample .metadata .get ("Viridian_pangolin" , "Unknown" )
81
+ logger .debug (
82
+ f"MatchDB insert: { sample .strain } { date } { pango } hmm_cost={ hmm_cost [j ]} "
83
+ )
78
84
# Batch insert, for efficiency.
79
85
with self .conn :
80
86
self .conn .executemany (sql , data )
@@ -124,7 +130,11 @@ def get(self, where_clause):
124
130
for row in self .conn .execute (sql ):
125
131
pkl = row .pop ("pickle" )
126
132
sample = pickle .loads (bz2 .decompress (pkl ))
127
- logger .debug (f"MatchDb got: { row } " )
133
+ pango = sample .metadata .get ("Viridian_pangolin" , "Unknown" )
134
+ logger .debug (
135
+ f"MatchDb got: { sample .strain } { sample .date } { pango } "
136
+ f"hmm_cost={ row ['hmm_cost' ]} "
137
+ )
128
138
# print(row)
129
139
yield sample
130
140
@@ -149,19 +159,20 @@ def initialise(db_path):
149
159
)
150
160
return MatchDb (db_path )
151
161
152
-
153
162
def print_all (self ):
154
163
"""
155
164
Debug method to print out full state of the DB.
156
165
"""
157
166
import pandas as pd
167
+
158
168
data = []
159
169
with self .conn :
160
170
for row in self .conn .execute ("SELECT * from samples" ):
161
171
data .append (row )
162
172
df = pd .DataFrame (row , index = ["strain" ])
163
173
print (df )
164
174
175
+
165
176
def mirror (x , L ):
166
177
return L - x
167
178
@@ -253,7 +264,7 @@ def last_date(ts):
253
264
# reference but not as a sample
254
265
u = ts .num_nodes - 1
255
266
node = ts .node (u )
256
- assert node .time == 0
267
+ # assert node.time == 0
257
268
return parse_date (node .metadata ["date" ])
258
269
else :
259
270
samples = ts .samples ()
@@ -336,6 +347,10 @@ class Sample:
336
347
mutations : List = dataclasses .field (default_factory = list )
337
348
alignment_qc : Dict = dataclasses .field (default_factory = dict )
338
349
masked_sites : List = dataclasses .field (default_factory = list )
350
+ # FIXME need a better name for this, as it's a different thing
351
+ # the original alignment. Haplotype is probably good, as it's
352
+ # what it would be in the tskit/tsinfer world.
353
+ alignment : List = None
339
354
340
355
# def __repr__(self):
341
356
# return self.strain
@@ -352,18 +367,6 @@ def breakpoints(self):
352
367
def parents (self ):
353
368
return [seg .parent for seg in self .path ]
354
369
355
- # @property
356
- # def date(self):
357
- # return parse_date(self.metadata["date"])
358
-
359
- # @property
360
- # def submission_date(self):
361
- # return parse_date(self.metadata["date_submitted"])
362
-
363
- # @property
364
- # def submission_delay(self):
365
- # return (self.submission_date - self.date).days
366
-
367
370
def get_hmm_cost (self , num_mismatches ):
368
371
# Note that Recombinant objects have total_cost.
369
372
# This bit of code is sort of repeated.
@@ -424,70 +427,84 @@ def daily_extend(
424
427
last_ts = ts
425
428
426
429
427
- def preprocess_and_match_alignments (
430
+ def preprocess (
428
431
date ,
429
432
* ,
433
+ base_ts ,
430
434
metadata_db ,
431
435
alignment_store ,
432
- match_db ,
433
- base_ts ,
434
- num_mismatches = None ,
435
- show_progress = False ,
436
- num_threads = None ,
437
- precision = None ,
438
436
max_daily_samples = None ,
439
- mirror_coordinates = False ,
437
+ show_progress = False ,
440
438
):
441
- if num_mismatches is None :
442
- # Default to no recombination
443
- num_mismatches = 1000
444
-
445
439
samples = []
446
- for md in metadata_db .get (date ):
447
- samples .append (Sample (md ["strain" ], md ["date" ], md ))
448
- if len (samples ) == 0 :
449
- logger .warn (f"Zero samples for { date } " )
450
- return
440
+ metadata_matches = list (metadata_db .get (date ))
441
+
442
+ if len (metadata_matches ) == 0 :
443
+ logger .warn (f"Zero metadata matches for { date } " )
444
+ return []
445
+
446
+ if date .endswith ("01-01" ):
447
+ logger .warning (f"Skipping { len (metadata_matches )} samples for { date } " )
448
+ return []
449
+
451
450
# TODO implement this.
452
451
assert max_daily_samples is None
453
452
454
- # Note: there's not a lot of point in making the G matrix here,
455
- # we should just pass on the encoded alignments to the matching
456
- # algorithm directly through the Sample class, and let it
457
- # do the low-level haplotype storage.
458
- G = np .zeros ((base_ts .num_sites , len (samples )), dtype = np .int8 )
459
453
keep_sites = base_ts .sites_position .astype (int )
460
454
problematic_sites = core .get_problematic_sites ()
455
+ samples = []
461
456
462
- samples_iter = enumerate (samples )
463
457
with tqdm .tqdm (
464
- samples_iter ,
465
- desc = f"Fetch:{ date } " ,
466
- total = len (samples ),
458
+ metadata_matches ,
459
+ desc = f"Preprocess:{ date } " ,
467
460
disable = not show_progress ,
468
461
) as bar :
469
- for j , sample in bar :
470
- logger .debug (f"Getting alignment for { sample .strain } " )
471
- alignment = alignment_store [sample .strain ]
472
- sample .alignment = alignment
473
- logger .debug ("Encoding alignment" )
462
+ for md in bar :
463
+ strain = md ["strain" ]
464
+ logger .debug (f"Getting alignment for { strain } " )
465
+ try :
466
+ alignment = alignment_store [strain ]
467
+ except KeyError :
468
+ logger .debug (f"No alignment stored for { strain } " )
469
+ continue
470
+
471
+ sample = Sample (strain , date , metadata = md )
474
472
ma = alignments .encode_and_mask (alignment )
475
473
# Always mask the problematic_sites as well. We need to do this
476
474
# for follow-up matching to inspect recombinants, as tsinfer
477
475
# needs us to keep all sites in the table when doing mirrored
478
476
# coordinates.
479
477
ma .alignment [problematic_sites ] = - 1
480
- G [:, j ] = ma .alignment [keep_sites ]
481
478
sample .alignment_qc = ma .qc_summary ()
482
479
sample .masked_sites = ma .masked_sites
480
+ sample .alignment = ma .alignment [keep_sites ]
481
+ samples .append (sample )
483
482
484
- masked_per_sample = np .mean ([len (sample .masked_sites )])
485
- logger .info (f"Masked average of { masked_per_sample :.2f} nucleotides per sample" )
483
+ logger .info (
484
+ f"Got alignments for { len (samples )} of { len (metadata_matches )} in metadata"
485
+ )
486
+ return samples
487
+
488
+
489
+ def match_samples (
490
+ date ,
491
+ samples ,
492
+ * ,
493
+ match_db ,
494
+ base_ts ,
495
+ num_mismatches = None ,
496
+ show_progress = False ,
497
+ num_threads = None ,
498
+ precision = None ,
499
+ mirror_coordinates = False ,
500
+ ):
501
+ if num_mismatches is None :
502
+ # Default to no recombination
503
+ num_mismatches = 1000
486
504
487
505
match_tsinfer (
488
506
samples = samples ,
489
507
ts = base_ts ,
490
- genotypes = G ,
491
508
num_mismatches = num_mismatches ,
492
509
precision = precision ,
493
510
num_threads = num_threads ,
@@ -515,21 +532,36 @@ def extend(
515
532
precision = None ,
516
533
rng = None ,
517
534
):
535
+ logger .info (
536
+ f"Extend { date } ; ts:nodes={ base_ts .num_nodes } ;edges={ base_ts .num_edges } ;"
537
+ f"mutations={ base_ts .num_mutations } "
538
+ )
518
539
# TODO not sure whether we'll keep these params. Making sure they're not
519
540
# used for now
520
541
assert max_submission_delay is None
521
542
522
- preprocess_and_match_alignments (
543
+ samples = preprocess (
523
544
date ,
524
545
metadata_db = metadata_db ,
525
546
alignment_store = alignment_store ,
526
547
base_ts = base_ts ,
548
+ max_daily_samples = max_daily_samples ,
549
+ show_progress = show_progress ,
550
+ )
551
+
552
+ if len (samples ) == 0 :
553
+ logger .warning (f"Nothing to do for { date } " )
554
+ return base_ts
555
+
556
+ match_samples (
557
+ date ,
558
+ samples ,
559
+ base_ts = base_ts ,
527
560
match_db = match_db ,
528
561
num_mismatches = num_mismatches ,
529
562
show_progress = show_progress ,
530
563
num_threads = num_threads ,
531
564
precision = precision ,
532
- max_daily_samples = max_daily_samples ,
533
565
)
534
566
535
567
match_db .create_mask_table (base_ts )
@@ -574,6 +606,10 @@ def match_path_ts(samples, ts, path, reversions):
574
606
path = samples [0 ].path
575
607
site_id_map = {}
576
608
first_sample = len (tables .nodes )
609
+ logger .debug (
610
+ f"Adding group of { len (samples )} with path={ path } and "
611
+ f"reversions={ reversions } "
612
+ )
577
613
for sample in samples :
578
614
assert sample .path == path
579
615
metadata = {
@@ -596,6 +632,10 @@ def match_path_ts(samples, ts, path, reversions):
596
632
# Now add the mutations
597
633
for node_id , sample in enumerate (samples , first_sample ):
598
634
# metadata = {**sample.metadata, "sc2ts_qc": sample.alignment_qc}
635
+ logger .debug (
636
+ f"Adding { sample .strain } :{ sample .date } with "
637
+ f"{ len (sample .mutations )} mutations"
638
+ )
599
639
for mut in sample .mutations :
600
640
tables .mutations .add_row (
601
641
site = site_id_map [mut .site_id ],
@@ -1210,14 +1250,14 @@ def resize_copy(array, new_size):
1210
1250
def match_tsinfer (
1211
1251
samples ,
1212
1252
ts ,
1213
- genotypes ,
1214
1253
* ,
1215
1254
num_mismatches ,
1216
1255
precision = None ,
1217
1256
num_threads = 0 ,
1218
1257
show_progress = False ,
1219
1258
mirror_coordinates = False ,
1220
1259
):
1260
+ genotypes = np .array ([sample .alignment for sample in samples ], dtype = np .int8 ).T
1221
1261
input_ts = ts
1222
1262
if mirror_coordinates :
1223
1263
ts = mirror_ts_coordinates (ts )
0 commit comments