1
1
from __future__ import annotations
2
+ import bz2
2
3
import logging
3
4
import datetime
4
5
import dataclasses
5
6
import collections
6
7
import io
7
8
import pickle
8
9
import os
10
+ import sqlite3
11
+ import pathlib
12
+ import json
9
13
10
14
import tqdm
11
15
import tskit
18
22
19
23
from . import core
20
24
from . import alignments
25
+ from . import metadata
21
26
22
27
logger = logging .getLogger (__name__ )
23
28
24
29
30
+ class MatchDb :
31
+ def __init__ (self , path ):
32
+ uri = f"file:{ path } "
33
+ # uri += "?mode=rw"
34
+ self .uri = uri
35
+ self .conn = sqlite3 .connect (uri , uri = True )
36
+ self .conn .row_factory = metadata .dict_factory
37
+
38
+ def __len__ (self ):
39
+ sql = "SELECT COUNT(*) FROM samples"
40
+ with self .conn :
41
+ row = self .conn .execute (sql ).fetchone ()
42
+ return row ["COUNT(*)" ]
43
+
44
+ def __str__ (self ):
45
+ return "MatchDb at {self.uri} has {len(self)} samples"
46
+
47
+ def __enter__ (self ):
48
+ return self
49
+
50
+ def __exit__ (self , type , value , traceback ):
51
+ self .close ()
52
+
53
+ def __str__ (self ):
54
+ return f"MatchDb at { self .uri } contains { len (self )} samples"
55
+
56
+ def close (self ):
57
+ self .conn .close ()
58
+
59
+ def add (self , sample , date , num_mismatches ):
60
+ """
61
+ Adds the specified matched sample to this MatchDb.
62
+ """
63
+ d = sample .asdict ()
64
+ sql = """\
65
+ INSERT INTO samples (
66
+ strain, match_date, hmm_cost, pickle)
67
+ VALUES (?, ?, ?, ?)
68
+ """
69
+
70
+ pkl = pickle .dumps (sample )
71
+ # BZ2 compressing drops this by ~10X, so worth it.
72
+ pkl_compressed = bz2 .compress (pkl )
73
+ args = (
74
+ sample .strain ,
75
+ date ,
76
+ sample .get_hmm_cost (num_mismatches ),
77
+ pkl_compressed ,
78
+ )
79
+ with self .conn :
80
+ self .conn .execute (sql , args )
81
+
82
+ def create_mask_table (self , ts ):
83
+ # TODO this is inefficient - need some logging to see how much time
84
+ # we're spending here.
85
+ samples = [(ts .node (u ).metadata ["strain" ],) for u in ts .samples ()]
86
+ sql = """\
87
+ DROP TABLE IF EXISTS sample_mask;
88
+ CREATE TABLE sample_mask (
89
+ strain TEXT,
90
+ PRIMARY KEY (strain));
91
+ """
92
+ with self .conn :
93
+ self .conn .execute ("DROP TABLE IF EXISTS used_samples" )
94
+ self .conn .execute (
95
+ "CREATE TABLE used_samples (strain TEXT, PRIMARY KEY (strain))"
96
+ )
97
+ self .conn .executemany ("INSERT INTO used_samples VALUES (?)" , samples )
98
+
99
+ def get (self , where_clause ):
100
+ sql = (
101
+ "SELECT * FROM samples LEFT JOIN used_samples "
102
+ "ON samples.strain = used_samples.strain "
103
+ f"WHERE used_samples.strain IS NULL AND { where_clause } "
104
+ )
105
+ with self .conn :
106
+ logger .debug (f"MatchDb run: { sql } " )
107
+ for row in self .conn .execute (sql ):
108
+ pkl = row .pop ("pickle" )
109
+ sample = pickle .loads (bz2 .decompress (pkl ))
110
+ logger .debug (f"MatchDb got: { row } " )
111
+ # print(row)
112
+ yield sample
113
+
114
+ @staticmethod
115
+ def initialise (db_path ):
116
+ db_path = pathlib .Path (db_path )
117
+ if db_path .exists ():
118
+ db_path .unlink ()
119
+ sql = """\
120
+ CREATE TABLE samples (
121
+ strain TEXT,
122
+ match_date TEXT,
123
+ hmm_cost REAL,
124
+ pickle BLOB,
125
+ PRIMARY KEY (strain))
126
+ """
127
+
128
+ with sqlite3 .connect (db_path ) as conn :
129
+ conn .execute (sql )
130
+ conn .execute (
131
+ "CREATE INDEX [ix_samples_match_date] on 'samples' " "([match_date]);"
132
+ )
133
+ conn .execute (
134
+ "CREATE INDEX [ix_samples_insertion_date] on 'samples' "
135
+ "([insertion_date]);"
136
+ )
137
+ return MatchDb (db_path )
138
+
139
+
25
140
def mirror (x , L ):
26
141
return L - x
27
142
@@ -255,6 +370,7 @@ def daily_extend(
255
370
alignment_store ,
256
371
metadata_db ,
257
372
base_ts ,
373
+ match_db ,
258
374
num_mismatches = None ,
259
375
max_hmm_cost = None ,
260
376
min_group_size = None ,
@@ -292,6 +408,7 @@ def daily_extend(
292
408
metadata_db = metadata_db ,
293
409
date = date ,
294
410
base_ts = last_ts ,
411
+ match_db = match_db ,
295
412
num_mismatches = num_mismatches ,
296
413
max_hmm_cost = max_hmm_cost ,
297
414
min_group_size = min_group_size ,
@@ -395,6 +512,7 @@ def extend(
395
512
metadata_db ,
396
513
date ,
397
514
base_ts ,
515
+ match_db ,
398
516
num_mismatches = None ,
399
517
max_hmm_cost = None ,
400
518
min_group_size = None ,
@@ -407,6 +525,7 @@ def extend(
407
525
reconsidered_samples = None ,
408
526
):
409
527
date_samples = [Sample (md ) for md in metadata_db .get (date )]
528
+ # TODO remove the max_submission_delay #203
410
529
samples = filter_samples (date_samples , alignment_store , max_submission_delay )
411
530
412
531
if max_daily_samples is not None and len (samples ) > max_daily_samples :
@@ -420,7 +539,7 @@ def extend(
420
539
logger .info (f"Got { len (samples )} samples" )
421
540
422
541
# Note num_mismatches is assigned a default value in match_tsinfer.
423
- samples = match (
542
+ match (
424
543
samples ,
425
544
alignment_store = alignment_store ,
426
545
base_ts = base_ts ,
@@ -429,29 +548,39 @@ def extend(
429
548
num_threads = num_threads ,
430
549
precision = precision ,
431
550
)
551
+
552
+ match_db .create_mask_table (base_ts )
553
+
554
+ # FIXME
555
+ if max_hmm_cost is None :
556
+ # By default, arbitraily high.
557
+ max_hmm_cost = 1e6
558
+
559
+ num_mismatches = 1000 if num_mismatches is None else num_mismatches
560
+
561
+ for sample in samples :
562
+ match_db .add (sample , date , num_mismatches )
563
+
432
564
ts = increment_time (date , base_ts )
433
565
434
566
ts = add_matching_results (
435
- samples = samples ,
567
+ f"match_date==' { date } ' and hmm_cost<= { max_hmm_cost } " ,
436
568
ts = ts ,
569
+ match_db = match_db ,
437
570
date = date ,
438
- num_mismatches = num_mismatches ,
439
- max_hmm_cost = max_hmm_cost ,
440
- min_group_size = None ,
571
+ min_group_size = 1 ,
441
572
show_progress = show_progress ,
442
573
)
443
574
444
- # ts, _, added_back_samples = add_matching_results(
445
- # samples=reconsidered_samples,
446
- # ts=ts,
447
- # date=date,
448
- # num_mismatches=num_mismatches,
449
- # max_hmm_cost=None,
450
- # min_group_size=min_group_size,
451
- # show_progress=show_progress,
452
- # )
453
-
454
- return ts # , excluded_samples, added_back_samples
575
+ ts = add_matching_results (
576
+ f"match_date<'{ date } ' and insertion_date==NULL" ,
577
+ ts = ts ,
578
+ match_db = match_db ,
579
+ date = date ,
580
+ min_group_size = 3 ,
581
+ show_progress = show_progress ,
582
+ )
583
+ return ts
455
584
456
585
457
586
def match_path_ts (samples , ts , path , reversions ):
@@ -501,33 +630,20 @@ def match_path_ts(samples, ts, path, reversions):
501
630
502
631
503
632
def add_matching_results (
504
- samples ,
633
+ where_clause ,
634
+ match_db ,
505
635
ts ,
506
636
date ,
507
- num_mismatches ,
508
- max_hmm_cost ,
509
637
min_group_size = 1 ,
510
638
show_progress = False ,
511
639
):
512
- if num_mismatches is None :
513
- # Note that this is the default assigned in match_tsinfer.
514
- num_mismatches = 1e3
515
-
516
- if max_hmm_cost is None :
517
- # By default, arbitraily high.
518
- max_hmm_cost = 1e6
519
-
520
- if min_group_size is None :
521
- min_group_size = 1
640
+ samples = match_db .get (where_clause )
522
641
523
642
# Group matches by path and set of immediate reversions.
524
643
grouped_matches = collections .defaultdict (list )
525
644
excluded_samples = []
526
645
site_masked_samples = np .zeros (int (ts .sequence_length ), dtype = int )
527
646
for sample in samples :
528
- if sample .get_hmm_cost (num_mismatches ) > max_hmm_cost :
529
- excluded_samples .append (sample )
530
- continue
531
647
site_masked_samples [sample .masked_sites ] += 1
532
648
path = tuple (sample .path )
533
649
reversions = tuple (
0 commit comments