55from operator import attrgetter
66from typing import Any , BinaryIO , Iterable , Optional , TYPE_CHECKING , Sequence , Literal
77
8+ from mavedb .models .mapped_variant import MappedVariant
89import numpy as np
910import pandas as pd
1011from pandas .testing import assert_index_equal
11- from sqlalchemy import Integer , cast , func , or_ , select
12+ from sqlalchemy import Integer , and_ , cast , func , or_ , select
1213from sqlalchemy .orm import Session , aliased , contains_eager , joinedload , selectinload
1314
1415from mavedb .lib .exceptions import ValidationError
1718 HGVS_NT_COLUMN ,
1819 HGVS_PRO_COLUMN ,
1920 HGVS_SPLICE_COLUMN ,
21+ REQUIRED_SCORE_COLUMN ,
2022 VARIANT_COUNT_DATA ,
2123 VARIANT_SCORE_DATA ,
2224)
5759
5860logger = logging .getLogger (__name__ )
5961
62+ HGVS_G_REGEX = re .compile (r"(^|:)g\." )
63+ HGVS_P_REGEX = re .compile (r"(^|:)p\." )
64+
6065
6166class HGVSColumns :
6267 NUCLEOTIDE : str = "hgvs_nt" # dataset.constants.hgvs_nt_column
@@ -402,26 +407,90 @@ def get_score_set_variants_as_csv(
402407 start : Optional [int ] = None ,
403408 limit : Optional [int ] = None ,
404409 drop_na_columns : Optional [bool ] = None ,
410+ include_custom_columns : bool = True ,
411+ include_post_mapped_hgvs : bool = False ,
405412) -> str :
413+ """
414+ Get the variant data from a score set as a CSV string.
415+
416+ Parameters
417+ __________
418+ db : Session
419+ The database session to use.
420+ score_set : ScoreSet
421+ The score set to get the variants from.
422+ data_type : {'scores', 'counts'}
423+ The type of data to get. Either 'scores' or 'counts'.
424+ start : int, optional
425+ The index to start from. If None, starts from the beginning.
426+ limit : int, optional
427+ The maximum number of variants to return. If None, returns all variants.
428+ drop_na_columns : bool, optional
429+ Whether to drop columns that contain only NA values. Defaults to False.
430+ include_custom_columns : bool, optional
431+ Whether to include custom columns defined in the score set. Defaults to True.
432+ include_post_mapped_hgvs : bool, optional
433+ Whether to include post-mapped HGVS notations in the output. Defaults to False. If True, the output will include
434+ columns for both post-mapped HGVS genomic (g.) and protein (p.) notations.
435+
436+ Returns
437+ _______
438+ str
439+ The CSV string containing the variant data.
440+ """
406441 assert type (score_set .dataset_columns ) is dict
407- dataset_cols = "score_columns" if data_type == "scores" else "count_columns"
442+ custom_columns_set = "score_columns" if data_type == "scores" else "count_columns"
408443 type_column = "score_data" if data_type == "scores" else "count_data"
409444
410- count_columns = [str (x ) for x in list (score_set .dataset_columns .get (dataset_cols , []))]
411- columns = ["accession" , "hgvs_nt" , "hgvs_splice" , "hgvs_pro" ] + count_columns
412-
413- variants_query = (
414- select (Variant )
415- .where (Variant .score_set_id == score_set .id )
416- .order_by (cast (func .split_part (Variant .urn , "#" , 2 ), Integer ))
417- )
418- if start :
419- variants_query = variants_query .offset (start )
420- if limit :
421- variants_query = variants_query .limit (limit )
422- variants = db .scalars (variants_query ).all ()
445+ columns = ["accession" , "hgvs_nt" , "hgvs_splice" , "hgvs_pro" ]
446+ if include_post_mapped_hgvs :
447+ columns .append ("post_mapped_hgvs_g" )
448+ columns .append ("post_mapped_hgvs_p" )
449+
450+ if include_custom_columns :
451+ custom_columns = [str (x ) for x in list (score_set .dataset_columns .get (custom_columns_set , []))]
452+ columns += custom_columns
453+ elif data_type == "scores" :
454+ columns .append (REQUIRED_SCORE_COLUMN )
455+
456+ variants : Sequence [Variant ] = []
457+ mappings : Optional [list [Optional [MappedVariant ]]] = None
458+
459+ if include_post_mapped_hgvs :
460+ variants_and_mappings_query = (
461+ select (Variant , MappedVariant )
462+ .join (
463+ MappedVariant ,
464+ and_ (Variant .id == MappedVariant .variant_id , MappedVariant .current .is_ (True )),
465+ isouter = True ,
466+ )
467+ .where (Variant .score_set_id == score_set .id )
468+ .order_by (cast (func .split_part (Variant .urn , "#" , 2 ), Integer ))
469+ )
470+ if start :
471+ variants_and_mappings_query = variants_and_mappings_query .offset (start )
472+ if limit :
473+ variants_and_mappings_query = variants_and_mappings_query .limit (limit )
474+ variants_and_mappings = db .execute (variants_and_mappings_query ).all ()
475+
476+ variants = []
477+ mappings = []
478+ for variant , mapping in variants_and_mappings :
479+ variants .append (variant )
480+ mappings .append (mapping )
481+ else :
482+ variants_query = (
483+ select (Variant )
484+ .where (Variant .score_set_id == score_set .id )
485+ .order_by (cast (func .split_part (Variant .urn , "#" , 2 ), Integer ))
486+ )
487+ if start :
488+ variants_query = variants_query .offset (start )
489+ if limit :
490+ variants_query = variants_query .limit (limit )
491+ variants = db .scalars (variants_query ).all ()
423492
424- rows_data = variants_to_csv_rows (variants , columns = columns , dtype = type_column ) # type: ignore
493+ rows_data = variants_to_csv_rows (variants , columns = columns , dtype = type_column , mappings = mappings ) # type: ignore
425494 if drop_na_columns :
426495 rows_data , columns = drop_na_columns_from_csv_file_rows (rows_data , columns )
427496
@@ -462,7 +531,63 @@ def is_null(value):
462531 return null_values_re .fullmatch (value ) or not value
463532
464533
465- def variant_to_csv_row (variant : Variant , columns : list [str ], dtype : str , na_rep = "NA" ) -> dict [str , Any ]:
534+ def hgvs_from_vrs_allele (allele : dict ) -> str :
535+ """
536+ Extract the HGVS notation from the VRS allele.
537+ """
538+ try :
539+ # VRS 2.X
540+ return allele ["expressions" ][0 ]["value" ]
541+ except KeyError :
542+ raise ValueError ("VRS 1.X format not supported." )
543+ # VRS 1.X. We don't want to allow this.
544+ # return allele["variation"]["expressions"][0]["value"]
545+
546+
547+ def get_hgvs_from_mapped_variant (post_mapped_vrs : Any ) -> Optional [str ]:
548+ if post_mapped_vrs ["type" ] == "Haplotype" : # type: ignore
549+ variations_hgvs = [hgvs_from_vrs_allele (allele ) for allele in post_mapped_vrs ["members" ]]
550+ elif post_mapped_vrs ["type" ] == "CisPhasedBlock" : # type: ignore
551+ variations_hgvs = [hgvs_from_vrs_allele (allele ) for allele in post_mapped_vrs ["members" ]]
552+ elif post_mapped_vrs ["type" ] == "Allele" : # type: ignore
553+ variations_hgvs = [hgvs_from_vrs_allele (post_mapped_vrs )]
554+ else :
555+ return None
556+
557+ if len (variations_hgvs ) == 0 :
558+ return None
559+ # raise ValueError(f"No variations found in variant {variant_urn}.")
560+ if len (variations_hgvs ) > 1 :
561+ return None
562+ # raise ValueError(f"Multiple variations found in variant {variant_urn}.")
563+
564+ return variations_hgvs [0 ]
565+
566+
567+ # TODO (https://github.com/VariantEffect/mavedb-api/issues/440) Temporarily, we are using these functions to distinguish
568+ # genomic and protein HGVS strings produced by the mapper. Using hgvs.parser.Parser is too slow, and we won't need to do
569+ # this once the mapper extracts separate g., c., and p. post-mapped HGVS strings.
570+ def is_hgvs_g (hgvs : str ) -> bool :
571+ """
572+ Check if the given HGVS string is a genomic HGVS (g.) string.
573+ """
574+ return bool (HGVS_G_REGEX .search (hgvs ))
575+
576+
577+ def is_hgvs_p (hgvs : str ) -> bool :
578+ """
579+ Check if the given HGVS string is a protein HGVS (p.) string.
580+ """
581+ return bool (HGVS_P_REGEX .search (hgvs ))
582+
583+
584+ def variant_to_csv_row (
585+ variant : Variant ,
586+ columns : list [str ],
587+ dtype : str ,
588+ mapping : Optional [MappedVariant ] = None ,
589+ na_rep = "NA" ,
590+ ) -> dict [str , Any ]:
466591 """
467592 Format a variant into a containing the keys specified in `columns`.
468593
@@ -491,6 +616,18 @@ def variant_to_csv_row(variant: Variant, columns: list[str], dtype: str, na_rep=
491616 value = str (variant .hgvs_splice )
492617 elif column_key == "accession" :
493618 value = str (variant .urn )
619+ elif column_key == "post_mapped_hgvs_g" :
620+ hgvs_str = get_hgvs_from_mapped_variant (mapping .post_mapped ) if mapping and mapping .post_mapped else None
621+ if hgvs_str is not None and is_hgvs_g (hgvs_str ):
622+ value = hgvs_str
623+ else :
624+ value = ""
625+ elif column_key == "post_mapped_hgvs_p" :
626+ hgvs_str = get_hgvs_from_mapped_variant (mapping .post_mapped ) if mapping and mapping .post_mapped else None
627+ if hgvs_str is not None and is_hgvs_p (hgvs_str ):
628+ value = hgvs_str
629+ else :
630+ value = ""
494631 else :
495632 parent = variant .data .get (dtype ) if variant .data else None
496633 value = str (parent .get (column_key )) if parent else na_rep
@@ -502,7 +639,11 @@ def variant_to_csv_row(variant: Variant, columns: list[str], dtype: str, na_rep=
502639
503640
504641def variants_to_csv_rows (
505- variants : Sequence [Variant ], columns : list [str ], dtype : str , na_rep = "NA"
642+ variants : Sequence [Variant ],
643+ columns : list [str ],
644+ dtype : str ,
645+ mappings : Optional [Sequence [Optional [MappedVariant ]]] = None ,
646+ na_rep = "NA" ,
506647) -> Iterable [dict [str , Any ]]:
507648 """
508649 Format each variant into a dictionary row containing the keys specified in `columns`.
@@ -522,7 +663,12 @@ def variants_to_csv_rows(
522663 -------
523664 list[dict[str, Any]]
524665 """
525- return map (lambda v : variant_to_csv_row (v , columns , dtype , na_rep ), variants )
666+ if mappings is not None :
667+ return map (
668+ lambda pair : variant_to_csv_row (pair [0 ], columns , dtype , mapping = pair [1 ], na_rep = na_rep ),
669+ zip (variants , mappings ),
670+ )
671+ return map (lambda v : variant_to_csv_row (v , columns , dtype , na_rep = na_rep ), variants )
526672
527673
528674def find_meta_analyses_for_score_sets (db : Session , urns : list [str ]) -> list [ScoreSet ]:
0 commit comments