@@ -398,72 +398,54 @@ def match_samples(
398
398
num_threads = None ,
399
399
):
400
400
# First pass, compute the matches at precision=0.
401
- # precision = 0
402
- # match_tsinfer(
403
- # samples=samples,
404
- # ts=base_ts,
405
- # num_mismatches=num_mismatches,
406
- # precision=precision,
407
- # num_threads=num_threads,
408
- # show_progress=show_progress,
409
- # )
410
-
411
- # cost_threshold = 1
412
- # rerun_batch = []
413
- # for sample in samples:
414
- # cost = sample.get_hmm_cost(num_mismatches)
415
- # logger.debug(
416
- # f"HMM@p={precision}: {sample.strain} hmm_cost={cost} path={sample.path}"
417
- # )
418
- # if cost > cost_threshold:
419
- # sample.path.clear()
420
- # sample.mutations.clear()
421
- # rerun_batch.append(sample)
422
-
423
- rerun_batch = samples
401
+ run_batch = samples
402
+
403
+ # WIP
404
+ for precision , cost_threshold in [(0 , 0 ), (1 , 1 )]: # , (2, 2)]:
405
+ logger .info (f"Running batch of { len (run_batch )} at p={ precision } " )
406
+ match_tsinfer (
407
+ samples = run_batch ,
408
+ ts = base_ts ,
409
+ num_mismatches = num_mismatches ,
410
+ precision = precision ,
411
+ num_threads = num_threads ,
412
+ show_progress = show_progress ,
413
+ )
414
+
415
+ exceeding_threshold = []
416
+ for sample in run_batch :
417
+ cost = sample .get_hmm_cost (num_mismatches )
418
+ logger .debug (
419
+ f"HMM@p={ precision } : { sample .strain } hmm_cost={ cost } path={ sample .path } "
420
+ )
421
+ if cost > cost_threshold :
422
+ sample .path .clear ()
423
+ sample .mutations .clear ()
424
+ exceeding_threshold .append (sample )
425
+
426
+ num_matches_found = len (run_batch ) - len (exceeding_threshold )
427
+ logger .info (
428
+ f"{ num_matches_found } final matches for found p={ precision } ; "
429
+ f"{ len (exceeding_threshold )} remain"
430
+ )
431
+ run_batch = exceeding_threshold
432
+
424
433
precision = 6
425
- logger .info (f"Rerunning batch of { len (rerun_batch )} at p={ precision } " )
434
+ logger .info (f"Running final batch of { len (run_batch )} at p={ precision } " )
426
435
match_tsinfer (
427
- samples = rerun_batch ,
436
+ samples = run_batch ,
428
437
ts = base_ts ,
429
438
num_mismatches = num_mismatches ,
430
439
precision = precision ,
431
440
num_threads = num_threads ,
432
441
show_progress = show_progress ,
433
442
)
434
- for sample in rerun_batch :
443
+ for sample in run_batch :
435
444
hmm_cost = sample .get_hmm_cost (num_mismatches )
436
445
# print(f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}")
437
446
logger .debug (
438
447
f"Final HMM pass:{ sample .strain } hmm_cost={ hmm_cost } path={ sample .path } "
439
448
)
440
-
441
- # remaining_samples = samples
442
- # for cost, precision in [(0, 0), (1, 2)]: #, (2, 3)]:
443
- # match_tsinfer(
444
- # samples=remaining_samples,
445
- # ts=base_ts,
446
- # num_mismatches=num_mismatches,
447
- # precision=precision,
448
- # num_threads=num_threads,
449
- # show_progress=show_progress,
450
- # mirror_coordinates=mirror_coordinates,
451
- # )
452
- # samples_to_rerun = []
453
- # for sample in remaining_samples:
454
- # hmm_cost = sample.get_hmm_cost(num_mismatches)
455
- # # print(f"HMM@p={precision}: {sample.strain} hmm_cost={hmm_cost} path={sample.path}")
456
- # logger.debug(
457
- # f"HMM@p={precision}: {sample.strain} hmm_cost={hmm_cost} path={sample.path}"
458
- # )
459
- # if hmm_cost > cost:
460
- # sample.path.clear()
461
- # sample.mutations.clear()
462
- # samples_to_rerun.append(sample)
463
- # remaining_samples = samples_to_rerun
464
-
465
- # Return in sorted order so that results are deterministic
466
- # return sorted(samples, key=lambda s: s.strain)
467
449
return samples
468
450
469
451
0 commit comments