@@ -222,6 +222,11 @@ def submission_date(self):
222
222
def submission_delay (self ):
223
223
return (self .submission_date - self .date ).days
224
224
225
+ def get_hmm_cost (self , num_mismatches ):
226
+ # Note that Recombinant objects have total_cost.
227
+ # This bit of code is sort of repeated.
228
+ return num_mismatches * (len (self .path ) - 1 ) + len (self .mutations )
229
+
225
230
def asdict (self ):
226
231
return {
227
232
"strain" : self .strain ,
@@ -247,6 +252,7 @@ def daily_extend(
247
252
metadata_db ,
248
253
base_ts ,
249
254
num_mismatches = None ,
255
+ max_hmm_cost = None ,
250
256
show_progress = False ,
251
257
max_submission_delay = None ,
252
258
max_daily_samples = None ,
@@ -263,6 +269,7 @@ def daily_extend(
263
269
date = date ,
264
270
base_ts = last_ts ,
265
271
num_mismatches = num_mismatches ,
272
+ max_hmm_cost = max_hmm_cost ,
266
273
show_progress = show_progress ,
267
274
max_submission_delay = max_submission_delay ,
268
275
max_daily_samples = max_daily_samples ,
@@ -340,14 +347,14 @@ def extend(
340
347
date ,
341
348
base_ts ,
342
349
num_mismatches = None ,
350
+ max_hmm_cost = None ,
343
351
show_progress = False ,
344
352
max_submission_delay = None ,
345
353
max_daily_samples = None ,
346
354
num_threads = None ,
347
355
precision = None ,
348
356
rng = None ,
349
357
):
350
-
351
358
date_samples = [Sample (md ) for md in metadata_db .get (date )]
352
359
samples = filter_samples (date_samples , alignment_store , max_submission_delay )
353
360
@@ -361,6 +368,7 @@ def extend(
361
368
362
369
logger .info (f"Got { len (samples )} samples" )
363
370
371
+ # Note num_mismatches is assigned a default value in match_tsinfer.
364
372
samples = match (
365
373
samples ,
366
374
alignment_store = alignment_store ,
@@ -371,7 +379,15 @@ def extend(
371
379
precision = precision ,
372
380
)
373
381
ts = increment_time (date , base_ts )
374
- return add_matching_results (samples , ts , date , show_progress )
382
+
383
+ return add_matching_results (
384
+ samples = samples ,
385
+ ts = ts ,
386
+ date = date ,
387
+ num_mismatches = num_mismatches ,
388
+ max_hmm_cost = max_hmm_cost ,
389
+ show_progress = show_progress ,
390
+ )
375
391
376
392
377
393
def match_path_ts (samples , ts , path , reversions ):
@@ -394,7 +410,7 @@ def match_path_ts(samples, ts, path, reversions):
394
410
"qc" : sample .alignment_qc ,
395
411
"path" : [x .asdict () for x in sample .path ],
396
412
"mutations" : [x .asdict () for x in sample .mutations ],
397
- }
413
+ },
398
414
}
399
415
node_id = tables .nodes .add_row (
400
416
flags = tskit .NODE_IS_SAMPLE , time = 0 , metadata = metadata
@@ -407,7 +423,7 @@ def match_path_ts(samples, ts, path, reversions):
407
423
408
424
# Now add the mutations
409
425
for node_id , sample in enumerate (samples , first_sample ):
410
- #metadata = {**sample.metadata, "sc2ts_qc": sample.alignment_qc}
426
+ # metadata = {**sample.metadata, "sc2ts_qc": sample.alignment_qc}
411
427
for mut in sample .mutations :
412
428
tables .mutations .add_row (
413
429
site = site_id_map [mut .site_id ],
@@ -420,7 +436,16 @@ def match_path_ts(samples, ts, path, reversions):
420
436
# print(tables)
421
437
422
438
423
- def add_matching_results (samples , ts , date , show_progress = False ):
439
+ def add_matching_results (
440
+ samples , ts , date , num_mismatches , max_hmm_cost , show_progress = False
441
+ ):
442
+ if num_mismatches is None :
443
+ # Note that this is the default assigned in match_tsinfer.
444
+ num_mismatches = 1e3
445
+
446
+ if max_hmm_cost is None :
447
+ # By default, arbitraily high.
448
+ max_hmm_cost = 1e6
424
449
425
450
# Group matches by path and set of reversion mutations
426
451
grouped_matches = collections .defaultdict (list )
@@ -435,6 +460,17 @@ def add_matching_results(samples, ts, date, show_progress=False):
435
460
)
436
461
grouped_matches [(path , reversions )].append (sample )
437
462
463
+ # Exclude single samples with "high-HMM cost" attachment paths.
464
+ tmp = {}
465
+ for k , v in grouped_matches .items ():
466
+ if len (v ) == 1 :
467
+ # Exclude sample if it's HMM cost exceeds a maximum.
468
+ sample = v [0 ]
469
+ if sample .get_hmm_cost (num_mismatches ) > max_hmm_cost :
470
+ continue
471
+ tmp [k ] = v
472
+ grouped_matches = tmp
473
+
438
474
tables = ts .dump_tables ()
439
475
logger .info (f"Got { len (grouped_matches )} distinct paths" )
440
476
@@ -981,6 +1017,7 @@ def match_tsinfer(
981
1017
show_progress = False ,
982
1018
mirror_coordinates = False ,
983
1019
):
1020
+ # TODO: Should this default be assigned elsewhere?
984
1021
if num_mismatches is None :
985
1022
# Default to no recombination
986
1023
num_mismatches = 1000
0 commit comments