@@ -392,22 +392,51 @@ def match_samples(
392
392
date ,
393
393
samples ,
394
394
* ,
395
- match_db ,
396
395
base_ts ,
397
396
num_mismatches = None ,
398
397
show_progress = False ,
399
398
num_threads = None ,
400
- precision = None ,
401
- mirror_coordinates = False ,
402
399
):
403
- if num_mismatches is None :
404
- # Default to no recombination
405
- num_mismatches = 1000
406
-
407
- # FIXME Something wrong here, we don't seem to get precisely the same
408
- # ARG for some reason. Need to track it down
409
- # Also: should only run the things at low precision that have that HMM cost.
410
- # Start out by setting everything to have 0 mutations and work up from there.
400
+ # First pass, compute the matches at precision=0.
401
+ # precision = 0
402
+ # match_tsinfer(
403
+ # samples=samples,
404
+ # ts=base_ts,
405
+ # num_mismatches=num_mismatches,
406
+ # precision=precision,
407
+ # num_threads=num_threads,
408
+ # show_progress=show_progress,
409
+ # )
410
+
411
+ # cost_threshold = 1
412
+ # rerun_batch = []
413
+ # for sample in samples:
414
+ # cost = sample.get_hmm_cost(num_mismatches)
415
+ # logger.debug(
416
+ # f"HMM@p={precision}: {sample.strain} hmm_cost={cost} path={sample.path}"
417
+ # )
418
+ # if cost > cost_threshold:
419
+ # sample.path.clear()
420
+ # sample.mutations.clear()
421
+ # rerun_batch.append(sample)
422
+
423
+ rerun_batch = samples
424
+ precision = 12
425
+ logger .info (f"Rerunning batch of { len (rerun_batch )} at p={ precision } " )
426
+ match_tsinfer (
427
+ samples = rerun_batch ,
428
+ ts = base_ts ,
429
+ num_mismatches = num_mismatches ,
430
+ precision = 12 ,
431
+ num_threads = num_threads ,
432
+ show_progress = show_progress ,
433
+ )
434
+ # for sample in samples_to_rerun:
435
+ # hmm_cost = sample.get_hmm_cost(num_mismatches)
436
+ # # print(f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}")
437
+ # logger.debug(
438
+ # f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}"
439
+ # )
411
440
412
441
# remaining_samples = samples
413
442
# for cost, precision in [(0, 0), (1, 2)]: #, (2, 3)]:
@@ -433,24 +462,8 @@ def match_samples(
433
462
# samples_to_rerun.append(sample)
434
463
# remaining_samples = samples_to_rerun
435
464
436
- samples_to_rerun = samples
437
- match_tsinfer (
438
- samples = samples_to_rerun ,
439
- ts = base_ts ,
440
- num_mismatches = num_mismatches ,
441
- precision = 12 ,
442
- num_threads = num_threads ,
443
- show_progress = show_progress ,
444
- mirror_coordinates = mirror_coordinates ,
445
- )
446
- for sample in samples_to_rerun :
447
- hmm_cost = sample .get_hmm_cost (num_mismatches )
448
- # print(f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}")
449
- logger .debug (
450
- f"Final HMM pass:{ sample .strain } hmm_cost={ hmm_cost } path={ sample .path } "
451
- )
452
-
453
- match_db .add (samples , date , num_mismatches )
465
+ # Return in sorted order so that results are deterministic
466
+ return sorted (samples , key = lambda s : s .strain )
454
467
455
468
456
469
def check_base_ts (ts ):
@@ -526,7 +539,6 @@ def extend(
526
539
min_group_size = 10
527
540
528
541
# TMP
529
- precision = 6
530
542
check_base_ts (base_ts )
531
543
logger .info (
532
544
f"Extend { date } ; ts:nodes={ base_ts .num_nodes } ;samples={ base_ts .num_samples } ;"
@@ -549,17 +561,16 @@ def extend(
549
561
f"Got alignments for { len (samples )} of { len (metadata_matches )} in metadata"
550
562
)
551
563
552
- match_samples (
564
+ samples = match_samples (
553
565
date ,
554
566
samples ,
555
567
base_ts = base_ts ,
556
- match_db = match_db ,
557
568
num_mismatches = num_mismatches ,
558
569
show_progress = show_progress ,
559
570
num_threads = num_threads ,
560
- precision = precision ,
561
571
)
562
572
573
+ match_db .add (samples , date , num_mismatches )
563
574
match_db .create_mask_table (base_ts )
564
575
ts = increment_time (date , base_ts )
565
576
@@ -810,23 +821,21 @@ def solve_num_mismatches(ts, k):
810
821
NOTE! This is NOT taking into account the spatial distance along
811
822
the genome, and so is not a very good model in some ways.
812
823
"""
824
+ # We can match against any node in tsinfer
813
825
m = ts .num_sites
814
- n = ts .num_nodes # We can match against any node in tsinfer
815
- if k == 0 :
816
- # Pathological things happen when k=0
817
- r = 1e-3
818
- mu = 1e-20
819
- else :
820
- # NOTE: the magnitude of mu matters because it puts a limit
821
- # on how low we can push the HMM precision. We should be able to solve
822
- # for the optimal value of this parameter such that the magnitude of the
823
- # values within the HMM are as large as possible (so that we can truncate
824
- # usefully).
825
- mu = 1e-2
826
- denom = (1 - mu ) ** k + (n - 1 ) * mu ** k
827
- r = n * mu ** k / denom
828
- assert mu < 0.5
829
- assert r < 0.5
826
+ n = ts .num_nodes
827
+ # values of k <= 1 are not relevant for SC2 and lead to awkward corner cases
828
+ assert k > 1
829
+
830
+ # NOTE: the magnitude of mu matters because it puts a limit
831
+ # on how low we can push the HMM precision. We should be able to solve
832
+ # for the optimal value of this parameter such that the magnitude of the
833
+ # values within the HMM are as large as possible (so that we can truncate
834
+ # usefully).
835
+ # mu = 1e-2
836
+ mu = 0.125
837
+ denom = (1 - mu ) ** k + (n - 1 ) * mu ** k
838
+ r = n * mu ** k / denom
830
839
831
840
# Add a little bit of extra mass for recombination so that we deterministically
832
841
# chose to recombine over k mutations
0 commit comments