Skip to content

Commit 968126a

Browse files
committed
feat: allow class-based calibration definitions from hgvs_nt and hgvs_pro
- Allow class-based calibration to be defined via hgvs strings - Introduced new test CSV files for calibration classes based on HGVS nucleotide, HGVS protein, and URN. - Enhanced test coverage for score calibration creation and updating, including scenarios for decoding errors and validation errors. - Refactored tests to utilize parameterization for different calibration class files. - Added validation checks for index column selection in calibration dataframes. - Improved error messages for missing or invalid calibration classes.
1 parent 126c591 commit 968126a

File tree

10 files changed

+1036
-234
lines changed

10 files changed

+1036
-234
lines changed

src/mavedb/lib/score_calibrations.py

Lines changed: 65 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99

1010
from mavedb.lib.acmg import find_or_create_acmg_classification
1111
from mavedb.lib.identifiers import find_or_create_publication_identifier
12-
from mavedb.lib.validation.constants.general import calibration_class_column_name, calibration_variant_column_name
12+
from mavedb.lib.types.score_calibrations import ClassificationDict
13+
from mavedb.lib.validation.constants.general import (
14+
calibration_class_column_name,
15+
calibration_variant_column_name,
16+
hgvs_nt_column,
17+
hgvs_pro_column,
18+
)
1319
from mavedb.lib.validation.utilities import inf_or_float
1420
from mavedb.models.enums.score_calibration_relation import ScoreCalibrationRelation
1521
from mavedb.models.score_calibration import ScoreCalibration
@@ -27,7 +33,7 @@ def create_functional_classification(
2733
score_calibration.FunctionalClassificationCreate, score_calibration.FunctionalClassificationModify
2834
],
2935
containing_calibration: ScoreCalibration,
30-
variant_classes: Optional[dict[str, list[str]]] = None,
36+
variant_classes: Optional[ClassificationDict] = None,
3137
) -> ScoreCalibrationFunctionalClassification:
3238
"""
3339
Create a functional classification entity for score calibration.
@@ -42,7 +48,7 @@ def create_functional_classification(
4248
description, range bounds, inclusivity flags, and optional ACMG
4349
classification information.
4450
containing_calibration (ScoreCalibration): The ScoreCalibration instance.
45-
variant_classes (Optional[dict[str, list[str]]]): Optional dictionary mapping variant classes
51+
variant_classes (Optional[ClassificationDict]): Optional dictionary mapping variant classes
4652
to their corresponding variant identifiers.
4753
4854
Returns:
@@ -92,7 +98,7 @@ async def _create_score_calibration(
9298
db: Session,
9399
calibration_create: score_calibration.ScoreCalibrationCreate,
94100
user: User,
95-
variant_classes: Optional[dict[str, list[str]]] = None,
101+
variant_classes: Optional[ClassificationDict] = None,
96102
containing_score_set: Optional[ScoreSet] = None,
97103
) -> ScoreCalibration:
98104
"""
@@ -125,6 +131,10 @@ async def _create_score_calibration(
125131
optional lists of publication source identifiers grouped by relation type.
126132
user : User
127133
Authenticated user context; the user to be recorded for audit
134+
variant_classes (Optional[ClassificationDict]):
135+
Optional dictionary mapping variant classes to their corresponding variant identifiers.
136+
containing_score_set : Optional[ScoreSet]
137+
If provided, the ScoreSet instance to which the new calibration will belong.
128138
129139
Returns
130140
-------
@@ -201,7 +211,7 @@ async def create_score_calibration_in_score_set(
201211
db: Session,
202212
calibration_create: score_calibration.ScoreCalibrationCreate,
203213
user: User,
204-
variant_classes: Optional[dict[str, list[str]]] = None,
214+
variant_classes: Optional[ClassificationDict] = None,
205215
) -> ScoreCalibration:
206216
"""
207217
Create a new score calibration and associate it with an existing score set.
@@ -217,7 +227,7 @@ async def create_score_calibration_in_score_set(
217227
object containing the fields required to create a score calibration. Must include
218228
a non-empty score_set_urn.
219229
user (User): Authenticated user information used for auditing
220-
variant_classes (Optional[dict[str, list[str]]]): Optional dictionary mapping variant classes
230+
variant_classes (Optional[ClassificationDict]): Optional dictionary mapping variant classes
221231
to their corresponding variant identifiers.
222232
223233
Returns:
@@ -259,7 +269,7 @@ async def create_score_calibration(
259269
db: Session,
260270
calibration_create: score_calibration.ScoreCalibrationCreate,
261271
user: User,
262-
variant_classes: Optional[dict[str, list[str]]] = None,
272+
variant_classes: Optional[ClassificationDict] = None,
263273
) -> ScoreCalibration:
264274
"""
265275
Asynchronously create and persist a new ScoreCalibration record.
@@ -277,7 +287,7 @@ async def create_score_calibration(
277287
score set identifiers).
278288
user : User
279289
Authenticated user context; the user to be recorded for audit
280-
variant_classes (Optional[dict[str, list[str]]]): Optional dictionary mapping variant classes
290+
variant_classes (Optional[ClassificationDict]): Optional dictionary mapping variant classes
281291
to their corresponding variant identifiers.
282292
283293
Returns
@@ -323,7 +333,7 @@ async def modify_score_calibration(
323333
calibration: ScoreCalibration,
324334
calibration_update: score_calibration.ScoreCalibrationModify,
325335
user: User,
326-
variant_classes: Optional[dict[str, list[str]]] = None,
336+
variant_classes: Optional[ClassificationDict] = None,
327337
) -> ScoreCalibration:
328338
"""
329339
Asynchronously modify an existing ScoreCalibration record and its related publication
@@ -360,7 +370,7 @@ async def modify_score_calibration(
360370
- Additional mutable calibration attributes.
361371
user : User
362372
Context for the authenticated user; the user to be recorded for audit.
363-
variant_classes (Optional[dict[str, list[str]]]): Optional dictionary mapping variant classes
373+
variant_classes (Optional[ClassificationDict]): Optional dictionary mapping variant classes
364374
to their corresponding variant identifiers.
365375
366376
Returns
@@ -645,7 +655,7 @@ def delete_score_calibration(db: Session, calibration: ScoreCalibration) -> None
645655
def variants_for_functional_classification(
646656
db: Session,
647657
functional_classification: ScoreCalibrationFunctionalClassification,
648-
variant_classes: Optional[dict[str, list[str]]] = None,
658+
variant_classes: Optional[ClassificationDict] = None,
649659
use_sql: bool = False,
650660
) -> list[Variant]:
651661
"""
@@ -664,7 +674,7 @@ def variants_for_functional_classification(
664674
Active SQLAlchemy session.
665675
functional_classification : ScoreCalibrationFunctionalClassification
666676
The ORM row defining the interval to test against.
667-
variant_classes : Optional[dict[str, list[str]]]
677+
variant_classes : Optional[ClassificationDict]
668678
If provided, a dictionary mapping variant classes to their corresponding variant identifiers
669679
to use for classification rather than the range property of the functional_classification.
670680
use_sql : bool
@@ -688,15 +698,31 @@ def variants_for_functional_classification(
688698
"""
689699
# Resolve score set id from attached calibration (relationship may be lazy)
690700
score_set_id = functional_classification.calibration.score_set_id # type: ignore[attr-defined]
701+
702+
if variant_classes and variant_classes["indexed_by"] not in [
703+
hgvs_nt_column,
704+
hgvs_pro_column,
705+
calibration_variant_column_name,
706+
]:
707+
raise ValueError(f"Unsupported index column `{variant_classes['indexed_by']}` for variant classification.")
708+
691709
if use_sql:
692710
try:
693711
# Build score extraction expression: data['score_data']['score']::text::float
694712
score_expr = Variant.data["score_data"]["score"].astext.cast(Float)
695713

696714
conditions = [Variant.score_set_id == score_set_id]
697715
if variant_classes is not None and functional_classification.class_ is not None:
698-
variant_urns = variant_classes.get(functional_classification.class_, [])
699-
conditions.append(Variant.urn.in_(variant_urns))
716+
index_element = variant_classes["classifications"].get(functional_classification.class_, set())
717+
718+
if variant_classes["indexed_by"] == hgvs_nt_column:
719+
conditions.append(Variant.hgvs_nt.in_(index_element))
720+
elif variant_classes["indexed_by"] == hgvs_pro_column:
721+
conditions.append(Variant.hgvs_pro.in_(index_element))
722+
elif variant_classes["indexed_by"] == calibration_variant_column_name:
723+
conditions.append(Variant.urn.in_(index_element))
724+
else: # pragma: no cover
725+
return []
700726

701727
elif functional_classification.range is not None and len(functional_classification.range) == 2:
702728
lower_raw, upper_raw = functional_classification.range
@@ -732,9 +758,19 @@ def variants_for_functional_classification(
732758
matches: list[Variant] = []
733759
for v in variants:
734760
if variant_classes is not None and functional_classification.class_ is not None:
735-
variant_urns = variant_classes.get(functional_classification.class_, [])
736-
if v.urn in variant_urns:
737-
matches.append(v)
761+
index_element = variant_classes["classifications"].get(functional_classification.class_, set())
762+
763+
if variant_classes["indexed_by"] == hgvs_nt_column:
764+
if v.hgvs_nt in index_element:
765+
matches.append(v)
766+
elif variant_classes["indexed_by"] == hgvs_pro_column:
767+
if v.hgvs_pro in index_element:
768+
matches.append(v)
769+
elif variant_classes["indexed_by"] == calibration_variant_column_name:
770+
if v.urn in index_element:
771+
matches.append(v)
772+
else: # pragma: no cover
773+
continue
738774

739775
elif functional_classification.range is not None and len(functional_classification.range) == 2:
740776
try:
@@ -759,7 +795,8 @@ def variants_for_functional_classification(
759795

760796
def variant_classification_df_to_dict(
761797
df: pd.DataFrame,
762-
) -> dict[str, list[str]]:
798+
index_column: str,
799+
) -> ClassificationDict:
763800
"""
764801
Convert a DataFrame of variant classifications into a dictionary mapping
765802
functional class labels to lists of distinct variant URNs.
@@ -776,18 +813,19 @@ def variant_classification_df_to_dict(
776813
777814
Returns
778815
-------
779-
dict[str, list[str]]
780-
A dictionary where keys are functional class labels and values are lists
781-
of distinct variant URNs belonging to each class.
816+
ClassificationDict
817+
A dictionary with two keys: 'indexed_by' indicating the index column name,
818+
and 'classifications' mapping each functional class label to a list of
819+
distinct variant URNs.
782820
"""
783-
classification_dict: dict[str, list[str]] = {}
821+
classifications: dict[str, set[str]] = {}
784822
for _, row in df.iterrows():
785-
variant_urn = row[calibration_variant_column_name]
823+
index_element = row[index_column]
786824
functional_class = row[calibration_class_column_name]
787825

788-
if functional_class not in classification_dict:
789-
classification_dict[functional_class] = []
826+
if functional_class not in classifications:
827+
classifications[functional_class] = set()
790828

791-
classification_dict[functional_class].append(variant_urn)
829+
classifications[functional_class].add(index_element)
792830

793-
return {k: list(set(v)) for k, v in classification_dict.items()}
831+
return {"indexed_by": index_column, "classifications": classifications}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from typing import TypedDict
2+
3+
4+
class ClassificationDict(TypedDict):
5+
indexed_by: str
6+
classifications: dict[str, set[str]]

src/mavedb/lib/validation/dataframe/calibration.py

Lines changed: 79 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from mavedb.lib.validation.constants.general import (
66
calibration_class_column_name,
77
calibration_variant_column_name,
8+
hgvs_nt_column,
9+
hgvs_pro_column,
810
)
911
from mavedb.lib.validation.dataframe.column import validate_data_column, validate_variant_column
1012
from mavedb.lib.validation.dataframe.dataframe import standardize_dataframe, validate_no_null_rows
@@ -13,15 +15,20 @@
1315
from mavedb.models.variant import Variant
1416
from mavedb.view_models import score_calibration
1517

16-
STANDARD_CALIBRATION_COLUMNS = (calibration_variant_column_name, calibration_class_column_name)
18+
STANDARD_CALIBRATION_COLUMNS = (
19+
calibration_variant_column_name,
20+
calibration_class_column_name,
21+
hgvs_nt_column,
22+
hgvs_pro_column,
23+
)
1724

1825

1926
def validate_and_standardize_calibration_classes_dataframe(
2027
db: Session,
2128
score_set: ScoreSet,
2229
calibration: score_calibration.ScoreCalibrationCreate | score_calibration.ScoreCalibrationModify,
2330
classes_df: pd.DataFrame,
24-
) -> pd.DataFrame:
31+
) -> tuple[pd.DataFrame, str]:
2532
"""
2633
Validate and standardize a calibration classes dataframe for functional classification calibrations.
2734
@@ -59,22 +66,22 @@ def validate_and_standardize_calibration_classes_dataframe(
5966
validate_no_null_rows(standardized_classes_df)
6067

6168
column_mapping = {c.lower(): c for c in standardized_classes_df.columns}
62-
index_column = column_mapping[calibration_variant_column_name]
69+
index_column = choose_calibration_index_column(standardized_classes_df)
70+
71+
# Drop rows where the calibration class column is NA
72+
standardized_classes_df = standardized_classes_df.dropna(
73+
subset=[column_mapping[calibration_class_column_name]]
74+
).reset_index(drop=True)
6375

6476
for c in column_mapping:
65-
if c == calibration_variant_column_name:
77+
if c == index_column.lower():
6678
validate_variant_column(standardized_classes_df[c], column_mapping[c] == index_column)
67-
validate_calibration_variant_urns(db, score_set, standardized_classes_df[c])
79+
validate_index_existence_in_score_set(db, score_set, standardized_classes_df[c], index_column)
6880
elif c == calibration_class_column_name:
6981
validate_data_column(standardized_classes_df[c], force_numeric=False)
7082
validate_calibration_classes(calibration, standardized_classes_df[c])
7183

72-
# handle unexpected columns. These should have already been caught by
73-
# validate_calibration_df_column_names, but we include this for completeness.
74-
else: # pragma: no cover
75-
raise ValidationError(f"unexpected column in calibration classes file: '{c}'")
76-
77-
return standardized_classes_df
84+
return standardized_classes_df, index_column
7885

7986

8087
def validate_calibration_df_column_names(df: pd.DataFrame) -> None:
@@ -113,21 +120,20 @@ def validate_calibration_df_column_names(df: pd.DataFrame) -> None:
113120

114121
columns = [c.lower() for c in df.columns]
115122

116-
if calibration_variant_column_name not in columns:
117-
raise ValidationError(f"missing required column: '{calibration_variant_column_name}'")
118-
119123
if calibration_class_column_name not in columns:
120124
raise ValidationError(f"missing required column: '{calibration_class_column_name}'")
121125

122-
if set(STANDARD_CALIBRATION_COLUMNS) != set(columns):
126+
if set(columns).isdisjoint({hgvs_nt_column, hgvs_pro_column, calibration_variant_column_name}):
123127
raise ValidationError(
124-
f"unexpected column(s) in calibration classes file: {', '.join(sorted(set(columns) - set(STANDARD_CALIBRATION_COLUMNS)))}"
128+
f"at least one of {', '.join({hgvs_nt_column, hgvs_pro_column, calibration_variant_column_name})} must be present"
125129
)
126130

127131

128-
def validate_calibration_variant_urns(db: Session, score_set: ScoreSet, variant_urns: pd.Series) -> None:
132+
def validate_index_existence_in_score_set(
133+
db: Session, score_set: ScoreSet, index_column: pd.Series, index_column_name: str
134+
) -> None:
129135
"""
130-
Validate that all provided variant URNs exist in the given score set.
136+
Validate that all provided resources in the index column exist in the given score set.
131137
132138
Args:
133139
db (Session): Database session for querying variants.
@@ -140,19 +146,65 @@ def validate_calibration_variant_urns(db: Session, score_set: ScoreSet, variant_
140146
Returns:
141147
None: Function returns nothing if validation passes.
142148
"""
143-
existing_variant_urns = set(
144-
db.scalars(
145-
select(Variant.urn).where(Variant.score_set_id == score_set.id, Variant.urn.in_(variant_urns.tolist()))
146-
).all()
147-
)
148-
149-
missing_variant_urns = set(variant_urns.tolist()) - existing_variant_urns
150-
if missing_variant_urns:
149+
if index_column_name.lower() == calibration_variant_column_name:
150+
existing_resources = set(
151+
db.scalars(
152+
select(Variant.urn).where(Variant.score_set_id == score_set.id, Variant.urn.in_(index_column.tolist()))
153+
).all()
154+
)
155+
elif index_column_name.lower() == hgvs_nt_column:
156+
existing_resources = set(
157+
db.scalars(
158+
select(Variant.hgvs_nt).where(
159+
Variant.score_set_id == score_set.id, Variant.hgvs_nt.in_(index_column.tolist())
160+
)
161+
).all()
162+
)
163+
elif index_column_name.lower() == hgvs_pro_column:
164+
existing_resources = set(
165+
db.scalars(
166+
select(Variant.hgvs_pro).where(
167+
Variant.score_set_id == score_set.id, Variant.hgvs_pro.in_(index_column.tolist())
168+
)
169+
).all()
170+
)
171+
172+
missing_resources = set(index_column.tolist()) - existing_resources
173+
if missing_resources:
151174
raise ValidationError(
152-
f"The following variant URNs do not exist in the score set: {', '.join(sorted(missing_variant_urns))}"
175+
f"The following resources do not exist in the score set: {', '.join(sorted(missing_resources))}"
153176
)
154177

155178

179+
def choose_calibration_index_column(df: pd.DataFrame) -> str:
180+
"""
181+
Choose the appropriate index column for a calibration DataFrame.
182+
183+
This function selects the index column based on the presence of specific columns
184+
in the DataFrame. It prioritizes the calibration variant column, followed by
185+
HGVS notation columns.
186+
187+
Args:
188+
df (pd.DataFrame): The DataFrame from which to choose the index column.
189+
190+
Returns:
191+
str: The name of the chosen index column.
192+
193+
Raises:
194+
ValidationError: If no valid index column is found in the DataFrame.
195+
"""
196+
column_mapping = {c.lower(): c for c in df.columns if not df[c].isna().all()}
197+
198+
if calibration_variant_column_name in column_mapping:
199+
return column_mapping[calibration_variant_column_name]
200+
elif hgvs_nt_column in column_mapping:
201+
return column_mapping[hgvs_nt_column]
202+
elif hgvs_pro_column in column_mapping:
203+
return column_mapping[hgvs_pro_column]
204+
else:
205+
raise ValidationError("failed to find valid calibration index column")
206+
207+
156208
def validate_calibration_classes(
157209
calibration: score_calibration.ScoreCalibrationCreate | score_calibration.ScoreCalibrationModify, classes: pd.Series
158210
) -> None:

0 commit comments

Comments
 (0)