Skip to content

Commit 13e23c6

Browse files
Add match_db
1 parent 37a7cd3 commit 13e23c6

File tree

2 files changed

+152
-32
lines changed

2 files changed

+152
-32
lines changed

sc2ts/cli.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,13 @@ def daily_extend(
238238
"""
239239
setup_logging(verbose, log_file)
240240
rng = random.Random(random_seed)
241+
match_db_path = f"{output_prefix}match.db"
241242
if base is None:
242243
base_ts = inference.initial_ts()
244+
match_db = inference.MatchDb.initialise(match_db_path)
243245
else:
244246
base_ts = tskit.load(base)
247+
match_db = inference.MatchDb(match_db_path)
245248

246249
if excluded_samples_dir is None:
247250
excluded_samples_dir = output_prefix
@@ -253,6 +256,7 @@ def daily_extend(
253256
alignment_store=alignment_store,
254257
metadata_db=metadata_db,
255258
base_ts=base_ts,
259+
match_db=match_db,
256260
num_mismatches=num_mismatches,
257261
max_hmm_cost=max_hmm_cost,
258262
min_group_size=min_group_size,

sc2ts/inference.py

Lines changed: 148 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from __future__ import annotations
2+
import bz2
23
import logging
34
import datetime
45
import dataclasses
56
import collections
67
import io
78
import pickle
89
import os
10+
import sqlite3
11+
import pathlib
12+
import json
913

1014
import tqdm
1115
import tskit
@@ -18,10 +22,121 @@
1822

1923
from . import core
2024
from . import alignments
25+
from . import metadata
2126

2227
logger = logging.getLogger(__name__)
2328

2429

30+
class MatchDb:
31+
def __init__(self, path):
32+
uri = f"file:{path}"
33+
# uri += "?mode=rw"
34+
self.uri = uri
35+
self.conn = sqlite3.connect(uri, uri=True)
36+
self.conn.row_factory = metadata.dict_factory
37+
38+
def __len__(self):
39+
sql = "SELECT COUNT(*) FROM samples"
40+
with self.conn:
41+
row = self.conn.execute(sql).fetchone()
42+
return row["COUNT(*)"]
43+
44+
def __str__(self):
45+
return "MatchDb at {self.uri} has {len(self)} samples"
46+
47+
def __enter__(self):
48+
return self
49+
50+
def __exit__(self, type, value, traceback):
51+
self.close()
52+
53+
def __str__(self):
54+
return f"MatchDb at {self.uri} contains {len(self)} samples"
55+
56+
def close(self):
57+
self.conn.close()
58+
59+
def add(self, sample, date, num_mismatches):
60+
"""
61+
Adds the specified matched sample to this MatchDb.
62+
"""
63+
d = sample.asdict()
64+
sql = """\
65+
INSERT INTO samples (
66+
strain, match_date, hmm_cost, pickle)
67+
VALUES (?, ?, ?, ?)
68+
"""
69+
70+
pkl = pickle.dumps(sample)
71+
# BZ2 compressing drops this by ~10X, so worth it.
72+
pkl_compressed = bz2.compress(pkl)
73+
args = (
74+
sample.strain,
75+
date,
76+
sample.get_hmm_cost(num_mismatches),
77+
pkl_compressed,
78+
)
79+
with self.conn:
80+
self.conn.execute(sql, args)
81+
82+
def create_mask_table(self, ts):
83+
# TODO this is inefficient - need some logging to see how much time
84+
# we're spending here.
85+
samples = [(ts.node(u).metadata["strain"],) for u in ts.samples()]
86+
sql = """\
87+
DROP TABLE IF EXISTS sample_mask;
88+
CREATE TABLE sample_mask (
89+
strain TEXT,
90+
PRIMARY KEY (strain));
91+
"""
92+
with self.conn:
93+
self.conn.execute("DROP TABLE IF EXISTS used_samples")
94+
self.conn.execute(
95+
"CREATE TABLE used_samples (strain TEXT, PRIMARY KEY (strain))"
96+
)
97+
self.conn.executemany("INSERT INTO used_samples VALUES (?)", samples)
98+
99+
def get(self, where_clause):
100+
sql = (
101+
"SELECT * FROM samples LEFT JOIN used_samples "
102+
"ON samples.strain = used_samples.strain "
103+
f"WHERE used_samples.strain IS NULL AND {where_clause}"
104+
)
105+
with self.conn:
106+
logger.debug(f"MatchDb run: {sql}")
107+
for row in self.conn.execute(sql):
108+
pkl = row.pop("pickle")
109+
sample = pickle.loads(bz2.decompress(pkl))
110+
logger.debug(f"MatchDb got: {row}")
111+
# print(row)
112+
yield sample
113+
114+
@staticmethod
115+
def initialise(db_path):
116+
db_path = pathlib.Path(db_path)
117+
if db_path.exists():
118+
db_path.unlink()
119+
sql = """\
120+
CREATE TABLE samples (
121+
strain TEXT,
122+
match_date TEXT,
123+
hmm_cost REAL,
124+
pickle BLOB,
125+
PRIMARY KEY (strain))
126+
"""
127+
128+
with sqlite3.connect(db_path) as conn:
129+
conn.execute(sql)
130+
conn.execute(
131+
"CREATE INDEX [ix_samples_match_date] on 'samples' " "([match_date]);"
132+
)
133+
conn.execute(
134+
"CREATE INDEX [ix_samples_insertion_date] on 'samples' "
135+
"([insertion_date]);"
136+
)
137+
return MatchDb(db_path)
138+
139+
25140
def mirror(x, L):
26141
return L - x
27142

@@ -255,6 +370,7 @@ def daily_extend(
255370
alignment_store,
256371
metadata_db,
257372
base_ts,
373+
match_db,
258374
num_mismatches=None,
259375
max_hmm_cost=None,
260376
min_group_size=None,
@@ -292,6 +408,7 @@ def daily_extend(
292408
metadata_db=metadata_db,
293409
date=date,
294410
base_ts=last_ts,
411+
match_db=match_db,
295412
num_mismatches=num_mismatches,
296413
max_hmm_cost=max_hmm_cost,
297414
min_group_size=min_group_size,
@@ -395,6 +512,7 @@ def extend(
395512
metadata_db,
396513
date,
397514
base_ts,
515+
match_db,
398516
num_mismatches=None,
399517
max_hmm_cost=None,
400518
min_group_size=None,
@@ -407,6 +525,7 @@ def extend(
407525
reconsidered_samples=None,
408526
):
409527
date_samples = [Sample(md) for md in metadata_db.get(date)]
528+
# TODO remove the max_submission_delay #203
410529
samples = filter_samples(date_samples, alignment_store, max_submission_delay)
411530

412531
if max_daily_samples is not None and len(samples) > max_daily_samples:
@@ -420,7 +539,7 @@ def extend(
420539
logger.info(f"Got {len(samples)} samples")
421540

422541
# Note num_mismatches is assigned a default value in match_tsinfer.
423-
samples = match(
542+
match(
424543
samples,
425544
alignment_store=alignment_store,
426545
base_ts=base_ts,
@@ -429,29 +548,39 @@ def extend(
429548
num_threads=num_threads,
430549
precision=precision,
431550
)
551+
552+
match_db.create_mask_table(base_ts)
553+
554+
# FIXME
555+
if max_hmm_cost is None:
556+
# By default, arbitraily high.
557+
max_hmm_cost = 1e6
558+
559+
num_mismatches = 1000 if num_mismatches is None else num_mismatches
560+
561+
for sample in samples:
562+
match_db.add(sample, date, num_mismatches)
563+
432564
ts = increment_time(date, base_ts)
433565

434566
ts = add_matching_results(
435-
samples=samples,
567+
f"match_date=='{date}' and hmm_cost<={max_hmm_cost}",
436568
ts=ts,
569+
match_db=match_db,
437570
date=date,
438-
num_mismatches=num_mismatches,
439-
max_hmm_cost=max_hmm_cost,
440-
min_group_size=None,
571+
min_group_size=1,
441572
show_progress=show_progress,
442573
)
443574

444-
# ts, _, added_back_samples = add_matching_results(
445-
# samples=reconsidered_samples,
446-
# ts=ts,
447-
# date=date,
448-
# num_mismatches=num_mismatches,
449-
# max_hmm_cost=None,
450-
# min_group_size=min_group_size,
451-
# show_progress=show_progress,
452-
# )
453-
454-
return ts # , excluded_samples, added_back_samples
575+
ts = add_matching_results(
576+
f"match_date<'{date}' and insertion_date==NULL",
577+
ts=ts,
578+
match_db=match_db,
579+
date=date,
580+
min_group_size=3,
581+
show_progress=show_progress,
582+
)
583+
return ts
455584

456585

457586
def match_path_ts(samples, ts, path, reversions):
@@ -501,33 +630,20 @@ def match_path_ts(samples, ts, path, reversions):
501630

502631

503632
def add_matching_results(
504-
samples,
633+
where_clause,
634+
match_db,
505635
ts,
506636
date,
507-
num_mismatches,
508-
max_hmm_cost,
509637
min_group_size=1,
510638
show_progress=False,
511639
):
512-
if num_mismatches is None:
513-
# Note that this is the default assigned in match_tsinfer.
514-
num_mismatches = 1e3
515-
516-
if max_hmm_cost is None:
517-
# By default, arbitraily high.
518-
max_hmm_cost = 1e6
519-
520-
if min_group_size is None:
521-
min_group_size = 1
640+
samples = match_db.get(where_clause)
522641

523642
# Group matches by path and set of immediate reversions.
524643
grouped_matches = collections.defaultdict(list)
525644
excluded_samples = []
526645
site_masked_samples = np.zeros(int(ts.sequence_length), dtype=int)
527646
for sample in samples:
528-
if sample.get_hmm_cost(num_mismatches) > max_hmm_cost:
529-
excluded_samples.append(sample)
530-
continue
531647
site_masked_samples[sample.masked_sites] += 1
532648
path = tuple(sample.path)
533649
reversions = tuple(

0 commit comments

Comments
 (0)