33import logging
44import re
55from operator import attrgetter
6- from typing import Any , BinaryIO , Iterable , Optional , TYPE_CHECKING , Sequence , Literal
6+ from typing import Any , BinaryIO , Iterable , List , Optional , TYPE_CHECKING , Sequence , Literal
77
88from mavedb .models .mapped_variant import MappedVariant
99import numpy as np
@@ -401,12 +401,12 @@ def find_publish_or_private_superseded_score_set_tail(
401401def get_score_set_variants_as_csv (
402402 db : Session ,
403403 score_set : ScoreSet ,
404- data_type : Literal ["scores" , "counts" ],
404+ data_types : List [ Literal ["scores" , "counts" , "clinVar" , "gnomAD" ] ],
405405 start : Optional [int ] = None ,
406406 limit : Optional [int ] = None ,
407407 drop_na_columns : Optional [bool ] = None ,
408- include_custom_columns : bool = True ,
409- include_post_mapped_hgvs : bool = False ,
408+ include_custom_columns : Optional [ bool ] = True ,
409+ include_post_mapped_hgvs : Optional [ bool ] = False ,
410410) -> str :
411411 """
412412 Get the variant data from a score set as a CSV string.
@@ -417,8 +417,8 @@ def get_score_set_variants_as_csv(
417417 The database session to use.
418418 score_set : ScoreSet
419419 The score set to get the variants from.
420- data_type : {' scores', ' counts'}
421- The type of data to get. Either 'scores' or 'counts'.
420+ data_types : List[Literal[" scores", " counts", "clinVar", "gnomAD"]]
421+ The data types to get. Either one of 'scores', 'counts', 'clinVar', 'gnomAD' or some of them .
422422 start : int, optional
423423 The index to start from. If None, starts from the beginning.
424424 limit : int, optional
@@ -437,18 +437,33 @@ def get_score_set_variants_as_csv(
437437 The CSV string containing the variant data.
438438 """
439439 assert type (score_set .dataset_columns ) is dict
440- custom_columns_set = "score_columns" if data_type == "scores" else "count_columns"
441- type_column = "score_data" if data_type == "scores" else "count_data"
442-
440+ custom_columns = {
441+ "scores" : "score_columns" ,
442+ "counts" : "count_columns" ,
443+ }
444+ custom_columns_set = [custom_columns [dt ] for dt in data_types if dt in custom_columns ]
445+ type_to_column = {
446+ "scores" : "score_data" ,
447+ "counts" : "count_data"
448+ }
449+ type_columns = [type_to_column [dt ] for dt in data_types if dt in type_to_column ]
443450 columns = ["accession" , "hgvs_nt" , "hgvs_splice" , "hgvs_pro" ]
444451 if include_post_mapped_hgvs :
445452 columns .append ("post_mapped_hgvs_g" )
446453 columns .append ("post_mapped_hgvs_p" )
447454
448455 if include_custom_columns :
449- custom_columns = [str (x ) for x in list (score_set .dataset_columns .get (custom_columns_set , []))]
450- columns += custom_columns
451- elif data_type == "scores" :
456+ for column in custom_columns_set :
457+ dataset_columns = [str (x ) for x in list (score_set .dataset_columns .get (column , []))]
458+ if column == "score_columns" :
459+ for c in dataset_columns :
460+ prefixed = "scores." + c
461+ columns .append (prefixed )
462+ elif column == "count_columns" :
463+ for c in dataset_columns :
464+ prefixed = "counts." + c
465+ columns .append (prefixed )
466+ elif len (data_types ) == 1 and data_types [0 ] == "scores" :
452467 columns .append (REQUIRED_SCORE_COLUMN )
453468
454469 variants : Sequence [Variant ] = []
@@ -488,7 +503,35 @@ def get_score_set_variants_as_csv(
488503 variants_query = variants_query .limit (limit )
489504 variants = db .scalars (variants_query ).all ()
490505
491- rows_data = variants_to_csv_rows (variants , columns = columns , dtype = type_column , mappings = mappings ) # type: ignore
506+ rows_data = variants_to_csv_rows (variants , columns = columns , dtype = type_columns , mappings = mappings ) # type: ignore
507+
508+ # TODO: will add len(data_types) == 1 and "scores"/"counts" are not in [data_types] and include_post_mapped_hgvs
509+ # case when we get the clinVar and gnomAD
510+ if len (data_types ) > 1 and include_post_mapped_hgvs :
511+ rename_map = {}
512+ rename_map ["post_mapped_hgvs_g" ] = "mavedb.post_mapped_hgvs_g"
513+ rename_map ["post_mapped_hgvs_p" ] = "mavedb.post_mapped_hgvs_p"
514+
515+ # Update column order list (preserve original order)
516+ columns = [rename_map .get (col , col ) for col in columns ]
517+
518+ # Rename keys in each row
519+ renamed_rows_data = []
520+ for row in rows_data :
521+ renamed_row = {rename_map .get (k , k ): v for k , v in row .items ()}
522+ renamed_rows_data .append (renamed_row )
523+
524+ rows_data = renamed_rows_data
525+ elif len (data_types ) == 1 :
526+ prefix = f"{ data_types [0 ]} ."
527+ columns = [col [len (prefix ):] if col .startswith (prefix ) else col for col in columns ]
528+
529+ # Rename rows to remove the same prefix from keys
530+ renamed_rows_data = []
531+ for row in rows_data :
532+ renamed_row = {(k [len (prefix ):] if k .startswith (prefix ) else k ): v for k , v in row .items ()}
533+ renamed_rows_data .append (renamed_row )
534+ rows_data = renamed_rows_data
492535 if drop_na_columns :
493536 rows_data , columns = drop_na_columns_from_csv_file_rows (rows_data , columns )
494537
@@ -532,7 +575,7 @@ def is_null(value):
532575def variant_to_csv_row (
533576 variant : Variant ,
534577 columns : list [str ],
535- dtype : str ,
578+ dtype : list [ str ] ,
536579 mapping : Optional [MappedVariant ] = None ,
537580 na_rep = "NA" ,
538581) -> dict [str , Any ]:
@@ -546,7 +589,7 @@ def variant_to_csv_row(
546589 columns : list[str]
547590 Columns to serialize.
548591 dtype : str, {'scores', 'counts'}
549- The type of data requested. Either the 'score_data' or ' count_data'.
592+ The type of data requested. [ 'score_data'], ['count_data'] or ['score_data', ' count_data'] .
550593 na_rep : str
551594 String to represent null values.
552595
@@ -577,8 +620,18 @@ def variant_to_csv_row(
577620 else :
578621 value = ""
579622 else :
580- parent = variant .data .get (dtype ) if variant .data else None
581- value = str (parent .get (column_key )) if parent else na_rep
623+ for dt in dtype :
624+ parent = variant .data .get (dt ) if variant .data else None
625+ if column_key .startswith ("scores." ):
626+ inner_key = column_key .replace ("scores." , "" )
627+ elif column_key .startswith ("counts." ):
628+ inner_key = column_key .replace ("counts." , "" )
629+ else :
630+ # fallback for non-prefixed columns
631+ inner_key = column_key
632+ if parent and inner_key in parent :
633+ value = str (parent [inner_key ])
634+ break
582635 if is_null (value ):
583636 value = na_rep
584637 row [column_key ] = value
@@ -589,7 +642,7 @@ def variant_to_csv_row(
589642def variants_to_csv_rows (
590643 variants : Sequence [Variant ],
591644 columns : list [str ],
592- dtype : str ,
645+ dtype : List [ str ] ,
593646 mappings : Optional [Sequence [Optional [MappedVariant ]]] = None ,
594647 na_rep = "NA" ,
595648) -> Iterable [dict [str , Any ]]:
@@ -602,8 +655,8 @@ def variants_to_csv_rows(
602655 List of variants.
603656 columns : list[str]
604657 Columns to serialize.
605- dtype : str , {'scores', 'counts'}
606- The type of data requested. Either the 'score_data' or ' count_data'.
658+ dtype : list , {'scores', 'counts'}
659+ The type of data requested. [ 'score_data'], ['count_data'] or ['score_data', ' count_data'] .
607660 na_rep : str
608661 String to represent null values.
609662
0 commit comments