99
1010from mavedb .lib .acmg import find_or_create_acmg_classification
1111from mavedb .lib .identifiers import find_or_create_publication_identifier
12- from mavedb .lib .validation .constants .general import calibration_class_column_name , calibration_variant_column_name
12+ from mavedb .lib .types .score_calibrations import ClassificationDict
13+ from mavedb .lib .validation .constants .general import (
14+ calibration_class_column_name ,
15+ calibration_variant_column_name ,
16+ hgvs_nt_column ,
17+ hgvs_pro_column ,
18+ )
1319from mavedb .lib .validation .utilities import inf_or_float
1420from mavedb .models .enums .score_calibration_relation import ScoreCalibrationRelation
1521from mavedb .models .score_calibration import ScoreCalibration
@@ -27,7 +33,7 @@ def create_functional_classification(
2733 score_calibration .FunctionalClassificationCreate , score_calibration .FunctionalClassificationModify
2834 ],
2935 containing_calibration : ScoreCalibration ,
30- variant_classes : Optional [dict [ str , list [ str ]] ] = None ,
36+ variant_classes : Optional [ClassificationDict ] = None ,
3137) -> ScoreCalibrationFunctionalClassification :
3238 """
3339 Create a functional classification entity for score calibration.
@@ -42,7 +48,7 @@ def create_functional_classification(
4248 description, range bounds, inclusivity flags, and optional ACMG
4349 classification information.
4450 containing_calibration (ScoreCalibration): The ScoreCalibration instance.
45- variant_classes (Optional[dict[str, list[str]] ]): Optional dictionary mapping variant classes
51+ variant_classes (Optional[ClassificationDict ]): Optional dictionary mapping variant classes
4652 to their corresponding variant identifiers.
4753
4854 Returns:
@@ -92,7 +98,7 @@ async def _create_score_calibration(
9298 db : Session ,
9399 calibration_create : score_calibration .ScoreCalibrationCreate ,
94100 user : User ,
95- variant_classes : Optional [dict [ str , list [ str ]] ] = None ,
101+ variant_classes : Optional [ClassificationDict ] = None ,
96102 containing_score_set : Optional [ScoreSet ] = None ,
97103) -> ScoreCalibration :
98104 """
@@ -125,6 +131,10 @@ async def _create_score_calibration(
125131 optional lists of publication source identifiers grouped by relation type.
126132 user : User
127133 Authenticated user context; the user to be recorded for audit
134+ variant_classes (Optional[ClassificationDict]):
135+ Optional dictionary mapping variant classes to their corresponding variant identifiers.
136+ containing_score_set : Optional[ScoreSet]
137+ If provided, the ScoreSet instance to which the new calibration will belong.
128138
129139 Returns
130140 -------
@@ -201,7 +211,7 @@ async def create_score_calibration_in_score_set(
201211 db : Session ,
202212 calibration_create : score_calibration .ScoreCalibrationCreate ,
203213 user : User ,
204- variant_classes : Optional [dict [ str , list [ str ]] ] = None ,
214+ variant_classes : Optional [ClassificationDict ] = None ,
205215) -> ScoreCalibration :
206216 """
207217 Create a new score calibration and associate it with an existing score set.
@@ -217,7 +227,7 @@ async def create_score_calibration_in_score_set(
217227 object containing the fields required to create a score calibration. Must include
218228 a non-empty score_set_urn.
219229 user (User): Authenticated user information used for auditing
220- variant_classes (Optional[dict[str, list[str]] ]): Optional dictionary mapping variant classes
230+ variant_classes (Optional[ClassificationDict ]): Optional dictionary mapping variant classes
221231 to their corresponding variant identifiers.
222232
223233 Returns:
@@ -259,7 +269,7 @@ async def create_score_calibration(
259269 db : Session ,
260270 calibration_create : score_calibration .ScoreCalibrationCreate ,
261271 user : User ,
262- variant_classes : Optional [dict [ str , list [ str ]] ] = None ,
272+ variant_classes : Optional [ClassificationDict ] = None ,
263273) -> ScoreCalibration :
264274 """
265275 Asynchronously create and persist a new ScoreCalibration record.
@@ -277,7 +287,7 @@ async def create_score_calibration(
277287 score set identifiers).
278288 user : User
279289 Authenticated user context; the user to be recorded for audit
280- variant_classes (Optional[dict[str, list[str]] ]): Optional dictionary mapping variant classes
290+ variant_classes (Optional[ClassificationDict ]): Optional dictionary mapping variant classes
281291 to their corresponding variant identifiers.
282292
283293 Returns
@@ -323,7 +333,7 @@ async def modify_score_calibration(
323333 calibration : ScoreCalibration ,
324334 calibration_update : score_calibration .ScoreCalibrationModify ,
325335 user : User ,
326- variant_classes : Optional [dict [ str , list [ str ]] ] = None ,
336+ variant_classes : Optional [ClassificationDict ] = None ,
327337) -> ScoreCalibration :
328338 """
329339 Asynchronously modify an existing ScoreCalibration record and its related publication
@@ -360,7 +370,7 @@ async def modify_score_calibration(
360370 - Additional mutable calibration attributes.
361371 user : User
362372 Context for the authenticated user; the user to be recorded for audit.
363- variant_classes (Optional[dict[str, list[str]] ]): Optional dictionary mapping variant classes
373+ variant_classes (Optional[ClassificationDict ]): Optional dictionary mapping variant classes
364374 to their corresponding variant identifiers.
365375
366376 Returns
@@ -645,7 +655,7 @@ def delete_score_calibration(db: Session, calibration: ScoreCalibration) -> None
645655def variants_for_functional_classification (
646656 db : Session ,
647657 functional_classification : ScoreCalibrationFunctionalClassification ,
648- variant_classes : Optional [dict [ str , list [ str ]] ] = None ,
658+ variant_classes : Optional [ClassificationDict ] = None ,
649659 use_sql : bool = False ,
650660) -> list [Variant ]:
651661 """
@@ -664,7 +674,7 @@ def variants_for_functional_classification(
664674 Active SQLAlchemy session.
665675 functional_classification : ScoreCalibrationFunctionalClassification
666676 The ORM row defining the interval to test against.
667- variant_classes : Optional[dict[str, list[str]] ]
677+ variant_classes : Optional[ClassificationDict ]
668678 If provided, a dictionary mapping variant classes to their corresponding variant identifiers
669679 to use for classification rather than the range property of the functional_classification.
670680 use_sql : bool
@@ -688,15 +698,31 @@ def variants_for_functional_classification(
688698 """
689699 # Resolve score set id from attached calibration (relationship may be lazy)
690700 score_set_id = functional_classification .calibration .score_set_id # type: ignore[attr-defined]
701+
702+ if variant_classes and variant_classes ["indexed_by" ] not in [
703+ hgvs_nt_column ,
704+ hgvs_pro_column ,
705+ calibration_variant_column_name ,
706+ ]:
707+ raise ValueError (f"Unsupported index column `{ variant_classes ['indexed_by' ]} ` for variant classification." )
708+
691709 if use_sql :
692710 try :
693711 # Build score extraction expression: data['score_data']['score']::text::float
694712 score_expr = Variant .data ["score_data" ]["score" ].astext .cast (Float )
695713
696714 conditions = [Variant .score_set_id == score_set_id ]
697715 if variant_classes is not None and functional_classification .class_ is not None :
698- variant_urns = variant_classes .get (functional_classification .class_ , [])
699- conditions .append (Variant .urn .in_ (variant_urns ))
716+ index_element = variant_classes ["classifications" ].get (functional_classification .class_ , set ())
717+
718+ if variant_classes ["indexed_by" ] == hgvs_nt_column :
719+ conditions .append (Variant .hgvs_nt .in_ (index_element ))
720+ elif variant_classes ["indexed_by" ] == hgvs_pro_column :
721+ conditions .append (Variant .hgvs_pro .in_ (index_element ))
722+ elif variant_classes ["indexed_by" ] == calibration_variant_column_name :
723+ conditions .append (Variant .urn .in_ (index_element ))
724+ else : # pragma: no cover
725+ return []
700726
701727 elif functional_classification .range is not None and len (functional_classification .range ) == 2 :
702728 lower_raw , upper_raw = functional_classification .range
@@ -732,9 +758,19 @@ def variants_for_functional_classification(
732758 matches : list [Variant ] = []
733759 for v in variants :
734760 if variant_classes is not None and functional_classification .class_ is not None :
735- variant_urns = variant_classes .get (functional_classification .class_ , [])
736- if v .urn in variant_urns :
737- matches .append (v )
761+ index_element = variant_classes ["classifications" ].get (functional_classification .class_ , set ())
762+
763+ if variant_classes ["indexed_by" ] == hgvs_nt_column :
764+ if v .hgvs_nt in index_element :
765+ matches .append (v )
766+ elif variant_classes ["indexed_by" ] == hgvs_pro_column :
767+ if v .hgvs_pro in index_element :
768+ matches .append (v )
769+ elif variant_classes ["indexed_by" ] == calibration_variant_column_name :
770+ if v .urn in index_element :
771+ matches .append (v )
772+ else : # pragma: no cover
773+ continue
738774
739775 elif functional_classification .range is not None and len (functional_classification .range ) == 2 :
740776 try :
@@ -759,7 +795,8 @@ def variants_for_functional_classification(
759795
760796def variant_classification_df_to_dict (
761797 df : pd .DataFrame ,
762- ) -> dict [str , list [str ]]:
798+ index_column : str ,
799+ ) -> ClassificationDict :
763800 """
764801 Convert a DataFrame of variant classifications into a dictionary mapping
765802 functional class labels to lists of distinct variant URNs.
@@ -776,18 +813,19 @@ def variant_classification_df_to_dict(
776813
777814 Returns
778815 -------
779- dict[str, list[str]]
780- A dictionary where keys are functional class labels and values are lists
781- of distinct variant URNs belonging to each class.
816+ ClassificationDict
817+ A dictionary with two keys: 'indexed_by' indicating the index column name,
818+ and 'classifications' mapping each functional class label to a list of
819+ distinct variant URNs.
782820 """
783- classification_dict : dict [str , list [str ]] = {}
821+ classifications : dict [str , set [str ]] = {}
784822 for _ , row in df .iterrows ():
785- variant_urn = row [calibration_variant_column_name ]
823+ index_element = row [index_column ]
786824 functional_class = row [calibration_class_column_name ]
787825
788- if functional_class not in classification_dict :
789- classification_dict [functional_class ] = []
826+ if functional_class not in classifications :
827+ classifications [functional_class ] = set ()
790828
791- classification_dict [functional_class ].append ( variant_urn )
829+ classifications [functional_class ].add ( index_element )
792830
793- return {k : list ( set ( v )) for k , v in classification_dict . items () }
831+ return {"indexed_by" : index_column , "classifications" : classifications }
0 commit comments