@@ -110,12 +110,7 @@ def create_mask_table(self, ts):
110
110
# the rows in the DB that *are* in the ts, as a separate
111
111
# transaction once we know that the trees have been saved to disk.
112
112
logger .info ("Loading used samples into DB" )
113
- # TODO this is inefficient - need some logging to see how much time
114
- # we're spending here.
115
- # One thing we can do is to store the list of strain IDs in the
116
- # tree sequence top-level metadata, which we could even store using
117
- # some numpy tricks to make it fast.
118
- samples = [(ts .node (u ).metadata ["strain" ],) for u in ts .samples ()]
113
+ samples = [(strain ,) for strain in ts .metadata ["sc2ts" ]["samples_strain" ]]
119
114
logger .debug (f"Got { len (samples )} from ts" )
120
115
with self .conn :
121
116
self .conn .execute ("DROP TABLE IF EXISTS used_samples" )
@@ -224,11 +219,17 @@ def initial_ts():
224
219
tables .reference_sequence .metadata_schema = tskit .MetadataSchema (base_schema )
225
220
tables .reference_sequence .metadata = {
226
221
"genbank_id" : core .REFERENCE_GENBANK ,
227
- "notes" : "X prepended to alignment to map from 1-based to 0-based coordinates"
222
+ "notes" : "X prepended to alignment to map from 1-based to 0-based coordinates" ,
228
223
}
229
224
tables .reference_sequence .data = reference
230
225
231
226
tables .metadata_schema = tskit .MetadataSchema (base_schema )
227
+ tables .metadata = {
228
+ "sc2ts" : {
229
+ "date" : core .REFERENCE_DATE ,
230
+ "samples_strain" : [core .REFERENCE_STRAIN ],
231
+ }
232
+ }
232
233
233
234
# TODO gene annotations to top level
234
235
# TODO add known fields to the schemas and document them.
@@ -245,7 +246,9 @@ def initial_ts():
245
246
# in later versions when we remove the dependence on tsinfer.
246
247
tables .nodes .add_row (time = 1 , metadata = {"strain" : "Vestigial_ignore" })
247
248
tables .nodes .add_row (
248
- flags = tskit .NODE_IS_SAMPLE , time = 0 , metadata = {"strain" : core .REFERENCE_STRAIN , "date" : core .REFERENCE_DATE }
249
+ flags = tskit .NODE_IS_SAMPLE ,
250
+ time = 0 ,
251
+ metadata = {"strain" : core .REFERENCE_STRAIN , "date" : core .REFERENCE_DATE },
249
252
)
250
253
tables .edges .add_row (0 , L , 0 , 1 )
251
254
return tables .tree_sequence ()
@@ -255,42 +258,8 @@ def parse_date(date):
255
258
return datetime .datetime .fromisoformat (date )
256
259
257
260
258
- def filter_samples (samples , alignment_store , max_submission_delay = None ):
259
- if max_submission_delay is None :
260
- max_submission_delay = 10 ** 8 # Arbitrary large number of days.
261
- not_in_store = 0
262
- num_filtered = 0
263
- ret = []
264
- for sample in samples :
265
- if sample .strain not in alignment_store :
266
- logger .warn (f"{ sample .strain } not in alignment store" )
267
- not_in_store += 1
268
- continue
269
- if sample .submission_delay < max_submission_delay :
270
- ret .append (sample )
271
- else :
272
- num_filtered += 1
273
- if not_in_store == len (samples ):
274
- raise ValueError ("All samples for day missing" )
275
- logger .info (
276
- f"Filtered { num_filtered } samples with "
277
- f"max_submission_delay >= { max_submission_delay } "
278
- )
279
- return ret
280
-
281
-
282
261
def last_date (ts ):
283
- if ts .num_samples == 0 :
284
- # Special case for the initial ts which contains the
285
- # reference but not as a sample
286
- u = ts .num_nodes - 1
287
- node = ts .node (u )
288
- # assert node.time == 0
289
- return parse_date (node .metadata ["date" ])
290
- else :
291
- samples = ts .samples ()
292
- samples_t0 = samples [ts .nodes_time [samples ] == 0 ]
293
- return max ([parse_date (ts .node (u ).metadata ["date" ]) for u in samples_t0 ])
262
+ return parse_date (ts .metadata ["sc2ts" ]["date" ])
294
263
295
264
296
265
def increment_time (date , ts ):
@@ -562,6 +531,14 @@ def match_samples(
562
531
match_db .add (samples , date , num_mismatches )
563
532
564
533
534
+ def check_base_ts (ts ):
535
+ md = ts .metadata
536
+ assert "sc2ts" in md
537
+ sc2ts_md = md ["sc2ts" ]
538
+ assert "date" in sc2ts_md
539
+ assert len (sc2ts_md ["samples_strain" ]) == ts .num_samples
540
+
541
+
565
542
def extend (
566
543
* ,
567
544
alignment_store ,
@@ -579,9 +556,10 @@ def extend(
579
556
precision = None ,
580
557
rng = None ,
581
558
):
559
+ check_base_ts (base_ts )
582
560
logger .info (
583
- f"Extend { date } ; ts:nodes={ base_ts .num_nodes } ;edges ={ base_ts .num_edges } ;"
584
- f"mutations={ base_ts .num_mutations } "
561
+ f"Extend { date } ; ts:nodes={ base_ts .num_nodes } ;samples ={ base_ts .num_samples } ;"
562
+ f"mutations={ base_ts .num_mutations } ;date= { base_ts . metadata [ 'sc2ts' ][ 'date' ] } "
585
563
)
586
564
# TODO not sure whether we'll keep these params. Making sure they're not
587
565
# used for now
@@ -640,7 +618,21 @@ def extend(
640
618
min_group_size = min_group_size ,
641
619
show_progress = show_progress ,
642
620
)
643
- return ts
621
+ return update_top_level_metadata (ts , date )
622
+
623
+
624
+ def update_top_level_metadata (ts , date ):
625
+ tables = ts .dump_tables ()
626
+ md = tables .metadata
627
+ md ["sc2ts" ]["date" ] = date
628
+ samples_strain = md ["sc2ts" ]["samples_strain" ]
629
+ new_samples = ts .samples ()[len (samples_strain ) :]
630
+ for u in new_samples :
631
+ node = ts .node (u )
632
+ samples_strain .append (node .metadata ["strain" ])
633
+ md ["sc2ts" ]["samples_strain" ] = samples_strain
634
+ tables .metadata = md
635
+ return tables .tree_sequence ()
644
636
645
637
646
638
def match_path_ts (samples , ts , path , reversions ):
0 commit comments