Skip to content

Commit 2b54993

Browse files
szhanjeromekelleher
authored andcommitted
Reconsider filtered samples from the past N consecutive days
1 parent a62086d commit 2b54993

File tree

2 files changed

+115
-5
lines changed

2 files changed

+115
-5
lines changed

sc2ts/cli.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,18 @@ def dump_samples(samples, output_file):
169169
)
170170
@click.option("--num-mismatches", default=None, type=float, help="num-mismatches")
171171
@click.option("--max-hmm-cost", default=None, type=float, help="max-hmm-cost")
172+
@click.option(
173+
"--min-group-size",
174+
default=None,
175+
type=int,
176+
help="Minimum size of groups of reconsidered samples",
177+
)
178+
@click.option(
179+
"--num-past-days",
180+
default=None,
181+
type=int,
182+
help="Number of past days to retrieve filtered samples",
183+
)
172184
@click.option(
173185
"--max-submission-delay",
174186
default=None,
@@ -187,6 +199,15 @@ def dump_samples(samples, output_file):
187199
"is greater than this, randomly subsample."
188200
),
189201
)
202+
@click.option(
203+
"--excluded_samples_dir",
204+
default=None,
205+
type=click.Path(file_okay=False, dir_okay=True),
206+
help=(
207+
"Directory containing pickled files of excluded samples. "
208+
"By default, it is set to output_prefx."
209+
),
210+
)
190211
@click.option("--num-threads", default=0, type=int, help="Number of match threads")
191212
@click.option("--random-seed", default=42, type=int, help="Random seed for subsampling")
192213
@click.option("-p", "--precision", default=None, type=int, help="Match precision")
@@ -200,8 +221,11 @@ def daily_extend(
200221
base,
201222
num_mismatches,
202223
max_hmm_cost,
224+
min_group_size,
225+
num_past_days,
203226
max_submission_delay,
204227
max_daily_samples,
228+
excluded_samples_dir,
205229
num_threads,
206230
random_seed,
207231
precision,
@@ -219,6 +243,9 @@ def daily_extend(
219243
else:
220244
base_ts = tskit.load(base)
221245

246+
if excluded_samples_dir is None:
247+
excluded_samples_dir = output_prefix
248+
222249
with contextlib.ExitStack() as exit_stack:
223250
alignment_store = exit_stack.enter_context(sc2ts.AlignmentStore(alignments))
224251
metadata_db = exit_stack.enter_context(sc2ts.MetadataDb(metadata))
@@ -228,12 +255,15 @@ def daily_extend(
228255
base_ts=base_ts,
229256
num_mismatches=num_mismatches,
230257
max_hmm_cost=max_hmm_cost,
258+
min_group_size=min_group_size,
259+
num_past_days=num_past_days,
231260
max_submission_delay=max_submission_delay,
232261
max_daily_samples=max_daily_samples,
233262
rng=rng,
234263
precision=precision,
235264
num_threads=num_threads,
236265
show_progress=not no_progress,
266+
excluded_sample_dir=excluded_samples_dir,
237267
)
238268
for ts, excluded_samples, date in ts_iter:
239269
output_ts = output_prefix + date + ".ts"

sc2ts/inference.py

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import dataclasses
55
import collections
66
import io
7+
import pickle
8+
import os
79

810
import tqdm
911
import tskit
@@ -110,11 +112,13 @@ def last_date(ts):
110112
# Special case for the initial ts which contains the
111113
# reference but not as a sample
112114
u = ts.num_nodes - 1
115+
node = ts.node(u)
116+
assert node.time == 0
117+
return parse_date(node.metadata["date"])
113118
else:
114-
u = ts.samples()[-1]
115-
node = ts.node(u)
116-
assert node.time == 0
117-
return parse_date(node.metadata["date"])
119+
samples = ts.samples()
120+
samples_t0 = samples[ts.nodes_time[samples] == 0]
121+
return max([parse_date(ts.node(u).metadata["date"]) for u in samples_t0])
118122

119123

120124
def increment_time(date, ts):
@@ -253,14 +257,31 @@ def daily_extend(
253257
base_ts,
254258
num_mismatches=None,
255259
max_hmm_cost=None,
260+
min_group_size=None,
261+
num_past_days=None,
256262
show_progress=False,
257263
max_submission_delay=None,
258264
max_daily_samples=None,
259265
num_threads=None,
260266
precision=None,
261267
rng=None,
268+
excluded_sample_dir=None,
262269
):
263270
start_day = last_date(base_ts)
271+
272+
reconsidered_samples = collections.deque()
273+
earliest_date = start_day - datetime.timedelta(days=1)
274+
if base_ts is not None:
275+
next_day = start_day + datetime.timedelta(days=1)
276+
reconsidered_samples.extend(
277+
fetch_samples_from_pickle_file(
278+
date=next_day,
279+
num_past_days=num_past_days,
280+
in_dir=excluded_sample_dir,
281+
)
282+
)
283+
earliest_date = next_day - datetime.timedelta(days=num_past_days)
284+
264285
last_ts = base_ts
265286
for date in metadata_db.get_days(start_day):
266287
ts, excluded_samples = extend(
@@ -270,14 +291,25 @@ def daily_extend(
270291
base_ts=last_ts,
271292
num_mismatches=num_mismatches,
272293
max_hmm_cost=max_hmm_cost,
294+
min_group_size=min_group_size,
273295
show_progress=show_progress,
274296
max_submission_delay=max_submission_delay,
275297
max_daily_samples=max_daily_samples,
276298
num_threads=num_threads,
277299
precision=precision,
278300
rng=rng,
301+
reconsidered_samples=reconsidered_samples,
279302
)
280303
yield ts, excluded_samples, date
304+
305+
# Update list of reconsidered samples.
306+
if len(reconsidered_samples) > 0:
307+
while reconsidered_samples[0].date == earliest_date:
308+
reconsidered_samples.popleft()
309+
reconsidered_samples.extend(excluded_samples)
310+
311+
earliest_date += datetime.timedelta(days=1)
312+
281313
last_ts = ts
282314

283315

@@ -348,12 +380,14 @@ def extend(
348380
base_ts,
349381
num_mismatches=None,
350382
max_hmm_cost=None,
383+
min_group_size=None,
351384
show_progress=False,
352385
max_submission_delay=None,
353386
max_daily_samples=None,
354387
num_threads=None,
355388
precision=None,
356389
rng=None,
390+
reconsidered_samples=None,
357391
):
358392
date_samples = [Sample(md) for md in metadata_db.get(date)]
359393
samples = filter_samples(date_samples, alignment_store, max_submission_delay)
@@ -386,6 +420,17 @@ def extend(
386420
date=date,
387421
num_mismatches=num_mismatches,
388422
max_hmm_cost=max_hmm_cost,
423+
min_group_size=None,
424+
show_progress=show_progress,
425+
)
426+
427+
ts, _ = add_matching_results(
428+
samples=reconsidered_samples,
429+
ts=ts,
430+
date=date,
431+
num_mismatches=num_mismatches,
432+
max_hmm_cost=None,
433+
min_group_size=min_group_size,
389434
show_progress=show_progress,
390435
)
391436

@@ -439,7 +484,13 @@ def match_path_ts(samples, ts, path, reversions):
439484

440485

441486
def add_matching_results(
442-
samples, ts, date, num_mismatches, max_hmm_cost, show_progress=False
487+
samples,
488+
ts,
489+
date,
490+
num_mismatches,
491+
max_hmm_cost,
492+
min_group_size,
493+
show_progress=False,
443494
):
444495
if num_mismatches is None:
445496
# Note that this is the default assigned in match_tsinfer.
@@ -449,6 +500,9 @@ def add_matching_results(
449500
# By default, arbitraily high.
450501
max_hmm_cost = 1e6
451502

503+
if min_group_size is None:
504+
min_group_size = 1
505+
452506
# Group matches by path and set of immediate reversions.
453507
grouped_matches = collections.defaultdict(list)
454508
excluded_samples = []
@@ -478,6 +532,9 @@ def add_matching_results(
478532
disable=not show_progress,
479533
) as bar:
480534
for (path, reversions), match_samples in bar:
535+
if len(match_samples) < min_group_size:
536+
continue
537+
481538
# print(path, reversions, len(match_samples))
482539
# Delete the reversions from these samples so that we don't
483540
# build them into the trees
@@ -540,6 +597,29 @@ def add_matching_results(
540597
return ts, excluded_samples
541598

542599

600+
def fetch_samples_from_pickle_file(date, num_past_days=None, in_dir=None):
601+
if in_dir is None:
602+
return []
603+
if num_past_days is None:
604+
num_past_days = 0
605+
file_suffix = ".excluded_samples.pickle"
606+
samples = []
607+
for i in range(num_past_days, 0, -1):
608+
past_date = date - datetime.timedelta(days=i)
609+
pickle_file = in_dir + "/"
610+
pickle_file += past_date.strftime('%Y-%m-%d') + file_suffix
611+
if os.path.exists(pickle_file):
612+
samples += parse_pickle_file(pickle_file)
613+
return samples
614+
615+
616+
def parse_pickle_file(pickle_file):
617+
"""Return a list of Sample objects."""
618+
with open(pickle_file, 'rb') as f:
619+
samples = pickle.load(f)
620+
return samples
621+
622+
543623
def solve_num_mismatches(ts, k):
544624
"""
545625
Return the low-level LS parameters corresponding to accepting

0 commit comments

Comments
 (0)