4
4
import datetime
5
5
import dataclasses
6
6
import collections
7
- import json
8
7
import pickle
9
8
import os
10
9
import sqlite3
@@ -77,7 +76,6 @@ def add(self, samples, date, num_mismatches):
77
76
data = []
78
77
hmm_cost = np .zeros (len (samples ))
79
78
for j , sample in enumerate (samples ):
80
- d = sample .asdict ()
81
79
assert sample .date == date
82
80
# FIXME we want to be more selective about what we're storing
83
81
# here, as we're including the alignment too.
@@ -110,12 +108,7 @@ def create_mask_table(self, ts):
110
108
# the rows in the DB that *are* in the ts, as a separate
111
109
# transaction once we know that the trees have been saved to disk.
112
110
logger .info ("Loading used samples into DB" )
113
- # TODO this is inefficient - need some logging to see how much time
114
- # we're spending here.
115
- # One thing we can do is to store the list of strain IDs in the
116
- # tree sequence top-level metadata, which we could even store using
117
- # some numpy tricks to make it fast.
118
- samples = [(ts .node (u ).metadata ["strain" ],) for u in ts .samples ()]
111
+ samples = [(strain ,) for strain in ts .metadata ["sc2ts" ]["samples_strain" ]]
119
112
logger .debug (f"Got { len (samples )} from ts" )
120
113
with self .conn :
121
114
self .conn .execute ("DROP TABLE IF EXISTS used_samples" )
@@ -212,26 +205,35 @@ def mirror_ts_coordinates(ts):
212
205
return tables .tree_sequence ()
213
206
214
207
215
- def initial_ts ():
208
+ def initial_ts (additional_problematic_sites = list () ):
216
209
reference = core .get_reference_sequence ()
217
210
L = core .REFERENCE_SEQUENCE_LENGTH
218
211
assert L == len (reference )
219
- problematic_sites = set (core .get_problematic_sites ())
212
+ problematic_sites = set (core .get_problematic_sites ()) | set ( additional_problematic_sites )
220
213
221
214
tables = tskit .TableCollection (L )
222
215
tables .time_units = core .TIME_UNITS
216
+
217
+ # TODO add known fields to the schemas and document them.
218
+
223
219
base_schema = tskit .MetadataSchema .permissive_json ().schema
224
220
tables .reference_sequence .metadata_schema = tskit .MetadataSchema (base_schema )
225
221
tables .reference_sequence .metadata = {
226
222
"genbank_id" : core .REFERENCE_GENBANK ,
227
- "notes" : "X prepended to alignment to map from 1-based to 0-based coordinates"
223
+ "notes" : "X prepended to alignment to map from 1-based to 0-based coordinates" ,
228
224
}
229
225
tables .reference_sequence .data = reference
230
226
231
227
tables .metadata_schema = tskit .MetadataSchema (base_schema )
232
-
233
228
# TODO gene annotations to top level
234
- # TODO add known fields to the schemas and document them.
229
+ tables .metadata = {
230
+ "sc2ts" : {
231
+ "date" : core .REFERENCE_DATE ,
232
+ "samples_strain" : [core .REFERENCE_STRAIN ],
233
+ "additional_problematic_sites" : additional_problematic_sites ,
234
+ }
235
+ }
236
+
235
237
tables .nodes .metadata_schema = tskit .MetadataSchema (base_schema )
236
238
tables .sites .metadata_schema = tskit .MetadataSchema (base_schema )
237
239
tables .mutations .metadata_schema = tskit .MetadataSchema (base_schema )
@@ -245,7 +247,9 @@ def initial_ts():
245
247
# in later versions when we remove the dependence on tsinfer.
246
248
tables .nodes .add_row (time = 1 , metadata = {"strain" : "Vestigial_ignore" })
247
249
tables .nodes .add_row (
248
- flags = tskit .NODE_IS_SAMPLE , time = 0 , metadata = {"strain" : core .REFERENCE_STRAIN , "date" : core .REFERENCE_DATE }
250
+ flags = tskit .NODE_IS_SAMPLE ,
251
+ time = 0 ,
252
+ metadata = {"strain" : core .REFERENCE_STRAIN , "date" : core .REFERENCE_DATE },
249
253
)
250
254
tables .edges .add_row (0 , L , 0 , 1 )
251
255
return tables .tree_sequence ()
@@ -255,42 +259,8 @@ def parse_date(date):
255
259
return datetime .datetime .fromisoformat (date )
256
260
257
261
258
- def filter_samples (samples , alignment_store , max_submission_delay = None ):
259
- if max_submission_delay is None :
260
- max_submission_delay = 10 ** 8 # Arbitrary large number of days.
261
- not_in_store = 0
262
- num_filtered = 0
263
- ret = []
264
- for sample in samples :
265
- if sample .strain not in alignment_store :
266
- logger .warn (f"{ sample .strain } not in alignment store" )
267
- not_in_store += 1
268
- continue
269
- if sample .submission_delay < max_submission_delay :
270
- ret .append (sample )
271
- else :
272
- num_filtered += 1
273
- if not_in_store == len (samples ):
274
- raise ValueError ("All samples for day missing" )
275
- logger .info (
276
- f"Filtered { num_filtered } samples with "
277
- f"max_submission_delay >= { max_submission_delay } "
278
- )
279
- return ret
280
-
281
-
282
262
def last_date (ts ):
283
- if ts .num_samples == 0 :
284
- # Special case for the initial ts which contains the
285
- # reference but not as a sample
286
- u = ts .num_nodes - 1
287
- node = ts .node (u )
288
- # assert node.time == 0
289
- return parse_date (node .metadata ["date" ])
290
- else :
291
- samples = ts .samples ()
292
- samples_t0 = samples [ts .nodes_time [samples ] == 0 ]
293
- return max ([parse_date (ts .node (u ).metadata ["date" ]) for u in samples_t0 ])
263
+ return parse_date (ts .metadata ["sc2ts" ]["date" ])
294
264
295
265
296
266
def increment_time (date , ts ):
@@ -343,7 +313,7 @@ def validate(ts, alignment_store, show_progress=False):
343
313
Check that all the samples in the specified tree sequence are correctly
344
314
representing the original alignments.
345
315
"""
346
- samples = ts .samples ()
316
+ samples = ts .samples ()[ 1 :]
347
317
chunk_size = 10 ** 3
348
318
offset = 0
349
319
num_chunks = ts .num_samples // chunk_size
@@ -562,6 +532,14 @@ def match_samples(
562
532
match_db .add (samples , date , num_mismatches )
563
533
564
534
535
+ def check_base_ts (ts ):
536
+ md = ts .metadata
537
+ assert "sc2ts" in md
538
+ sc2ts_md = md ["sc2ts" ]
539
+ assert "date" in sc2ts_md
540
+ assert len (sc2ts_md ["samples_strain" ]) == ts .num_samples
541
+
542
+
565
543
def extend (
566
544
* ,
567
545
alignment_store ,
@@ -579,9 +557,10 @@ def extend(
579
557
precision = None ,
580
558
rng = None ,
581
559
):
560
+ check_base_ts (base_ts )
582
561
logger .info (
583
- f"Extend { date } ; ts:nodes={ base_ts .num_nodes } ;edges ={ base_ts .num_edges } ;"
584
- f"mutations={ base_ts .num_mutations } "
562
+ f"Extend { date } ; ts:nodes={ base_ts .num_nodes } ;samples ={ base_ts .num_samples } ;"
563
+ f"mutations={ base_ts .num_mutations } ;date= { base_ts . metadata [ 'sc2ts' ][ 'date' ] } "
585
564
)
586
565
# TODO not sure whether we'll keep these params. Making sure they're not
587
566
# used for now
@@ -640,7 +619,21 @@ def extend(
640
619
min_group_size = min_group_size ,
641
620
show_progress = show_progress ,
642
621
)
643
- return ts
622
+ return update_top_level_metadata (ts , date )
623
+
624
+
625
+ def update_top_level_metadata (ts , date ):
626
+ tables = ts .dump_tables ()
627
+ md = tables .metadata
628
+ md ["sc2ts" ]["date" ] = date
629
+ samples_strain = md ["sc2ts" ]["samples_strain" ]
630
+ new_samples = ts .samples ()[len (samples_strain ) :]
631
+ for u in new_samples :
632
+ node = ts .node (u )
633
+ samples_strain .append (node .metadata ["strain" ])
634
+ md ["sc2ts" ]["samples_strain" ] = samples_strain
635
+ tables .metadata = md
636
+ return tables .tree_sequence ()
644
637
645
638
646
639
def match_path_ts (samples , ts , path , reversions ):
0 commit comments