4
4
import dataclasses
5
5
import collections
6
6
import io
7
+ import pickle
8
+ import os
7
9
8
10
import tqdm
9
11
import tskit
@@ -110,11 +112,13 @@ def last_date(ts):
110
112
# Special case for the initial ts which contains the
111
113
# reference but not as a sample
112
114
u = ts .num_nodes - 1
115
+ node = ts .node (u )
116
+ assert node .time == 0
117
+ return parse_date (node .metadata ["date" ])
113
118
else :
114
- u = ts .samples ()[- 1 ]
115
- node = ts .node (u )
116
- assert node .time == 0
117
- return parse_date (node .metadata ["date" ])
119
+ samples = ts .samples ()
120
+ samples_t0 = samples [ts .nodes_time [samples ] == 0 ]
121
+ return max ([parse_date (ts .node (u ).metadata ["date" ]) for u in samples_t0 ])
118
122
119
123
120
124
def increment_time (date , ts ):
@@ -253,14 +257,31 @@ def daily_extend(
253
257
base_ts ,
254
258
num_mismatches = None ,
255
259
max_hmm_cost = None ,
260
+ min_group_size = None ,
261
+ num_past_days = None ,
256
262
show_progress = False ,
257
263
max_submission_delay = None ,
258
264
max_daily_samples = None ,
259
265
num_threads = None ,
260
266
precision = None ,
261
267
rng = None ,
268
+ excluded_sample_dir = None ,
262
269
):
263
270
start_day = last_date (base_ts )
271
+
272
+ reconsidered_samples = collections .deque ()
273
+ earliest_date = start_day - datetime .timedelta (days = 1 )
274
+ if base_ts is not None :
275
+ next_day = start_day + datetime .timedelta (days = 1 )
276
+ reconsidered_samples .extend (
277
+ fetch_samples_from_pickle_file (
278
+ date = next_day ,
279
+ num_past_days = num_past_days ,
280
+ in_dir = excluded_sample_dir ,
281
+ )
282
+ )
283
+ earliest_date = next_day - datetime .timedelta (days = num_past_days )
284
+
264
285
last_ts = base_ts
265
286
for date in metadata_db .get_days (start_day ):
266
287
ts , excluded_samples = extend (
@@ -270,14 +291,25 @@ def daily_extend(
270
291
base_ts = last_ts ,
271
292
num_mismatches = num_mismatches ,
272
293
max_hmm_cost = max_hmm_cost ,
294
+ min_group_size = min_group_size ,
273
295
show_progress = show_progress ,
274
296
max_submission_delay = max_submission_delay ,
275
297
max_daily_samples = max_daily_samples ,
276
298
num_threads = num_threads ,
277
299
precision = precision ,
278
300
rng = rng ,
301
+ reconsidered_samples = reconsidered_samples ,
279
302
)
280
303
yield ts , excluded_samples , date
304
+
305
+ # Update list of reconsidered samples.
306
+ if len (reconsidered_samples ) > 0 :
307
+ while reconsidered_samples [0 ].date == earliest_date :
308
+ reconsidered_samples .popleft ()
309
+ reconsidered_samples .extend (excluded_samples )
310
+
311
+ earliest_date += datetime .timedelta (days = 1 )
312
+
281
313
last_ts = ts
282
314
283
315
@@ -348,12 +380,14 @@ def extend(
348
380
base_ts ,
349
381
num_mismatches = None ,
350
382
max_hmm_cost = None ,
383
+ min_group_size = None ,
351
384
show_progress = False ,
352
385
max_submission_delay = None ,
353
386
max_daily_samples = None ,
354
387
num_threads = None ,
355
388
precision = None ,
356
389
rng = None ,
390
+ reconsidered_samples = None ,
357
391
):
358
392
date_samples = [Sample (md ) for md in metadata_db .get (date )]
359
393
samples = filter_samples (date_samples , alignment_store , max_submission_delay )
@@ -386,6 +420,17 @@ def extend(
386
420
date = date ,
387
421
num_mismatches = num_mismatches ,
388
422
max_hmm_cost = max_hmm_cost ,
423
+ min_group_size = None ,
424
+ show_progress = show_progress ,
425
+ )
426
+
427
+ ts , _ = add_matching_results (
428
+ samples = reconsidered_samples ,
429
+ ts = ts ,
430
+ date = date ,
431
+ num_mismatches = num_mismatches ,
432
+ max_hmm_cost = None ,
433
+ min_group_size = min_group_size ,
389
434
show_progress = show_progress ,
390
435
)
391
436
@@ -439,7 +484,13 @@ def match_path_ts(samples, ts, path, reversions):
439
484
440
485
441
486
def add_matching_results (
442
- samples , ts , date , num_mismatches , max_hmm_cost , show_progress = False
487
+ samples ,
488
+ ts ,
489
+ date ,
490
+ num_mismatches ,
491
+ max_hmm_cost ,
492
+ min_group_size ,
493
+ show_progress = False ,
443
494
):
444
495
if num_mismatches is None :
445
496
# Note that this is the default assigned in match_tsinfer.
@@ -449,6 +500,9 @@ def add_matching_results(
449
500
# By default, arbitraily high.
450
501
max_hmm_cost = 1e6
451
502
503
+ if min_group_size is None :
504
+ min_group_size = 1
505
+
452
506
# Group matches by path and set of immediate reversions.
453
507
grouped_matches = collections .defaultdict (list )
454
508
excluded_samples = []
@@ -478,6 +532,9 @@ def add_matching_results(
478
532
disable = not show_progress ,
479
533
) as bar :
480
534
for (path , reversions ), match_samples in bar :
535
+ if len (match_samples ) < min_group_size :
536
+ continue
537
+
481
538
# print(path, reversions, len(match_samples))
482
539
# Delete the reversions from these samples so that we don't
483
540
# build them into the trees
@@ -540,6 +597,29 @@ def add_matching_results(
540
597
return ts , excluded_samples
541
598
542
599
600
+ def fetch_samples_from_pickle_file (date , num_past_days = None , in_dir = None ):
601
+ if in_dir is None :
602
+ return []
603
+ if num_past_days is None :
604
+ num_past_days = 0
605
+ file_suffix = ".excluded_samples.pickle"
606
+ samples = []
607
+ for i in range (num_past_days , 0 , - 1 ):
608
+ past_date = date - datetime .timedelta (days = i )
609
+ pickle_file = in_dir + "/"
610
+ pickle_file += past_date .strftime ('%Y-%m-%d' ) + file_suffix
611
+ if os .path .exists (pickle_file ):
612
+ samples += parse_pickle_file (pickle_file )
613
+ return samples
614
+
615
+
616
+ def parse_pickle_file (pickle_file ):
617
+ """Return a list of Sample objects."""
618
+ with open (pickle_file , 'rb' ) as f :
619
+ samples = pickle .load (f )
620
+ return samples
621
+
622
+
543
623
def solve_num_mismatches (ts , k ):
544
624
"""
545
625
Return the low-level LS parameters corresponding to accepting
0 commit comments