Skip to content

Commit 1a1812d

Browse files
Merge pull request #235 from jeromekelleher/rescale-ls-params
Rescale ls params and adaptive precision
2 parents a3d3ba8 + d8d9101 commit 1a1812d

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

sc2ts/cli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,11 @@ def setup_logging(verbosity, log_file=None):
7676
# to go to the log to see the traceback, but for production it may
7777
# be better to let daiquiri record the errors as well.
7878
daiquiri.setup(outputs=outputs, set_excepthook=False)
79-
# Only show stuff coming from sc2ts. Sometimes it's handy to look
80-
# at the tsinfer logs too, so we could add an option to set its
81-
# levels
79+
# Only show stuff coming from sc2ts and the relevant bits of tsinfer.
8280
logger = logging.getLogger("sc2ts")
8381
logger.setLevel(log_level)
82+
logger = logging.getLogger("tsinfer.inference")
83+
logger.setLevel(log_level)
8484

8585

8686
# TODO add options to list keys, dump specific alignments etc

sc2ts/inference.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ def preprocess(
443443
logger.warn(f"Zero metadata matches for {date}")
444444
return []
445445

446-
if date.endswith("01-01"):
446+
if date.endswith("12-31"):
447447
logger.warning(f"Skipping {len(metadata_matches)} samples for {date}")
448448
return []
449449

@@ -506,11 +506,37 @@ def match_samples(
506506
samples=samples,
507507
ts=base_ts,
508508
num_mismatches=num_mismatches,
509-
precision=precision,
509+
precision=2,
510510
num_threads=num_threads,
511511
show_progress=show_progress,
512512
mirror_coordinates=mirror_coordinates,
513513
)
514+
samples_to_rerun = []
515+
for sample in samples:
516+
hmm_cost = sample.get_hmm_cost(num_mismatches)
517+
logger.debug(
518+
f"First sketch: {sample.strain} hmm_cost={hmm_cost} path={sample.path}"
519+
)
520+
if hmm_cost >= 2:
521+
sample.path.clear()
522+
sample.mutations.clear()
523+
samples_to_rerun.append(sample)
524+
525+
if len(samples_to_rerun) > 0:
526+
match_tsinfer(
527+
samples=samples_to_rerun,
528+
ts=base_ts,
529+
num_mismatches=num_mismatches,
530+
precision=precision,
531+
num_threads=num_threads,
532+
show_progress=show_progress,
533+
mirror_coordinates=mirror_coordinates,
534+
)
535+
for sample in samples_to_rerun:
536+
hmm_cost = sample.get_hmm_cost(num_mismatches)
537+
logger.debug(
538+
f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}"
539+
)
514540

515541
match_db.add(samples, date, num_mismatches)
516542

@@ -801,14 +827,20 @@ def solve_num_mismatches(ts, k):
801827
r = 1e-3
802828
mu = 1e-20
803829
else:
804-
mu = 1e-6
830+
# NOTE: the magnitude of mu matters because it puts a limit
831+
# on how low we can push the HMM precision. We should be able to solve
832+
# for the optimal value of this parameter such that the magnitude of the
833+
# values within the HMM are as large as possible (so that we can truncate
834+
# usefully).
835+
mu = 1e-3
805836
denom = (1 - mu) ** k + (n - 1) * mu**k
806837
r = n * mu**k / denom
807838
assert mu < 0.5
808839
assert r < 0.5
809840

810-
# Add a tiny bit of extra mass for recombination so that we deterministically
841+
# Add a little bit of extra mass for recombination so that we deterministically
811842
# chose to recombine over k mutations
843+
# NOTE: the magnitude of this value will depend also on mu, see above.
812844
r += r * 0.01
813845
ls_recomb = np.full(m - 1, r)
814846
ls_mismatch = np.full(m, mu)

0 commit comments

Comments
 (0)