77from os .path import join , basename , isfile
88from collections import Counter
99import random
10- from typing import Optional , Sequence
10+ from typing import Optional , Sequence , Union
1111
1212import pandas as pd
1313from Bio .SeqIO .PdbIO import AtomIterator
@@ -204,7 +204,16 @@ def load_db_variants(db_fn: str, pdb_fn: str) -> set:
204204 return db
205205
206206
207- def gen_subvariants_vlist (seq , target_num , min_num_subs , max_num_subs , chars , seq_idxs , rng , pdb_fn , db_fn = None ):
207+ def gen_subvariants_vlist (seq : str ,
208+ target_num : int ,
209+ min_num_subs : int ,
210+ max_num_subs : int ,
211+ chars : Union [list [str ], tuple [str , ...]],
212+ seq_idxs : Sequence [int ],
213+ rng : np .random .Generator ,
214+ db_pdb_fn : str ,
215+ db_fn : Optional [str ] = None ):
216+
208217 # max_num_subs determines the maximum number of substitutions for the main variants
209218 # min_num_subs determines the minimum number of substitutions for subvariants
210219 # so for example, if min_num_subs is 2, then this function won't generate subvariants with 1 substitution
@@ -215,7 +224,7 @@ def gen_subvariants_vlist(seq, target_num, min_num_subs, max_num_subs, chars, se
215224 # then this will still return the ones that aren't in the DB.
216225 db = None
217226 if db_fn is not None :
218- db = load_db_variants (db_fn , pdb_fn )
227+ db = load_db_variants (db_fn , db_pdb_fn )
219228
220229 # using a set and a list to maintain the order
221230 # this is slower and uses 2x the memory, but the final variant list will be ordered
@@ -258,7 +267,7 @@ def gen_subvariants_vlist(seq, target_num, min_num_subs, max_num_subs, chars, se
258267
259268
260269def gen_subvariants_sample (db_fn : str ,
261- pdb_fn : str ,
270+ db_pdb_fn : str ,
262271 target_num : int ,
263272 min_num_subs : int ,
264273 max_num_subs : int ,
@@ -271,17 +280,31 @@ def gen_subvariants_sample(db_fn: str,
271280 """
272281
273282 # load all the variants for the given pdb_fn from the database
274- variants = list (load_db_variants (db_fn , pdb_fn ))
275- df = pd .DataFrame ({"variant" : variants , "num_mutations" : [len (v .split ("," )) for v in variants ]})
283+ db_variants_set = load_db_variants (db_fn , db_pdb_fn )
284+ db_variants = list (db_variants_set )
285+
286+ df = pd .DataFrame ({"variant" : db_variants , "num_mutations" : [len (v .split ("," )) for v in db_variants ]})
276287
277288 # iteratively sample variants with max_num_subs
278289 df_max_subs = df [df ["num_mutations" ] == max_num_subs ]
279290
291+ # using a set and a list to maintain the order
280292 variants_set = set ()
281293 variants_list = []
294+
295+ # create the list of max_subs variants to sample from, basically just shuffle df_max_subs
296+ df_max_subs = df_max_subs .sample (frac = 1 , random_state = rng .bit_generator ).reset_index (drop = True )
297+ main_v_index = 0
298+
282299 while len (variants_list ) < target_num :
283- # sample a random variant from df_max_subs dataframe using rng
284- main_v = df_max_subs .sample (n = 1 , random_state = rng )[0 ]["variant" ]
300+
301+ # sample the next variant from df_max_subs
302+ if main_v_index > len (df_max_subs ) - 1 :
303+ # there aren't enough df_max_subs variants to sample from to put together the target number of variants
304+ raise ValueError ("Not enough {}-variants to sample from" .format (max_num_subs ))
305+
306+ main_v = df_max_subs .iloc [main_v_index ]["variant" ]
307+ main_v_index += 1
285308
286309 # generate all subvariants for this variant
287310 av = [main_v ]
@@ -299,7 +322,7 @@ def gen_subvariants_sample(db_fn: str,
299322 print ("Generated variant already in set: {}" .format (v ))
300323
301324 variant_in_db = True
302- if v not in df [ "variant" ] :
325+ if v not in db_variants_set :
303326 variant_in_db = False
304327 print ("Generated subvariant does NOT exist in database, skipping: {}" .format (v ))
305328
@@ -352,6 +375,7 @@ def gen_all_main(pdb_fn: str,
352375 out_dir : Optional [str ],
353376 db_fn : Optional [str ] = None ,
354377 db_mode : Optional [str ] = None ,
378+ db_pdb_fn : Optional [str ] = None ,
355379 ignore_existing_out_file : bool = False ):
356380 """
357381 Generate all variants for a single PDB file
@@ -366,16 +390,27 @@ def gen_all_main(pdb_fn: str,
366390 if db_mode not in [None , "filter" , "sample" ]:
367391 raise ValueError ("db_mode must be None, 'filter' or 'sample'" )
368392
393+ # if db_pdb_fn is None, set it equal to pdb_fn
394+ # note db_pdb_fn will only be used if db_mode is "filter" or "sample"
395+ if db_pdb_fn is None :
396+ db_pdb_fn = pdb_fn
397+
369398 # if db_fn is specified, we need to have a hash of the database in the filename
370399 db_hash = hash_db (db_fn )
371400
372401 # determine the output filename
373402 if db_mode == "sample" :
374403 # only sampling variants from the given database
375- out_fn = "{}_all_NS-{}_sampled-DB-{}.txt" .format (basename (pdb_fn )[:- 4 ], "," .join (map (str , num_subs_list )), db_hash )
404+ out_fn = "{}_all_NS-{}_sampled-DB-{}-{}.txt" .format (basename (pdb_fn )[:- 4 ],
405+ "," .join (map (str , num_subs_list )),
406+ db_hash ,
407+ basename (db_pdb_fn )[:- 4 ])
376408 elif db_mode == "filter" :
377409 # excluding variants that are in the database
378- out_fn = "{}_all_NS-{}_filtered-DB-{}.txt" .format (basename (pdb_fn )[:- 4 ], "," .join (map (str , num_subs_list )), db_hash )
410+ out_fn = "{}_all_NS-{}_filtered-DB-{}-{}.txt" .format (basename (pdb_fn )[:- 4 ],
411+ "," .join (map (str , num_subs_list )),
412+ db_hash ,
413+ basename (db_pdb_fn )[:- 4 ])
379414 else :
380415 # no database specified, just generate all variants
381416 out_fn = "{}_all_NS-{}.txt" .format (basename (pdb_fn )[:- 4 ], "," .join (map (str , num_subs_list )))
@@ -404,13 +439,13 @@ def gen_all_main(pdb_fn: str,
404439
405440 if db_mode == "sample" :
406441 # database sample mode, only include variants that are in the database
407- db_variants = load_db_variants (db_fn , pdb_fn )
442+ db_variants = load_db_variants (db_fn , db_pdb_fn )
408443 variants = [v for v in variants if v in db_variants ]
409444
410445 # database filter mode, exclude any variants that are in the database
411446 elif db_mode == "filter" :
412447 # filter out variants already in the database if one is provided
413- db_variants = load_db_variants (db_fn , pdb_fn )
448+ db_variants = load_db_variants (db_fn , db_pdb_fn )
414449 variants = [v for v in variants if v not in db_variants ]
415450
416451 print_variant_info (variants )
@@ -444,26 +479,34 @@ def gen_subvariants_main(pdb_fn: str,
444479 seed : int ,
445480 out_dir : str ,
446481 db_fn : Optional [str ] = None ,
447- db_mode : Optional [str ] = None ):
482+ db_mode : Optional [str ] = None ,
483+ db_pdb_fn : Optional [str ] = None ):
448484
449485 if (db_mode is None ) ^ (db_fn is None ):
450486 raise ValueError ("Both db_fn and db_mode should be specified or left as None" )
451487
452488 if db_mode is not None and db_mode not in ["filter" , "sample" ]:
453489 raise ValueError ("db_mode must be None, 'filter' or 'sample'" )
454490
491+ # db_pdb_fn is used to query the database for database modes 'filter' and 'sample'
492+ # if None, then use the same PDB file for which we are generating variants
493+ if db_pdb_fn is None :
494+ db_pdb_fn = pdb_fn
495+
455496 # if db_fn is specified, we need to have a hash of the database in the filename
456497 db_hash = hash_db (db_fn )
457498
458499 # determine the output filename
459- out_fn_template = "{}_subvariants_TN-{}_MAXS-{}_MINS-{}_{}-DB-{}_RS-{}.txt"
500+ # todo: hard for the filename can't communicate all the provenance...maybe have additional metadata file?
501+ out_fn_template = "{}_subvariants_TN-{}_MAXS-{}_MINS-{}_{}-DB-{}-{}_RS-{}.txt"
460502 out_fn_template_args = [
461503 basename (pdb_fn ).rsplit ('.' , 1 )[0 ],
462504 human_format (target_num ),
463505 max_num_subs ,
464506 min_num_subs ,
465507 "sampled" if db_mode == "sample" else "filtered" ,
466508 db_hash ,
509+ basename (db_pdb_fn ).rsplit ('.' , 1 )[0 ],
467510 seed
468511 ]
469512 out_fn = out_fn_template .format (* out_fn_template_args )
@@ -481,10 +524,10 @@ def gen_subvariants_main(pdb_fn: str,
481524 # just a type hint because if db_mode is "sample" then the error checking ensures db_fn is str
482525 db_fn : str
483526 # sampling needs a special function that selects the main variant from the database
484- variants = gen_subvariants_sample (db_fn , pdb_fn , target_num , min_num_subs , max_num_subs , rng )
527+ variants = gen_subvariants_sample (db_fn , db_pdb_fn , target_num , min_num_subs , max_num_subs , rng )
485528 elif db_mode == "filter" or db_mode is None :
486529 # this can handle db_fn being None or db_mode being "filter"
487- variants = gen_subvariants_vlist (seq , target_num , min_num_subs , max_num_subs , chars , seq_idxs , rng , pdb_fn , db_fn )
530+ variants = gen_subvariants_vlist (seq , target_num , min_num_subs , max_num_subs , chars , seq_idxs , rng , db_pdb_fn , db_fn )
488531 else :
489532 raise ValueError ("db_mode must be None, 'filter' or 'sample'" )
490533
@@ -546,7 +589,8 @@ def main(args):
546589 seed = seed ,
547590 out_dir = args .out_dir ,
548591 db_fn = args .db_fn ,
549- db_mode = args .db_mode )
592+ db_mode = args .db_mode ,
593+ db_pdb_fn = args .db_pdb_fn )
550594
551595 elif args .method == "random" :
552596 gen_random_main (pdb_fn , seq , seq_idxs , chars ,
@@ -561,6 +605,7 @@ def main(args):
561605 out_dir = args .out_dir ,
562606 db_fn = args .db_fn ,
563607 db_mode = args .db_mode ,
608+ db_pdb_fn = args .db_pdb_fn ,
564609 ignore_existing_out_file = args .ignore_existing_out_file )
565610
566611
@@ -613,6 +658,11 @@ def main(args):
613658 "if 'sample', only include variants that are in the given database" ,
614659 default = None ,
615660 choices = ["filter" , "sample" ])
661+ parser .add_argument ("--db_pdb_fn" ,
662+ type = str ,
663+ help = "the PDB file to use for the database. if None, use the same PDB file as the one "
664+ "being used to generate variants" ,
665+ default = None )
616666 parser .add_argument ("--ignore_existing_out_file" ,
617667 action = "store_true" ,
618668 default = False ,
0 commit comments