Skip to content

Commit 86b8008

Browse files
Work in progress
1 parent 375e202 commit 86b8008

File tree

1 file changed

+35
-53
lines changed

1 file changed

+35
-53
lines changed

sc2ts/inference.py

Lines changed: 35 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -398,72 +398,54 @@ def match_samples(
398398
num_threads=None,
399399
):
400400
# First pass, compute the matches at precision=0.
401-
# precision = 0
402-
# match_tsinfer(
403-
# samples=samples,
404-
# ts=base_ts,
405-
# num_mismatches=num_mismatches,
406-
# precision=precision,
407-
# num_threads=num_threads,
408-
# show_progress=show_progress,
409-
# )
410-
411-
# cost_threshold = 1
412-
# rerun_batch = []
413-
# for sample in samples:
414-
# cost = sample.get_hmm_cost(num_mismatches)
415-
# logger.debug(
416-
# f"HMM@p={precision}: {sample.strain} hmm_cost={cost} path={sample.path}"
417-
# )
418-
# if cost > cost_threshold:
419-
# sample.path.clear()
420-
# sample.mutations.clear()
421-
# rerun_batch.append(sample)
422-
423-
rerun_batch = samples
401+
run_batch = samples
402+
403+
# WIP
404+
for precision, cost_threshold in [(0, 0), (1, 1)]: # , (2, 2)]:
405+
logger.info(f"Running batch of {len(run_batch)} at p={precision}")
406+
match_tsinfer(
407+
samples=run_batch,
408+
ts=base_ts,
409+
num_mismatches=num_mismatches,
410+
precision=precision,
411+
num_threads=num_threads,
412+
show_progress=show_progress,
413+
)
414+
415+
exceeding_threshold = []
416+
for sample in run_batch:
417+
cost = sample.get_hmm_cost(num_mismatches)
418+
logger.debug(
419+
f"HMM@p={precision}: {sample.strain} hmm_cost={cost} path={sample.path}"
420+
)
421+
if cost > cost_threshold:
422+
sample.path.clear()
423+
sample.mutations.clear()
424+
exceeding_threshold.append(sample)
425+
426+
num_matches_found = len(run_batch) - len(exceeding_threshold)
427+
logger.info(
428+
f"{num_matches_found} final matches for found p={precision}; "
429+
f"{len(exceeding_threshold)} remain"
430+
)
431+
run_batch = exceeding_threshold
432+
424433
precision = 6
425-
logger.info(f"Rerunning batch of {len(rerun_batch)} at p={precision}")
434+
logger.info(f"Running final batch of {len(run_batch)} at p={precision}")
426435
match_tsinfer(
427-
samples=rerun_batch,
436+
samples=run_batch,
428437
ts=base_ts,
429438
num_mismatches=num_mismatches,
430439
precision=precision,
431440
num_threads=num_threads,
432441
show_progress=show_progress,
433442
)
434-
for sample in rerun_batch:
443+
for sample in run_batch:
435444
hmm_cost = sample.get_hmm_cost(num_mismatches)
436445
# print(f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}")
437446
logger.debug(
438447
f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}"
439448
)
440-
441-
# remaining_samples = samples
442-
# for cost, precision in [(0, 0), (1, 2)]: #, (2, 3)]:
443-
# match_tsinfer(
444-
# samples=remaining_samples,
445-
# ts=base_ts,
446-
# num_mismatches=num_mismatches,
447-
# precision=precision,
448-
# num_threads=num_threads,
449-
# show_progress=show_progress,
450-
# mirror_coordinates=mirror_coordinates,
451-
# )
452-
# samples_to_rerun = []
453-
# for sample in remaining_samples:
454-
# hmm_cost = sample.get_hmm_cost(num_mismatches)
455-
# # print(f"HMM@p={precision}: {sample.strain} hmm_cost={hmm_cost} path={sample.path}")
456-
# logger.debug(
457-
# f"HMM@p={precision}: {sample.strain} hmm_cost={hmm_cost} path={sample.path}"
458-
# )
459-
# if hmm_cost > cost:
460-
# sample.path.clear()
461-
# sample.mutations.clear()
462-
# samples_to_rerun.append(sample)
463-
# remaining_samples = samples_to_rerun
464-
465-
# Return in sorted order so that results are deterministic
466-
# return sorted(samples, key=lambda s: s.strain)
467449
return samples
468450

469451

0 commit comments

Comments
 (0)