diff --git a/docker-compose-dev.yml b/docker-compose-dev.yml index 571d21ee..5c674d96 100644 --- a/docker-compose-dev.yml +++ b/docker-compose-dev.yml @@ -48,6 +48,7 @@ services: - redis dcd-mapping: + build: ../dcd_mapping image: dcd-mapping:dev command: bash -c "uvicorn api.server_main:app --host 0.0.0.0 --port 8000 --reload" depends_on: @@ -61,6 +62,7 @@ services: - mavedb-seqrepo-dev:/usr/local/share/seqrepo cdot-rest: + build: ../cdot_rest image: cdot-rest:dev command: bash -c "gunicorn cdot_rest.wsgi:application --bind 0.0.0.0:8000" env_file: diff --git a/settings/.env.template b/settings/.env.template index d2060ede..fbb5b861 100644 --- a/settings/.env.template +++ b/settings/.env.template @@ -67,8 +67,10 @@ DCD_MAPPING_URL=http://dcd-mapping:8000 #################################################################################################### CDOT_URL=http://cdot-rest:8000 -REDIS_HOST=localhost +REDIS_HOST=redis +REDIS_IP=redis REDIS_PORT=6379 +REDIS_SSL=false #################################################################################################### # Environment variables for ClinGen diff --git a/src/mavedb/lib/target_genes.py b/src/mavedb/lib/target_genes.py index da114584..61f20653 100644 --- a/src/mavedb/lib/target_genes.py +++ b/src/mavedb/lib/target_genes.py @@ -1,19 +1,154 @@ import logging from typing import Optional -from sqlalchemy import func, or_ +from sqlalchemy import and_, func, or_ from sqlalchemy.orm import Session from mavedb.lib.logging.context import logging_context, save_to_logging_context from mavedb.models.contributor import Contributor from mavedb.models.score_set import ScoreSet +from mavedb.models.target_accession import TargetAccession from mavedb.models.target_gene import TargetGene +from mavedb.models.target_sequence import TargetSequence +from mavedb.models.taxonomy import Taxonomy from mavedb.models.user import User from mavedb.view_models.search import TextSearch logger = logging.getLogger(__name__) +def find_or_create_target_gene_by_accession( + db: Session, + score_set_id: int, + tg: dict, + tg_accession: dict, +) -> TargetGene: + """ + Find or create a target gene for a score set by accession. If the existing target gene or related accession record is modified, + this function creates a new target gene so that that its id can be used to determine if a score set has changed in a way + that requires the create variants job to be re-run. + + : param db: Database session + : param score_set_id: ID of the score set to associate the target gene with + : param tg: Dictionary with target gene details (name, category, etc.) + : param tg_accession: Dictionary with target accession details (accession, assembly, gene, etc.) + : return: The found or newly created TargetGene instance + """ + target_gene = None + logger.info( + msg=f"Searching for existing target gene by accession within score set {score_set_id}.", + extra=logging_context(), + ) + if tg_accession is not None and tg_accession.get("accession"): + target_gene = ( + db.query(TargetGene) + .filter( + and_( + TargetGene.target_accession.has( + and_( + TargetAccession.accession == tg_accession["accession"], + TargetAccession.assembly == tg_accession["assembly"], + TargetAccession.gene == tg_accession["gene"], + TargetAccession.is_base_editor == tg_accession.get("is_base_editor", False), + ) + ), + TargetGene.name == tg["name"], + TargetGene.category == tg["category"], + TargetGene.score_set_id == score_set_id, + ) + ) + .first() + ) + + if target_gene is None: + target_accession = TargetAccession(**tg_accession) + target_gene = TargetGene( + **tg, + score_set_id=score_set_id, + target_accession=target_accession, + ) + db.add(target_gene) + db.commit() + db.refresh(target_gene) + logger.info( + msg=f"Created new target gene '{target_gene.name}' with ID {target_gene.id}.", + extra=logging_context(), + ) + else: + logger.info( + msg=f"Found existing target gene '{target_gene.name}' with ID {target_gene.id}.", + extra=logging_context(), + ) + + return target_gene + + +def find_or_create_target_gene_by_sequence( + db: Session, + score_set_id: int, + tg: dict, + tg_sequence: dict, +) -> TargetGene: + """ + Find or create a target gene for a score set by sequence. If the existing target gene or related sequence record is modified, + this function creates a new target gene so that that its id can be used to determine if a score set has changed in a way + that requires the create variants job to be re-run. + + : param db: Database session + : param score_set_id: ID of the score set to associate the target gene with + : param tg: Dictionary with target gene details (name, category, etc.) + : param tg_sequence: Dictionary with target sequence details (sequence, sequence_type, taxonomy, label, etc.) + : return: The found or newly created TargetGene instance + """ + target_gene = None + logger.info( + msg=f"Searching for existing target gene by sequence within score set {score_set_id}.", + extra=logging_context(), + ) + if tg_sequence is not None and tg_sequence.get("sequence"): + target_gene = ( + db.query(TargetGene) + .filter( + and_( + TargetGene.target_sequence.has( + and_( + TargetSequence.sequence == tg_sequence["sequence"], + TargetSequence.sequence_type == tg_sequence["sequence_type"], + TargetSequence.taxonomy.has(Taxonomy.id == tg_sequence["taxonomy"].id), + TargetSequence.label == tg_sequence["label"], + ) + ), + TargetGene.name == tg["name"], + TargetGene.category == tg["category"], + TargetGene.score_set_id == score_set_id, + ) + ) + .first() + ) + + if target_gene is None: + target_sequence = TargetSequence(**tg_sequence) + target_gene = TargetGene( + **tg, + score_set_id=score_set_id, + target_sequence=target_sequence, + ) + db.add(target_gene) + db.commit() + db.refresh(target_gene) + logger.info( + msg=f"Created new target gene '{target_gene.name}' with ID {target_gene.id}.", + extra=logging_context(), + ) + else: + logger.info( + msg=f"Found existing target gene '{target_gene.name}' with ID {target_gene.id}.", + extra=logging_context(), + ) + + return target_gene + + def search_target_genes( db: Session, owner_or_contributor: Optional[User], diff --git a/src/mavedb/lib/validation/dataframe/dataframe.py b/src/mavedb/lib/validation/dataframe/dataframe.py index b8bfb6d1..75a07db6 100644 --- a/src/mavedb/lib/validation/dataframe/dataframe.py +++ b/src/mavedb/lib/validation/dataframe/dataframe.py @@ -1,25 +1,26 @@ -from typing import Optional, Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Optional, Tuple import numpy as np import pandas as pd from mavedb.lib.exceptions import MixedTargetError from mavedb.lib.validation.constants.general import ( + guide_sequence_column, hgvs_nt_column, hgvs_pro_column, hgvs_splice_column, - guide_sequence_column, required_score_column, ) -from mavedb.lib.validation.exceptions import ValidationError -from mavedb.models.target_gene import TargetGene from mavedb.lib.validation.dataframe.column import validate_data_column from mavedb.lib.validation.dataframe.variant import ( - validate_hgvs_transgenic_column, - validate_hgvs_genomic_column, validate_guide_sequence_column, + validate_hgvs_genomic_column, validate_hgvs_prefix_combinations, + validate_hgvs_transgenic_column, ) +from mavedb.lib.validation.exceptions import ValidationError +from mavedb.models.target_gene import TargetGene +from mavedb.view_models.score_set_dataset_columns import DatasetColumnMetadata if TYPE_CHECKING: from cdot.hgvs.dataproviders import RESTDataProvider @@ -28,12 +29,28 @@ STANDARD_COLUMNS = (hgvs_nt_column, hgvs_splice_column, hgvs_pro_column, required_score_column, guide_sequence_column) +def clean_col_name(col: str) -> str: + col = col.strip() + # Only remove quotes if the column name is fully quoted + if (col.startswith('"') and col.endswith('"')) or (col.startswith("'") and col.endswith("'")): + col = col[1:-1] + + return col.strip() + + def validate_and_standardize_dataframe_pair( scores_df: pd.DataFrame, counts_df: Optional[pd.DataFrame], + score_columns_metadata: Optional[dict[str, DatasetColumnMetadata]], + count_columns_metadata: Optional[dict[str, DatasetColumnMetadata]], targets: list[TargetGene], hdp: Optional["RESTDataProvider"], -) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: +) -> Tuple[ + pd.DataFrame, + Optional[pd.DataFrame], + Optional[dict[str, DatasetColumnMetadata]], + Optional[dict[str, DatasetColumnMetadata]], +]: """ Perform validation and standardization on a pair of score and count dataframes. @@ -43,6 +60,10 @@ def validate_and_standardize_dataframe_pair( The scores dataframe counts_df : Optional[pandas.DataFrame] The counts dataframe, can be None if not present + score_columns_metadata: Optional[dict[str, DatasetColumnMetadata]] + The scores column metadata, can be None if not present + count_columns_metadata: Optional[dict[str, DatasetColumnMetadata]] + The counts column metadata, can be None if not present targets : str The target genes on which to validate dataframes hdp : RESTDataProvider @@ -50,8 +71,8 @@ def validate_and_standardize_dataframe_pair( Returns ------- - Tuple[pd.DataFrame, Optional[pd.DataFrame]] - The standardized score and count dataframes, or score and None if no count dataframe was provided + Tuple[pd.DataFrame, Optional[pd.DataFrame], Optional[dict[str, DatasetColumnMetadata]], Optional[dict[str, DatasetColumnMetadata]]] + The standardized score and count dataframes, plus score column metadata and counts column metadata dictionaries. Counts dataframe and column metadata dictionaries can be None if not provided. Raises ------ @@ -65,11 +86,32 @@ def validate_and_standardize_dataframe_pair( standardized_counts_df = standardize_dataframe(counts_df) if counts_df is not None else None validate_dataframe(standardized_scores_df, "scores", targets, hdp) + + if score_columns_metadata is not None: + standardized_score_columns_metadata = standardize_dict_keys(score_columns_metadata) + validate_df_column_metadata_match(standardized_scores_df, standardized_score_columns_metadata) + else: + standardized_score_columns_metadata = None + if standardized_counts_df is not None: validate_dataframe(standardized_counts_df, "counts", targets, hdp) validate_variant_columns_match(standardized_scores_df, standardized_counts_df) - - return standardized_scores_df, standardized_counts_df + if count_columns_metadata is not None: + standardized_count_columns_metadata = standardize_dict_keys(count_columns_metadata) + validate_df_column_metadata_match(standardized_counts_df, standardized_count_columns_metadata) + else: + standardized_count_columns_metadata = None + else: + if count_columns_metadata is not None and len(count_columns_metadata.keys()) > 0: + raise ValidationError("Counts column metadata provided without counts dataframe") + standardized_count_columns_metadata = None + + return ( + standardized_scores_df, + standardized_counts_df, + standardized_score_columns_metadata, + standardized_count_columns_metadata, + ) def validate_dataframe( @@ -163,6 +205,25 @@ def validate_dataframe( ) +def standardize_dict_keys(d: dict[str, Any]) -> dict[str, Any]: + """ + Standardize the keys of a dictionary by stripping leading and trailing whitespace + and removing any quoted strings from the keys. + + Parameters + ---------- + d : dict[str, DatasetColumnMetadata] + The dictionary to standardize + + Returns + ------- + dict[str, DatasetColumnMetadata] + The standardized dictionary + """ + + return {clean_col_name(k): v for k, v in d.items()} + + def standardize_dataframe(df: pd.DataFrame) -> pd.DataFrame: """Standardize a dataframe by sorting the columns and changing the standard column names to lowercase. Also strips leading and trailing whitespace from column names and removes any quoted strings from column names. @@ -186,15 +247,7 @@ def standardize_dataframe(df: pd.DataFrame) -> pd.DataFrame: The standardized dataframe """ - def clean_column(col: str) -> str: - col = col.strip() - # Only remove quotes if the column name is fully quoted - if (col.startswith('"') and col.endswith('"')) or (col.startswith("'") and col.endswith("'")): - col = col[1:-1] - - return col.strip() - - cleaned_columns = {c: clean_column(c) for c in df.columns} + cleaned_columns = {c: clean_col_name(c) for c in df.columns} df.rename(columns=cleaned_columns, inplace=True) column_mapper = {x: x.lower() for x in df.columns if x.lower() in STANDARD_COLUMNS} @@ -368,6 +421,32 @@ def validate_variant_consistency(df: pd.DataFrame) -> None: pass +def validate_df_column_metadata_match(df: pd.DataFrame, columnMetadata: dict[str, DatasetColumnMetadata]): + """ + Checks that metadata keys match the dataframe column names and exclude standard column names. + + Parameters + ---------- + df1 : pandas.DataFrame + Dataframe parsed from an uploaded scores file + columnMetadata : dict[str, DatasetColumnMetadata] + Metadata for the scores columns + + Raises + ------ + ValidationError + If any metadata keys do not match dataframe column names + ValidationError + If any metadata keys match standard columns + + """ + for key in columnMetadata.keys(): + if key.lower() in STANDARD_COLUMNS: + raise ValidationError(f"standard column '{key}' cannot have metadata defined") + elif key not in df.columns: + raise ValidationError(f"column metadata key '{key}' does not match any dataframe column names") + + def validate_variant_columns_match(df1: pd.DataFrame, df2: pd.DataFrame): """ Checks if two dataframes have matching HGVS columns. diff --git a/src/mavedb/routers/experiments.py b/src/mavedb/routers/experiments.py index 5e49b017..a2682edd 100644 --- a/src/mavedb/routers/experiments.py +++ b/src/mavedb/routers/experiments.py @@ -5,15 +5,16 @@ import requests from fastapi import APIRouter, Depends, HTTPException from fastapi.encoders import jsonable_encoder -from sqlalchemy.orm import Session from sqlalchemy import or_ +from sqlalchemy.orm import Session from mavedb import deps from mavedb.lib.authentication import UserData, get_current_user from mavedb.lib.authorization import require_current_user, require_current_user_with_email from mavedb.lib.contributors import find_or_create_contributor from mavedb.lib.exceptions import NonexistentOrcidUserError -from mavedb.lib.experiments import search_experiments as _search_experiments, enrich_experiment_with_num_score_sets +from mavedb.lib.experiments import enrich_experiment_with_num_score_sets +from mavedb.lib.experiments import search_experiments as _search_experiments from mavedb.lib.identifiers import ( find_or_create_doi_identifier, find_or_create_publication_identifier, @@ -368,7 +369,7 @@ async def update_experiment( ] } for var, value in pairs.items(): # vars(item_update).items(): - setattr(item, var, value) if value else None + setattr(item, var, value) if value is not None else None try: item.contributors = [ diff --git a/src/mavedb/routers/score_sets.py b/src/mavedb/routers/score_sets.py index 49d0e084..2e1aa87b 100644 --- a/src/mavedb/routers/score_sets.py +++ b/src/mavedb/routers/score_sets.py @@ -1,26 +1,28 @@ +import json import logging from datetime import date -from typing import Any, List, Optional, Sequence, Union +from typing import Any, List, Optional, Sequence, TypedDict, Union import pandas as pd from arq import ArqRedis -from fastapi import APIRouter, Depends, File, Query, UploadFile, status +from fastapi import APIRouter, Depends, File, Query, Request, UploadFile, status from fastapi.encoders import jsonable_encoder -from fastapi.exceptions import HTTPException +from fastapi.exceptions import HTTPException, RequestValidationError from fastapi.responses import StreamingResponse from ga4gh.va_spec.acmg_2015 import VariantPathogenicityEvidenceLine -from ga4gh.va_spec.base.core import Statement, ExperimentalVariantFunctionalImpactStudyResult +from ga4gh.va_spec.base.core import ExperimentalVariantFunctionalImpactStudyResult, Statement +from pydantic import ValidationError from sqlalchemy import null, or_, select from sqlalchemy.exc import MultipleResultsFound, NoResultFound -from sqlalchemy.orm import contains_eager, Session +from sqlalchemy.orm import Session, contains_eager from mavedb import deps -from mavedb.lib.annotation.exceptions import MappingDataDoesntExistException from mavedb.lib.annotation.annotate import ( - variant_pathogenicity_evidence, variant_functional_impact_statement, + variant_pathogenicity_evidence, variant_study_result, ) +from mavedb.lib.annotation.exceptions import MappingDataDoesntExistException from mavedb.lib.authentication import UserData from mavedb.lib.authorization import ( RoleRequirer, @@ -48,12 +50,13 @@ fetch_score_set_search_filter_options, find_meta_analyses_for_experiment_sets, get_score_set_variants_as_csv, + refresh_variant_urns, variants_to_csv_rows, ) from mavedb.lib.score_sets import ( search_score_sets as _search_score_sets, - refresh_variant_urns, ) +from mavedb.lib.target_genes import find_or_create_target_gene_by_accession, find_or_create_target_gene_by_sequence from mavedb.lib.taxonomies import find_or_create_taxonomy from mavedb.lib.urns import ( generate_experiment_set_urn, @@ -73,8 +76,13 @@ from mavedb.models.target_gene import TargetGene from mavedb.models.target_sequence import TargetSequence from mavedb.models.variant import Variant -from mavedb.view_models import mapped_variant, score_set, clinical_control, score_range, gnomad_variant +from mavedb.view_models import clinical_control, gnomad_variant, mapped_variant, score_range, score_set +from mavedb.view_models.contributor import ContributorCreate +from mavedb.view_models.doi_identifier import DoiIdentifierCreate +from mavedb.view_models.publication_identifier import PublicationIdentifierCreate +from mavedb.view_models.score_set_dataset_columns import DatasetColumnMetadata from mavedb.view_models.search import ScoreSetsSearch, ScoreSetsSearchFilterOptionsResponse, ScoreSetsSearchResponse +from mavedb.view_models.target_gene import TargetGeneCreate logger = logging.getLogger(__name__) @@ -82,6 +90,347 @@ SCORE_SET_SEARCH_MAX_PUBLICATION_IDENTIFIERS = 40 +async def enqueue_variant_creation( + *, + item: ScoreSet, + user_data: UserData, + new_scores_df: Optional[pd.DataFrame] = None, + new_counts_df: Optional[pd.DataFrame] = None, + new_score_columns_metadata: Optional[dict[str, DatasetColumnMetadata]] = None, + new_count_columns_metadata: Optional[dict[str, DatasetColumnMetadata]] = None, + worker: ArqRedis, +) -> None: + assert item.dataset_columns is not None + + # create CSV from existing variants on the score set if no new dataframe provided + existing_scores_df = None + if new_scores_df is None: + score_columns = [ + "hgvs_nt", + "hgvs_splice", + "hgvs_pro", + ] + item.dataset_columns.get("score_columns", []) + existing_scores_df = pd.DataFrame( + variants_to_csv_rows(item.variants, columns=score_columns, dtype="score_data") + ).replace("NA", pd.NA) + + # create CSV from existing variants on the score set if no new dataframe provided + existing_counts_df = None + if new_counts_df is None and item.dataset_columns.get("count_columns"): + count_columns = [ + "hgvs_nt", + "hgvs_splice", + "hgvs_pro", + ] + item.dataset_columns["count_columns"] + existing_counts_df = pd.DataFrame( + variants_to_csv_rows(item.variants, columns=count_columns, dtype="count_data") + ).replace("NA", pd.NA) + + # Await the insertion of this job into the worker queue, not the job itself. + # Uses provided score and counts dataframes and metadata files, or falls back to existing data on the score set if not provided. + job = await worker.enqueue_job( + "create_variants_for_score_set", + correlation_id_for_context(), + item.id, + user_data.user.id, + existing_scores_df if new_scores_df is None else new_scores_df, + existing_counts_df if new_counts_df is None else new_counts_df, + item.dataset_columns.get("score_columns_metadata") + if new_score_columns_metadata is None + else new_score_columns_metadata, + item.dataset_columns.get("count_columns_metadata") + if new_count_columns_metadata is None + else new_count_columns_metadata, + ) + if job is not None: + save_to_logging_context({"worker_job_id": job.job_id}) + logger.info(msg="Enqueued variant creation job.", extra=logging_context()) + + +class ScoreSetUpdateResult(TypedDict): + item: ScoreSet + should_create_variants: bool + + +async def score_set_update( + *, + db: Session, + urn: str, + item_update: score_set.ScoreSetUpdateAllOptional, + exclude_unset: bool = False, + user_data: UserData, + existing_item: Optional[ScoreSet] = None, +) -> ScoreSetUpdateResult: + logger.info(msg="Updating score set.", extra=logging_context()) + + should_create_variants = False + item_update_dict: dict[str, Any] = item_update.model_dump(exclude_unset=exclude_unset) + + item = existing_item or db.query(ScoreSet).filter(ScoreSet.urn == urn).one_or_none() + if not item or item.id is None: + logger.info(msg="Failed to update score set; The requested score set does not exist.", extra=logging_context()) + raise HTTPException(status_code=404, detail=f"score set with URN '{urn}' not found") + + assert_permission(user_data, item, Action.UPDATE) + + for var, value in item_update_dict.items(): + if var not in [ + "contributors", + "score_ranges", + "doi_identifiers", + "experiment_urn", + "license_id", + "secondary_publication_identifiers", + "primary_publication_identifiers", + "target_genes", + "dataset_columns", + ]: + setattr(item, var, value) + + item_update_license_id = item_update_dict.get("license_id") + if item_update_license_id is not None: + save_to_logging_context({"license": item_update_license_id}) + license_ = db.query(License).filter(License.id == item_update_license_id).one_or_none() + + if not license_: + logger.info( + msg="Failed to update score set; The requested license does not exist.", extra=logging_context() + ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Unknown license") + + # Allow in-active licenses to be retained on update if they already exist on the item. + elif not license_.active and item.license.id != item_update_license_id: + logger.info( + msg="Failed to update score set license; The requested license is no longer active.", + extra=logging_context(), + ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid license") + + item.license = license_ + + if "doi_identifiers" in item_update_dict: + doi_identifiers_list = [ + DoiIdentifierCreate(**identifier) for identifier in item_update_dict.get("doi_identifiers") or [] + ] + item.doi_identifiers = [ + await find_or_create_doi_identifier(db, identifier.identifier) for identifier in doi_identifiers_list + ] + + if any(key in item_update_dict for key in ["primary_publication_identifiers", "secondary_publication_identifiers"]): + if "primary_publication_identifiers" in item_update_dict: + primary_publication_identifiers_list = [ + PublicationIdentifierCreate(**identifier) + for identifier in item_update_dict.get("primary_publication_identifiers") or [] + ] + primary_publication_identifiers = [ + await find_or_create_publication_identifier(db, identifier.identifier, identifier.db_name) + for identifier in primary_publication_identifiers_list + ] + else: + # set to existing primary publication identifiers if not provided in update + primary_publication_identifiers = [p for p in item.publication_identifiers if getattr(p, "primary", False)] + + if "secondary_publication_identifiers" in item_update_dict: + secondary_publication_identifiers_list = [ + PublicationIdentifierCreate(**identifier) + for identifier in item_update_dict.get("secondary_publication_identifiers") or [] + ] + secondary_publication_identifiers = [ + await find_or_create_publication_identifier(db, identifier.identifier, identifier.db_name) + for identifier in secondary_publication_identifiers_list + ] + else: + # set to existing secondary publication identifiers if not provided in update + secondary_publication_identifiers = [ + p for p in item.publication_identifiers if not getattr(p, "primary", False) + ] + + publication_identifiers = primary_publication_identifiers + secondary_publication_identifiers + + # create a temporary `primary` attribute on each of our publications that indicates + # to our association proxy whether it is a primary publication or not + primary_identifiers = [p.identifier for p in primary_publication_identifiers] + for publication in publication_identifiers: + setattr(publication, "primary", publication.identifier in primary_identifiers) + + item.publication_identifiers = publication_identifiers + + if "contributors" in item_update_dict: + try: + contributors = [ + ContributorCreate(**contributor) for contributor in item_update_dict.get("contributors") or [] + ] + item.contributors = [ + await find_or_create_contributor(db, contributor.orcid_id) for contributor in contributors + ] + except NonexistentOrcidUserError as e: + logger.error(msg="Could not find ORCID user with the provided user ID.", extra=logging_context()) + raise HTTPException(status_code=422, detail=str(e)) + + # Score set has not been published and attributes affecting scores may still be edited. + if item.private: + if "score_ranges" in item_update_dict: + item.score_ranges = item_update_dict.get("score_ranges", null()) + + if "target_genes" in item_update_dict: + # stash existing target gene ids to compare after update, to determine if variants need to be re-created + assert all(tg.id is not None for tg in item.target_genes) + existing_target_ids: list[int] = [tg.id for tg in item.target_genes if tg.id is not None] + + targets: List[TargetGene] = [] + accessions = False + + for tg in item_update_dict.get("target_genes", []): + gene = TargetGeneCreate(**tg) + if gene.target_sequence: + if accessions and len(targets) > 0: + logger.info( + msg="Failed to update score set; Both a sequence and accession based target were detected.", + extra=logging_context(), + ) + + raise MixedTargetError( + "MaveDB does not support score-sets with both sequence and accession based targets. Please re-submit this scoreset using only one type of target." + ) + + upload_taxonomy = gene.target_sequence.taxonomy + save_to_logging_context({"requested_taxonomy": gene.target_sequence.taxonomy.code}) + taxonomy = await find_or_create_taxonomy(db, upload_taxonomy) + + if not taxonomy: + logger.info( + msg="Failed to create score set; The requested taxonomy does not exist.", + extra=logging_context(), + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unknown taxonomy {gene.target_sequence.taxonomy.code}", + ) + + # If the target sequence has a label, use it. Otherwise, use the name from the target gene as the label. + # View model validation rules enforce that sequences must have a label defined if there are more than one + # targets defined on a score set. + seq_label = gene.target_sequence.label if gene.target_sequence.label is not None else gene.name + + target_gene = target_gene = find_or_create_target_gene_by_sequence( + db, + score_set_id=item.id, + tg=jsonable_encoder( + gene, + by_alias=False, + exclude={ + "external_identifiers", + "target_sequence", + "target_accession", + }, + ), + tg_sequence={ + **jsonable_encoder(gene.target_sequence, by_alias=False, exclude={"taxonomy", "label"}), + "taxonomy": taxonomy, + "label": seq_label, + }, + ) + + elif gene.target_accession: + if not accessions and len(targets) > 0: + logger.info( + msg="Failed to create score set; Both a sequence and accession based target were detected.", + extra=logging_context(), + ) + raise MixedTargetError( + "MaveDB does not support score-sets with both sequence and accession based targets. Please re-submit this scoreset using only one type of target." + ) + accessions = True + + target_gene = find_or_create_target_gene_by_accession( + db, + score_set_id=item.id, + tg=jsonable_encoder( + gene, + by_alias=False, + exclude={ + "external_identifiers", + "target_sequence", + "target_accession", + }, + ), + tg_accession=jsonable_encoder(gene.target_accession, by_alias=False), + ) + else: + save_to_logging_context({"failing_target": gene}) + logger.info(msg="Failed to create score set; Could not infer target type.", extra=logging_context()) + raise ValueError("One of either `target_accession` or `target_gene` should be present") + + for external_gene_identifier_offset_create in gene.external_identifiers: + offset = external_gene_identifier_offset_create.offset + identifier_create = external_gene_identifier_offset_create.identifier + await create_external_gene_identifier_offset( + db, + target_gene, + identifier_create.db_name, + identifier_create.identifier, + offset, + ) + + targets.append(target_gene) + + item.target_genes = targets + + assert all(tg.id is not None for tg in item.target_genes) + current_target_ids: list[int] = [tg.id for tg in item.target_genes if tg.id is not None] + + if sorted(existing_target_ids) != sorted(current_target_ids): + logger.info(msg=f"Target genes have changed for score set {item.id}", extra=logging_context()) + should_create_variants = True if item.variants else False + + else: + logger.debug(msg="Skipped score range and target gene update. Score set is published.", extra=logging_context()) + + db.add(item) + db.commit() + db.refresh(item) + + save_to_logging_context({"updated_resource": item.urn}) + return {"item": item, "should_create_variants": should_create_variants} + + +class ParseScoreSetUpdate(TypedDict): + scores_df: Optional[pd.DataFrame] + counts_df: Optional[pd.DataFrame] + + +async def parse_score_set_variants_uploads( + scores_file: Optional[UploadFile] = File(None), + counts_file: Optional[UploadFile] = File(None), +) -> ParseScoreSetUpdate: + if scores_file and scores_file.file: + try: + scores_df = csv_data_to_df(scores_file.file) + # Handle non-utf8 file problem. + except UnicodeDecodeError as e: + raise HTTPException( + status_code=400, detail=f"Error decoding file: {e}. Ensure the file has correct values." + ) + else: + scores_df = None + + if counts_file and counts_file.file: + try: + counts_df = csv_data_to_df(counts_file.file) + # Handle non-utf8 file problem. + except UnicodeDecodeError as e: + raise HTTPException( + status_code=400, detail=f"Error decoding file: {e}. Ensure the file has correct values." + ) + else: + counts_df = None + + return { + "scores_df": scores_df, + "counts_df": counts_df, + } + + async def fetch_score_set_by_urn( db, urn: str, user: Optional[UserData], owner_or_contributor: Optional[UserData], only_published: bool ) -> ScoreSet: @@ -873,6 +1222,7 @@ async def create_score_set( # View model validation rules enforce that sequences must have a label defined if there are more than one # targets defined on a score set. seq_label = gene.target_sequence.label if gene.target_sequence.label is not None else gene.name + target_sequence = TargetSequence( **jsonable_encoder(gene.target_sequence, by_alias=False, exclude={"taxonomy", "label"}), taxonomy=taxonomy, @@ -985,8 +1335,11 @@ async def create_score_set( async def upload_score_set_variant_data( *, urn: str, + data: Request, counts_file: Optional[UploadFile] = File(None), - scores_file: UploadFile = File(...), + scores_file: Optional[UploadFile] = File(None), + # count_columns_metadata: Optional[dict[str, DatasetColumnMetadata]] = None, + # score_columns_metadata: Optional[dict[str, DatasetColumnMetadata]] = None, db: Session = Depends(deps.get_db), user_data: UserData = Depends(require_current_user_with_email), worker: ArqRedis = Depends(deps.get_worker), @@ -997,6 +1350,19 @@ async def upload_score_set_variant_data( """ save_to_logging_context({"requested_resource": urn, "resource_property": "variants"}) + try: + score_set_variants_data = await parse_score_set_variants_uploads(scores_file, counts_file) + + form_data = await data.form() + # Parse variants dataset column metadata JSON strings + dataset_column_metadata = { + key: json.loads(str(value)) + for key, value in form_data.items() + if key in ["count_columns_metadata", "score_columns_metadata"] + } + except Exception as e: + raise HTTPException(status_code=422, detail=str(e)) + # item = db.query(ScoreSet).filter(ScoreSet.urn == urn).filter(ScoreSet.private.is_(False)).one_or_none() item = db.query(ScoreSet).filter(ScoreSet.urn == urn).one_or_none() if not item or not item.urn: @@ -1006,37 +1372,27 @@ async def upload_score_set_variant_data( assert_permission(user_data, item, Action.UPDATE) assert_permission(user_data, item, Action.SET_SCORES) - try: - scores_df = csv_data_to_df(scores_file.file) - counts_df = None - if counts_file and counts_file.filename: - counts_df = csv_data_to_df(counts_file.file) - # Handle non-utf8 file problem. - except UnicodeDecodeError as e: - raise HTTPException(status_code=400, detail=f"Error decoding file: {e}. Ensure the file has correct values.") - - if scores_file: - # Although this is also updated within the variant creation job, update it here - # as well so that we can display the proper UI components (queue invocation delay - # races the score set GET request). - item.processing_state = ProcessingState.processing - - # await the insertion of this job into the worker queue, not the job itself. - job = await worker.enqueue_job( - "create_variants_for_score_set", - correlation_id_for_context(), - item.id, - user_data.user.id, - scores_df, - counts_df, - ) - if job is not None: - save_to_logging_context({"worker_job_id": job.job_id}) - logger.info(msg="Enqueud variant creation job.", extra=logging_context()) + # Although this is also updated within the variant creation job, update it here + # as well so that we can display the proper UI components (queue invocation delay + # races the score set GET request). + item.processing_state = ProcessingState.processing + + logger.info(msg="Enqueuing variant creation job.", extra=logging_context()) + + await enqueue_variant_creation( + item=item, + user_data=user_data, + new_scores_df=score_set_variants_data["scores_df"], + new_counts_df=score_set_variants_data["counts_df"], + new_score_columns_metadata=dataset_column_metadata.get("score_columns_metadata", {}), + new_count_columns_metadata=dataset_column_metadata.get("count_columns_metadata", {}), + worker=worker, + ) db.add(item) db.commit() db.refresh(item) + enriched_experiment = enrich_experiment_with_num_score_sets(item.experiment, user_data) return score_set.ScoreSet.model_validate(item).copy(update={"experiment": enriched_experiment}) @@ -1077,259 +1433,170 @@ async def update_score_set_range_data( return score_set.ScoreSet.model_validate(item).copy(update={"experiment": enriched_experiment}) -@router.put( - "/score-sets/{urn}", response_model=score_set.ScoreSet, responses={422: {}}, response_model_exclude_none=True +@router.patch( + "/score-sets-with-variants/{urn}", + response_model=score_set.ScoreSet, + responses={422: {}}, + response_model_exclude_none=True, ) -async def update_score_set( +async def update_score_set_with_variants( *, urn: str, - item_update: score_set.ScoreSetUpdate, + request: Request, + # Variants data files + counts_file: Optional[UploadFile] = File(None), + scores_file: Optional[UploadFile] = File(None), db: Session = Depends(deps.get_db), user_data: UserData = Depends(require_current_user_with_email), worker: ArqRedis = Depends(deps.get_worker), ) -> Any: """ - Update a score set. + Update a score set and variants. """ - save_to_logging_context({"requested_resource": urn}) - logger.debug(msg="Began score set update.", extra=logging_context()) + logger.info(msg="Began score set with variants update.", extra=logging_context()) - item = db.query(ScoreSet).filter(ScoreSet.urn == urn).one_or_none() - if not item: - logger.info(msg="Failed to update score set; The requested score set does not exist.", extra=logging_context()) - raise HTTPException(status_code=404, detail=f"score set with URN '{urn}' not found") - - assert_permission(user_data, item, Action.UPDATE) - - for var, value in vars(item_update).items(): - if var not in [ - "contributors", - "score_ranges", - "doi_identifiers", - "experiment_urn", - "license_id", - "secondary_publication_identifiers", - "primary_publication_identifiers", - "target_genes", - ]: - setattr(item, var, value) if value else None + try: + # Get all form data from the request + form_data = await request.form() + + # Convert form data to dictionary, excluding file and associated column metadata fields + form_dict = { + key: value + for key, value in form_data.items() + if key not in ["counts_file", "scores_file", "count_columns_metadata", "score_columns_metadata"] + } + # Create the update object using **kwargs in as_form + item_update_partial = score_set.ScoreSetUpdateAllOptional.as_form(**form_dict) - if item_update.license_id is not None: - save_to_logging_context({"license": item_update.license_id}) - license_ = db.query(License).filter(License.id == item_update.license_id).one_or_none() + # parse uploaded CSV files + score_set_variants_data = await parse_score_set_variants_uploads( + scores_file, + counts_file, + ) - if not license_: - logger.info( - msg="Failed to update score set; The requested license does not exist.", extra=logging_context() - ) - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Unknown license") + # Parse variants dataset column metadata JSON strings + dataset_column_metadata = { + key: json.loads(str(value)) + for key, value in form_data.items() + if key in ["count_columns_metadata", "score_columns_metadata"] + } + except Exception as e: + raise HTTPException(status_code=422, detail=str(e)) - # Allow in-active licenses to be retained on update if they already exist on the item. - elif not license_.active and item.licence_id != item_update.license_id: - logger.info( - msg="Failed to update score set license; The requested license is no longer active.", - extra=logging_context(), - ) - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid license") + # get existing item from db + existing_item = db.query(ScoreSet).filter(ScoreSet.urn == urn).one_or_none() - item.license = license_ + # merge existing item data with item_update data to validate against ScoreSetUpdate - item.doi_identifiers = [ - await find_or_create_doi_identifier(db, identifier.identifier) - for identifier in item_update.doi_identifiers or [] - ] - primary_publication_identifiers = [ - await find_or_create_publication_identifier(db, identifier.identifier, identifier.db_name) - for identifier in item_update.primary_publication_identifiers or [] - ] - publication_identifiers = [ - await find_or_create_publication_identifier(db, identifier.identifier, identifier.db_name) - for identifier in item_update.secondary_publication_identifiers or [] - ] + primary_publication_identifiers + if existing_item: + existing_item_data = score_set.ScoreSet.model_validate(existing_item).model_dump() + updated_data = {**existing_item_data, **item_update_partial.model_dump(exclude_unset=True)} + try: + score_set.ScoreSetUpdate.model_validate(updated_data) + except ValidationError as e: + # format as fastapi request validation error + raise RequestValidationError(errors=e.errors()) + else: + logger.info(msg="Failed to update score set; The requested score set does not exist.", extra=logging_context()) + raise HTTPException(status_code=404, detail=f"score set with URN '{urn}' not found") - # create a temporary `primary` attribute on each of our publications that indicates - # to our association proxy whether it is a primary publication or not - primary_identifiers = [p.identifier for p in primary_publication_identifiers] - for publication in publication_identifiers: - setattr(publication, "primary", publication.identifier in primary_identifiers) + itemUpdateResult = await score_set_update( + db=db, + urn=urn, + item_update=item_update_partial, + exclude_unset=True, + user_data=user_data, + existing_item=existing_item, + ) + updatedItem = itemUpdateResult["item"] + should_create_variants = itemUpdateResult.get("should_create_variants", False) - item.publication_identifiers = publication_identifiers + existing_score_columns_metadata = (existing_item.dataset_columns or {}).get("score_columns_metadata", {}) + existing_count_columns_metadata = (existing_item.dataset_columns or {}).get("count_columns_metadata", {}) - try: - item.contributors = [ - await find_or_create_contributor(db, contributor.orcid_id) for contributor in item_update.contributors or [] - ] - except NonexistentOrcidUserError as e: - logger.error(msg="Could not find ORCID user with the provided user ID.", extra=logging_context()) - raise HTTPException(status_code=422, detail=str(e)) + did_score_columns_metadata_change = ( + dataset_column_metadata.get("score_columns_metadata", {}) != existing_score_columns_metadata + ) + did_count_columns_metadata_change = ( + dataset_column_metadata.get("count_columns_metadata", {}) != existing_count_columns_metadata + ) - # Score set has not been published and attributes affecting scores may still be edited. - if item.private: - if item_update.score_ranges: - item.score_ranges = item_update.score_ranges.model_dump() - else: - item.score_ranges = null() - - # Delete the old target gene, WT sequence, and reference map. These will be deleted when we set the score set's - # target_gene to None, because we have set cascade='all,delete-orphan' on ScoreSet.target_gene. (Since the - # relationship is defined with the target gene as owner, this is actually set up in the backref attribute of - # TargetGene.score_set.) - # - # We must flush our database queries now so that the old target gene will be deleted before inserting a new one - # with the same score_set_id. - item.target_genes = [] - db.flush() - - targets: List[TargetGene] = [] - accessions = False - for gene in item_update.target_genes: - if gene.target_sequence: - if accessions and len(targets) > 0: - logger.info( - msg="Failed to update score set; Both a sequence and accession based target were detected.", - extra=logging_context(), - ) + # run variant creation job only if targets have changed (indicated by "should_create_variants"), new score + # or count files were uploaded, or dataset column metadata has changed + if ( + should_create_variants + or did_score_columns_metadata_change + or did_count_columns_metadata_change + or any([val is not None for val in score_set_variants_data.values()]) + ): + assert_permission(user_data, updatedItem, Action.SET_SCORES) + + updatedItem.processing_state = ProcessingState.processing + logger.info(msg="Enqueuing variant creation job.", extra=logging_context()) + + await enqueue_variant_creation( + item=updatedItem, + user_data=user_data, + worker=worker, + new_scores_df=score_set_variants_data["scores_df"], + new_counts_df=score_set_variants_data["counts_df"], + new_score_columns_metadata=dataset_column_metadata.get("score_columns_metadata") + if did_score_columns_metadata_change + else existing_score_columns_metadata, + new_count_columns_metadata=dataset_column_metadata.get("count_columns_metadata") + if did_count_columns_metadata_change + else existing_count_columns_metadata, + ) - raise MixedTargetError( - "MaveDB does not support score-sets with both sequence and accession based targets. Please re-submit this scoreset using only one type of target." - ) + db.add(updatedItem) + db.commit() + db.refresh(updatedItem) - upload_taxonomy = gene.target_sequence.taxonomy - save_to_logging_context({"requested_taxonomy": gene.target_sequence.taxonomy.code}) - taxonomy = await find_or_create_taxonomy(db, upload_taxonomy) + enriched_experiment = enrich_experiment_with_num_score_sets(updatedItem.experiment, user_data) + return score_set.ScoreSet.model_validate(updatedItem).copy(update={"experiment": enriched_experiment}) - if not taxonomy: - logger.info( - msg="Failed to create score set; The requested taxonomy does not exist.", - extra=logging_context(), - ) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Unknown taxonomy {gene.target_sequence.taxonomy.code}", - ) - # If the target sequence has a label, use it. Otherwise, use the name from the target gene as the label. - # View model validation rules enforce that sequences must have a label defined if there are more than one - # targets defined on a score set. - seq_label = gene.target_sequence.label if gene.target_sequence.label is not None else gene.name - target_sequence = TargetSequence( - **jsonable_encoder( - gene.target_sequence, - by_alias=False, - exclude={"taxonomy", "label"}, - ), - taxonomy=taxonomy, - label=seq_label, - ) - target_gene = TargetGene( - **jsonable_encoder( - gene, - by_alias=False, - exclude={ - "external_identifiers", - "target_sequence", - "target_accession", - }, - ), - target_sequence=target_sequence, - ) +@router.put( + "/score-sets/{urn}", response_model=score_set.ScoreSet, responses={422: {}}, response_model_exclude_none=True +) +async def update_score_set( + *, + urn: str, + item_update: score_set.ScoreSetUpdate, + db: Session = Depends(deps.get_db), + user_data: UserData = Depends(require_current_user_with_email), + worker: ArqRedis = Depends(deps.get_worker), +) -> Any: + """ + Update a score set. + """ + save_to_logging_context({"requested_resource": urn}) + logger.debug(msg="Began score set update.", extra=logging_context()) - elif gene.target_accession: - if not accessions and len(targets) > 0: - logger.info( - msg="Failed to create score set; Both a sequence and accession based target were detected.", - extra=logging_context(), - ) - raise MixedTargetError( - "MaveDB does not support score-sets with both sequence and accession based targets. Please re-submit this scoreset using only one type of target." - ) - accessions = True - target_accession = TargetAccession(**jsonable_encoder(gene.target_accession, by_alias=False)) - target_gene = TargetGene( - **jsonable_encoder( - gene, - by_alias=False, - exclude={ - "external_identifiers", - "target_sequence", - "target_accession", - }, - ), - target_accession=target_accession, - ) - else: - save_to_logging_context({"failing_target": gene}) - logger.info(msg="Failed to create score set; Could not infer target type.", extra=logging_context()) - raise ValueError("One of either `target_accession` or `target_gene` should be present") - - for external_gene_identifier_offset_create in gene.external_identifiers: - offset = external_gene_identifier_offset_create.offset - identifier_create = external_gene_identifier_offset_create.identifier - await create_external_gene_identifier_offset( - db, - target_gene, - identifier_create.db_name, - identifier_create.identifier, - offset, - ) + # this object will contain all required fields because item_update type is ScoreSetUpdate, but + # is converted to instance of ScoreSetUpdateAllOptional to match expected input of score_set_update function + score_set_update_item = score_set.ScoreSetUpdateAllOptional.model_validate(item_update.model_dump()) + itemUpdateResult = await score_set_update( + db=db, urn=urn, item_update=score_set_update_item, exclude_unset=False, user_data=user_data + ) + updatedItem = itemUpdateResult["item"] + should_create_variants = itemUpdateResult["should_create_variants"] - targets.append(target_gene) - - item.target_genes = targets - - # re-validate existing variants and clear them if they do not pass validation - if item.variants: - assert item.dataset_columns is not None - score_columns = [ - "hgvs_nt", - "hgvs_splice", - "hgvs_pro", - ] + item.dataset_columns["score_columns"] - count_columns = [ - "hgvs_nt", - "hgvs_splice", - "hgvs_pro", - ] + item.dataset_columns["count_columns"] - - scores_data = pd.DataFrame( - variants_to_csv_rows(item.variants, columns=score_columns, dtype="score_data") - ).replace("NA", pd.NA) - - if item.dataset_columns["count_columns"]: - count_data = pd.DataFrame( - variants_to_csv_rows(item.variants, columns=count_columns, dtype="count_data") - ).replace("NA", pd.NA) - else: - count_data = None - - # Although this is also updated within the variant creation job, update it here - # as well so that we can display the proper UI components (queue invocation delay - # races the score set GET request). - item.processing_state = ProcessingState.processing - - # await the insertion of this job into the worker queue, not the job itself. - job = await worker.enqueue_job( - "create_variants_for_score_set", - correlation_id_for_context(), - item.id, - user_data.user.id, - scores_data, - count_data, - ) - if job is not None: - save_to_logging_context({"worker_job_id": job.job_id}) - logger.info(msg="Enqueud variant creation job.", extra=logging_context()) - else: - logger.debug(msg="Skipped score range and target gene update. Score set is published.", extra=logging_context()) + if should_create_variants: + # Although this is also updated within the variant creation job, update it here + # as well so that we can display the proper UI components (queue invocation delay + # races the score set GET request). + updatedItem.processing_state = ProcessingState.processing - db.add(item) - db.commit() - db.refresh(item) + logger.info(msg="Enqueuing variant creation job.", extra=logging_context()) + await enqueue_variant_creation(item=updatedItem, user_data=user_data, worker=worker) - save_to_logging_context({"updated_resource": item.urn}) + db.add(updatedItem) + db.commit() + db.refresh(updatedItem) - enriched_experiment = enrich_experiment_with_num_score_sets(item.experiment, user_data) - return score_set.ScoreSet.model_validate(item).copy(update={"experiment": enriched_experiment}) + enriched_experiment = enrich_experiment_with_num_score_sets(updatedItem.experiment, user_data) + return score_set.ScoreSet.model_validate(updatedItem).copy(update={"experiment": enriched_experiment}) @router.delete("/score-sets/{urn}", responses={422: {}}) diff --git a/src/mavedb/view_models/score_set.py b/src/mavedb/view_models/score_set.py index 1dcb74d5..1b3328d5 100644 --- a/src/mavedb/view_models/score_set.py +++ b/src/mavedb/view_models/score_set.py @@ -1,25 +1,25 @@ # See https://pydantic-docs.helpmanual.io/usage/postponed_annotations/#self-referencing-models from __future__ import annotations +import json from datetime import date from typing import Any, Collection, Optional, Sequence, Union -from typing_extensions import Self -from humps import camelize from pydantic import field_validator, model_validator +from typing_extensions import Self from mavedb.lib.validation import urn_re from mavedb.lib.validation.exceptions import ValidationError +from mavedb.lib.validation.transform import ( + transform_publication_identifiers_to_primary_and_secondary, + transform_score_set_list_to_urn_list, +) from mavedb.lib.validation.utilities import is_null from mavedb.models.enums.mapping_state import MappingState from mavedb.models.enums.processing_state import ProcessingState from mavedb.view_models import record_type_validator, set_record_type from mavedb.view_models.base.base import BaseModel from mavedb.view_models.contributor import Contributor, ContributorCreate -from mavedb.lib.validation.transform import ( - transform_score_set_list_to_urn_list, - transform_publication_identifiers_to_primary_and_secondary, -) from mavedb.view_models.doi_identifier import ( DoiIdentifier, DoiIdentifierCreate, @@ -31,7 +31,8 @@ PublicationIdentifierCreate, SavedPublicationIdentifier, ) -from mavedb.view_models.score_range import SavedScoreSetRanges, ScoreSetRangesCreate, ScoreSetRanges +from mavedb.view_models.score_range import SavedScoreSetRanges, ScoreSetRanges, ScoreSetRangesCreate +from mavedb.view_models.score_set_dataset_columns import DatasetColumns, SavedDatasetColumns from mavedb.view_models.target_gene import ( SavedTargetGene, ShortTargetGene, @@ -39,7 +40,7 @@ TargetGeneCreate, ) from mavedb.view_models.user import SavedUser, User - +from mavedb.view_models.utils import all_fields_optional_model UnboundedRange = tuple[Union[float, None], Union[float, None]] @@ -69,7 +70,7 @@ class ScoreSetBase(BaseModel): data_usage_policy: Optional[str] = None -class ScoreSetModify(ScoreSetBase): +class ScoreSetModifyBase(ScoreSetBase): contributors: Optional[list[ContributorCreate]] = None primary_publication_identifiers: Optional[list[PublicationIdentifierCreate]] = None secondary_publication_identifiers: Optional[list[PublicationIdentifierCreate]] = None @@ -77,6 +78,10 @@ class ScoreSetModify(ScoreSetBase): target_genes: list[TargetGeneCreate] score_ranges: Optional[ScoreSetRangesCreate] = None + +class ScoreSetModify(ScoreSetModifyBase): + """View model that adds custom validators to ScoreSetModifyBase.""" + @field_validator("title", "short_description", "abstract_text", "method_text") def validate_field_is_non_empty(cls, v: str) -> str: if is_null(v): @@ -87,7 +92,7 @@ def validate_field_is_non_empty(cls, v: str) -> str: def max_one_primary_publication_identifier( cls, v: list[PublicationIdentifierCreate] ) -> list[PublicationIdentifierCreate]: - if len(v) > 1: + if v is not None and len(v) > 1: raise ValidationError("Multiple primary publication identifiers are not allowed.") return v @@ -269,12 +274,57 @@ def validate_experiment_urn_required_except_for_meta_analyses(self) -> Self: return self +class ScoreSetUpdateBase(ScoreSetModifyBase): + """View model for updating a score set with no custom validators.""" + + license_id: Optional[int] = None + + class ScoreSetUpdate(ScoreSetModify): - """View model for updating a score set.""" + """View model for updating a score set that includes custom validators.""" license_id: Optional[int] = None +@all_fields_optional_model() +class ScoreSetUpdateAllOptional(ScoreSetUpdateBase): + @classmethod + def as_form(cls, **kwargs: Any) -> "ScoreSetUpdateAllOptional": + """Create ScoreSetUpdateAllOptional from form data.""" + + # Define which fields need special JSON parsing + json_fields = { + "contributors": lambda data: [ContributorCreate.model_validate(c) for c in data] if data else None, + "primary_publication_identifiers": lambda data: [ + PublicationIdentifierCreate.model_validate(p) for p in data + ] + if data + else None, + "secondary_publication_identifiers": lambda data: [ + PublicationIdentifierCreate.model_validate(s) for s in data + ] + if data + else None, + "doi_identifiers": lambda data: [DoiIdentifierCreate.model_validate(d) for d in data] if data else None, + "target_genes": lambda data: [TargetGeneCreate.model_validate(t) for t in data] if data else None, + "score_ranges": lambda data: ScoreSetRangesCreate.model_validate(data) if data else None, + "extra_metadata": lambda data: data, + } + + # Process all fields dynamically + processed_kwargs = {} + + for field_name, value in kwargs.items(): + if field_name in json_fields and value is not None and isinstance(value, str): + parsed_value = json.loads(value) + processed_kwargs[field_name] = json_fields[field_name](parsed_value) + else: + # All other fields pass through as-is + processed_kwargs[field_name] = value + + return cls(**processed_kwargs) + + class ShortScoreSet(BaseModel): """ Score set view model containing a smaller set of properties to return in list contexts. @@ -357,7 +407,7 @@ class SavedScoreSet(ScoreSetBase): created_by: Optional[SavedUser] = None modified_by: Optional[SavedUser] = None target_genes: Sequence[SavedTargetGene] - dataset_columns: dict + dataset_columns: Optional[SavedDatasetColumns] = None external_links: dict[str, ExternalLink] contributors: Sequence[Contributor] score_ranges: Optional[SavedScoreSetRanges] = None @@ -375,10 +425,6 @@ def publication_identifiers_validator(cls, value: Any) -> list[PublicationIdenti assert isinstance(value, Collection), "Publication identifier lists must be a collection" return list(value) # Re-cast into proper list-like type - @field_validator("dataset_columns") - def camelize_dataset_columns_keys(cls, value: dict) -> dict: - return camelize(value) - # These 'synthetic' fields are generated from other model properties. Transform data from other properties as needed, setting # the appropriate field on the model itself. Then, proceed with Pydantic ingestion once fields are created. @model_validator(mode="before") @@ -441,6 +487,7 @@ class ScoreSet(SavedScoreSet): mapping_state: Optional[MappingState] = None mapping_errors: Optional[dict] = None score_ranges: Optional[ScoreSetRanges] = None # type: ignore[assignment] + dataset_columns: Optional[DatasetColumns] = None # type: ignore[assignment] class ScoreSetWithVariants(ScoreSet): diff --git a/src/mavedb/view_models/score_set_dataset_columns.py b/src/mavedb/view_models/score_set_dataset_columns.py new file mode 100644 index 00000000..9435f581 --- /dev/null +++ b/src/mavedb/view_models/score_set_dataset_columns.py @@ -0,0 +1,69 @@ +from typing import Optional + +from pydantic import field_validator, model_validator +from typing_extensions import Self + +from mavedb.lib.validation.exceptions import ValidationError +from mavedb.view_models import record_type_validator, set_record_type +from mavedb.view_models.base.base import BaseModel + + +class DatasetColumnMetadata(BaseModel): + """Metadata for individual dataset columns.""" + + description: str + details: Optional[str] = None + + +class DatasetColumnsBase(BaseModel): + """Dataset columns view model representing the dataset columns property of a score set.""" + + score_columns: Optional[list[str]] = None + count_columns: Optional[list[str]] = None + score_columns_metadata: Optional[dict[str, DatasetColumnMetadata]] = None + count_columns_metadata: Optional[dict[str, DatasetColumnMetadata]] = None + + @field_validator("score_columns_metadata", "count_columns_metadata") + def validate_dataset_columns_metadata( + cls, v: Optional[dict[str, DatasetColumnMetadata]] + ) -> Optional[dict[str, DatasetColumnMetadata]]: + if not v: + return None + for val in v.values(): + DatasetColumnMetadata.model_validate(val) + return v + + @model_validator(mode="after") + def validate_dataset_columns_metadata_keys(self) -> Self: + if self.score_columns_metadata is not None and self.score_columns is None: + raise ValidationError("Score columns metadata cannot be provided without score columns.") + elif self.score_columns_metadata is not None and self.score_columns is not None: + for key in self.score_columns_metadata.keys(): + if key not in self.score_columns: + raise ValidationError(f"Score column metadata key '{key}' does not exist in score_columns list.") + + if self.count_columns_metadata is not None and self.count_columns is None: + raise ValidationError("Count columns metadata cannot be provided without count columns.") + elif self.count_columns_metadata is not None and self.count_columns is not None: + for key in self.count_columns_metadata.keys(): + if key not in self.count_columns: + raise ValidationError(f"Count column metadata key '{key}' does not exist in count_columns list.") + return self + + +class SavedDatasetColumns(DatasetColumnsBase): + record_type: str = None # type: ignore + + _record_type_factory = record_type_validator()(set_record_type) + + +class DatasetColumns(SavedDatasetColumns): + pass + + +class DatasetColumnsCreate(DatasetColumnsBase): + pass + + +class DatasetColumnsModify(DatasetColumnsBase): + pass diff --git a/src/mavedb/view_models/target_gene.py b/src/mavedb/view_models/target_gene.py index 10d6ed89..48396a98 100644 --- a/src/mavedb/view_models/target_gene.py +++ b/src/mavedb/view_models/target_gene.py @@ -1,8 +1,8 @@ from datetime import date from typing import Any, Optional, Sequence -from typing_extensions import Self from pydantic import Field, model_validator +from typing_extensions import Self from mavedb.lib.validation.exceptions import ValidationError from mavedb.lib.validation.transform import transform_external_identifier_offsets_to_list, transform_score_set_to_urn diff --git a/src/mavedb/view_models/utils.py b/src/mavedb/view_models/utils.py new file mode 100644 index 00000000..5a7f43da --- /dev/null +++ b/src/mavedb/view_models/utils.py @@ -0,0 +1,36 @@ +from copy import deepcopy +from typing import Any, Callable, Optional, Type, TypeVar + +from pydantic import create_model +from pydantic.fields import FieldInfo + +from mavedb.view_models.base.base import BaseModel + +Model = TypeVar("Model", bound=BaseModel) + + +def all_fields_optional_model() -> Callable[[Type[Model]], Type[Model]]: + """A decorator that create a partial model. + + Args: + model (Type[BaseModel]): BaseModel model. + + Returns: + Type[BaseModel]: ModelBase partial model. + """ + + def wrapper(model: Type[Model]) -> Type[Model]: + def make_field_optional(field: FieldInfo, default: Any = None) -> tuple[Any, FieldInfo]: + new = deepcopy(field) + new.default = default + new.annotation = Optional[field.annotation] # type: ignore[assignment] + return new.annotation, new + + return create_model( + model.__name__, + __base__=model, + __module__=model.__module__, + **{field_name: make_field_optional(field_info) for field_name, field_info in model.model_fields.items()}, + ) # type: ignore[call-overload] + + return wrapper diff --git a/src/mavedb/worker/jobs.py b/src/mavedb/worker/jobs.py index c119a360..437caa64 100644 --- a/src/mavedb/worker/jobs.py +++ b/src/mavedb/worker/jobs.py @@ -63,6 +63,7 @@ from mavedb.models.score_set import ScoreSet from mavedb.models.user import User from mavedb.models.variant import Variant +from mavedb.view_models.score_set_dataset_columns import DatasetColumnMetadata logger = logging.getLogger(__name__) @@ -120,7 +121,14 @@ async def enqueue_job_with_backoff( async def create_variants_for_score_set( - ctx, correlation_id: str, score_set_id: int, updater_id: int, scores: pd.DataFrame, counts: pd.DataFrame + ctx, + correlation_id: str, + score_set_id: int, + updater_id: int, + scores: pd.DataFrame, + counts: pd.DataFrame, + score_columns_metadata: Optional[dict[str, DatasetColumnMetadata]] = None, + count_columns_metadata: Optional[dict[str, DatasetColumnMetadata]] = None, ): """ Create variants for a score set. Intended to be run within a worker. @@ -156,13 +164,26 @@ async def create_variants_for_score_set( ) raise ValueError("Can't create variants when score set has no targets.") - validated_scores, validated_counts = validate_and_standardize_dataframe_pair( - scores, counts, score_set.target_genes, hdp + validated_scores, validated_counts, validated_score_columns_metadata, validated_count_columns_metadata = ( + validate_and_standardize_dataframe_pair( + scores_df=scores, + counts_df=counts, + score_columns_metadata=score_columns_metadata, + count_columns_metadata=count_columns_metadata, + targets=score_set.target_genes, + hdp=hdp, + ) ) score_set.dataset_columns = { "score_columns": columns_for_dataset(validated_scores), "count_columns": columns_for_dataset(validated_counts), + "score_columns_metadata": validated_score_columns_metadata + if validated_score_columns_metadata is not None + else {}, + "count_columns_metadata": validated_count_columns_metadata + if validated_count_columns_metadata is not None + else {}, } # Delete variants after validation occurs so we don't overwrite them in the case of a bad update. @@ -1275,9 +1296,9 @@ async def link_clingen_variants(ctx: dict, correlation_id: str, score_set_id: in logging_context["linkage_failures"] = num_linkage_failures logging_context["linkage_successes"] = num_variant_urns - num_linkage_failures - assert ( - len(linked_allele_ids) == num_variant_urns - ), f"{num_variant_urns - len(linked_allele_ids)} appear to not have been attempted to be linked." + assert len(linked_allele_ids) == num_variant_urns, ( + f"{num_variant_urns - len(linked_allele_ids)} appear to not have been attempted to be linked." + ) job_succeeded = False if not linkage_failures: @@ -1369,7 +1390,7 @@ async def link_clingen_variants(ctx: dict, correlation_id: str, score_set_id: in extra=logging_context, ) send_slack_message( - text=f"Failed to link {len(linkage_failures)} ({ratio_failed_linking*100}% of total mapped variants for {score_set.urn})." + text=f"Failed to link {len(linkage_failures)} ({ratio_failed_linking * 100}% of total mapped variants for {score_set.urn})." f"This job was successfully retried. This was attempt {attempt}. Retry will occur in {backoff_time} seconds. URNs failed to link: {', '.join(linkage_failures)}." ) elif new_job_id is None and not max_retries_exceeded: diff --git a/tests/helpers/constants.py b/tests/helpers/constants.py index 26edfec4..78fc8d01 100644 --- a/tests/helpers/constants.py +++ b/tests/helpers/constants.py @@ -4,7 +4,6 @@ from mavedb.models.enums.processing_state import ProcessingState - VALID_EXPERIMENT_SET_URN = "urn:mavedb:01234567" VALID_EXPERIMENT_URN = f"{VALID_EXPERIMENT_SET_URN}-a" VALID_SCORE_SET_URN = f"{VALID_EXPERIMENT_URN}-1" @@ -693,6 +692,12 @@ "active": TEST_INACTIVE_LICENSE["active"], } +SAVED_MINIMAL_DATASET_COLUMNS = { + "recordType": "DatasetColumns", + "countColumns": [], + "scoreColumns": ["score", "s_0", "s_1"], +} + TEST_SEQ_SCORESET = { "title": "Test Score Set Title", "short_description": "Test score set", @@ -805,7 +810,9 @@ "doiIdentifiers": [], "primaryPublicationIdentifiers": [], "secondaryPublicationIdentifiers": [], - "datasetColumns": {}, + "datasetColumns": { + "recordType": "DatasetColumns", + }, "externalLinks": {}, "private": True, "experiment": TEST_MINIMAL_EXPERIMENT_RESPONSE, @@ -920,7 +927,9 @@ "doiIdentifiers": [], "primaryPublicationIdentifiers": [], "secondaryPublicationIdentifiers": [], - "datasetColumns": {}, + "datasetColumns": { + "recordType": "DatasetColumns", + }, "private": True, "experiment": TEST_MINIMAL_EXPERIMENT_RESPONSE, # keys to be set after receiving response @@ -1060,7 +1069,9 @@ "doiIdentifiers": [], "primaryPublicationIdentifiers": [], "secondaryPublicationIdentifiers": [], - "datasetColumns": {}, + "datasetColumns": { + "recordType": "DatasetColumns", + }, "externalLinks": {}, "private": True, "experiment": TEST_MINIMAL_EXPERIMENT_RESPONSE, @@ -1070,6 +1081,19 @@ "officialCollections": [], } +TEST_SCORE_SET_DATASET_COLUMNS = { + "score_columns": ["score", "s_0", "s_1"], + "count_columns": ["c_0", "c_1"], + "score_columns_metadata": { + "s_0": {"description": "s_0 description", "details": "s_0 details"}, + "s_1": {"description": "s_1 description", "details": "s_1 details"}, + }, + "count_columns_metadata": { + "c_0": {"description": "c_0 description", "details": "c_0 details"}, + "c_1": {"description": "c_1 description", "details": "c_1 details"}, + }, +} + TEST_NT_CDOT_TRANSCRIPT = { "start_codon": 0, "stop_codon": 18, diff --git a/tests/helpers/util/score_set.py b/tests/helpers/util/score_set.py index d60101ff..b2a8b2c6 100644 --- a/tests/helpers/util/score_set.py +++ b/tests/helpers/util/score_set.py @@ -1,10 +1,11 @@ -from datetime import date from copy import deepcopy -from unittest.mock import patch +from datetime import date from typing import Any, Dict, Optional +from unittest.mock import patch import cdot.hgvs.dataproviders import jsonschema +from fastapi.testclient import TestClient from sqlalchemy import select from mavedb.models.clinical_control import ClinicalControl as ClinicalControlDbModel @@ -13,11 +14,10 @@ from mavedb.models.score_set import ScoreSet as ScoreSetDbModel from mavedb.models.variant import Variant as VariantDbModel from mavedb.view_models.score_set import ScoreSet, ScoreSetCreate - from tests.helpers.constants import ( TEST_MINIMAL_ACC_SCORESET, - TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_MULTI_TARGET_SCORESET, + TEST_MINIMAL_SEQ_SCORESET, TEST_NT_CDOT_TRANSCRIPT, TEST_VALID_POST_MAPPED_VRS_ALLELE_VRS2_X, TEST_VALID_POST_MAPPED_VRS_CIS_PHASED_BLOCK, @@ -25,7 +25,6 @@ TEST_VALID_PRE_MAPPED_VRS_CIS_PHASED_BLOCK, ) from tests.helpers.util.variant import mock_worker_variant_insertion -from fastapi.testclient import TestClient def create_seq_score_set( @@ -88,10 +87,26 @@ def create_multi_target_score_set( def create_seq_score_set_with_mapped_variants( - client, db, data_provider, experiment_urn, scores_csv_path, update=None, counts_csv_path=None + client, + db, + data_provider, + experiment_urn, + scores_csv_path, + update=None, + counts_csv_path=None, + score_columns_metadata_json_path=None, + count_columns_metadata_json_path=None, ): score_set = create_seq_score_set_with_variants( - client, db, data_provider, experiment_urn, scores_csv_path, update, counts_csv_path + client, + db, + data_provider, + experiment_urn, + scores_csv_path, + update, + counts_csv_path, + score_columns_metadata_json_path, + count_columns_metadata_json_path, ) score_set = mock_worker_vrs_mapping(client, db, score_set) @@ -100,10 +115,26 @@ def create_seq_score_set_with_mapped_variants( def create_acc_score_set_with_mapped_variants( - client, db, data_provider, experiment_urn, scores_csv_path, update=None, counts_csv_path=None + client, + db, + data_provider, + experiment_urn, + scores_csv_path, + update=None, + counts_csv_path=None, + score_columns_metadata_json_path=None, + count_columns_metadata_json_path=None, ): score_set = create_acc_score_set_with_variants( - client, db, data_provider, experiment_urn, scores_csv_path, update, counts_csv_path + client, + db, + data_provider, + experiment_urn, + scores_csv_path, + update, + counts_csv_path, + score_columns_metadata_json_path, + count_columns_metadata_json_path, ) score_set = mock_worker_vrs_mapping(client, db, score_set) @@ -112,28 +143,62 @@ def create_acc_score_set_with_mapped_variants( def create_seq_score_set_with_variants( - client, db, data_provider, experiment_urn, scores_csv_path, update=None, counts_csv_path=None + client, + db, + data_provider, + experiment_urn, + scores_csv_path, + update=None, + counts_csv_path=None, + score_columns_metadata_json_path=None, + count_columns_metadata_json_path=None, ): score_set = create_seq_score_set(client, experiment_urn, update) - score_set = mock_worker_variant_insertion(client, db, data_provider, score_set, scores_csv_path, counts_csv_path) + score_set = mock_worker_variant_insertion( + client, + db, + data_provider, + score_set, + scores_csv_path, + counts_csv_path, + score_columns_metadata_json_path, + count_columns_metadata_json_path, + ) - assert ( - score_set["numVariants"] == 3 - ), f"Could not create sequence based score set with variants within experiment {experiment_urn}" + assert score_set["numVariants"] == 3, ( + f"Could not create sequence based score set with variants within experiment {experiment_urn}" + ) jsonschema.validate(instance=score_set, schema=ScoreSet.model_json_schema()) return score_set def create_acc_score_set_with_variants( - client, db, data_provider, experiment_urn, scores_csv_path, update=None, counts_csv_path=None + client, + db, + data_provider, + experiment_urn, + scores_csv_path, + update=None, + counts_csv_path=None, + score_columns_metadata_json_path=None, + count_columns_metadata_json_path=None, ): score_set = create_acc_score_set(client, experiment_urn, update) - score_set = mock_worker_variant_insertion(client, db, data_provider, score_set, scores_csv_path, counts_csv_path) + score_set = mock_worker_variant_insertion( + client, + db, + data_provider, + score_set, + scores_csv_path, + counts_csv_path, + score_columns_metadata_json_path, + count_columns_metadata_json_path, + ) - assert ( - score_set["numVariants"] == 3 - ), f"Could not create sequence based score set with variants within experiment {experiment_urn}" + assert score_set["numVariants"] == 3, ( + f"Could not create sequence based score set with variants within experiment {experiment_urn}" + ) jsonschema.validate(instance=score_set, schema=ScoreSet.model_json_schema()) return score_set diff --git a/tests/helpers/util/variant.py b/tests/helpers/util/variant.py index 3772d2d2..5fcc05db 100644 --- a/tests/helpers/util/variant.py +++ b/tests/helpers/util/variant.py @@ -1,24 +1,25 @@ +import json from typing import Optional +from unittest.mock import patch from arq import ArqRedis from cdot.hgvs.dataproviders import RESTDataProvider from fastapi.testclient import TestClient -from sqlalchemy.orm import Session from sqlalchemy import select -from unittest.mock import patch +from sqlalchemy.orm import Session -from mavedb.lib.score_sets import create_variants, columns_for_dataset, create_variants_data, csv_data_to_df +from mavedb.lib.score_sets import columns_for_dataset, create_variants, create_variants_data, csv_data_to_df from mavedb.lib.validation.dataframe.dataframe import validate_and_standardize_dataframe_pair -from mavedb.models.enums.processing_state import ProcessingState from mavedb.models.enums.mapping_state import MappingState +from mavedb.models.enums.processing_state import ProcessingState from mavedb.models.mapped_variant import MappedVariant from mavedb.models.score_set import ScoreSet from mavedb.models.target_gene import TargetGene from mavedb.models.variant import Variant - +from mavedb.view_models.score_set_dataset_columns import DatasetColumnsCreate from tests.helpers.constants import ( - TEST_MINIMAL_PRE_MAPPED_METADATA, TEST_MINIMAL_POST_MAPPED_METADATA, + TEST_MINIMAL_PRE_MAPPED_METADATA, ) @@ -29,6 +30,8 @@ def mock_worker_variant_insertion( score_set: dict, scores_csv_path: str, counts_csv_path: Optional[str] = None, + score_columns_metadata_json_path: Optional[str] = None, + count_columns_metadata_json_path: Optional[str] = None, ) -> None: with ( open(scores_csv_path, "rb") as score_file, @@ -42,6 +45,26 @@ def mock_worker_variant_insertion( else: counts_file = None + if score_columns_metadata_json_path is not None: + score_columns_metadata_file = open(score_columns_metadata_json_path, "rb") + files["score_columns_metadata_file"] = ( + score_columns_metadata_json_path.name, + score_columns_metadata_file, + "rb", + ) + else: + score_columns_metadata_file = None + + if count_columns_metadata_json_path is not None: + count_columns_metadata_file = open(count_columns_metadata_json_path, "rb") + files["count_columns_metadata_file"] = ( + count_columns_metadata_json_path.name, + count_columns_metadata_file, + "rb", + ) + else: + count_columns_metadata_file = None + response = client.post(f"/api/v1/score-sets/{score_set['urn']}/variants/data", files=files) # Assert we have mocked a job being added to the queue, and that the request succeeded. The @@ -49,8 +72,9 @@ def mock_worker_variant_insertion( worker_queue.assert_called_once() assert response.status_code == 200 - if counts_file is not None: - counts_file.close() + for file in (counts_file, score_columns_metadata_file, count_columns_metadata_file): + if file is not None: + file.close() # Reopen files since their buffers are consumed while mocking the variant data post request. with open(scores_csv_path, "rb") as score_file: @@ -62,20 +86,36 @@ def mock_worker_variant_insertion( else: counts_df = None + if score_columns_metadata_json_path is not None: + with open(score_columns_metadata_json_path, "rb") as score_columns_metadata_file: + score_columns_metadata = json.load(score_columns_metadata_file) + else: + score_columns_metadata = None + + if count_columns_metadata_json_path is not None: + with open(count_columns_metadata_json_path, "rb") as count_columns_metadata_file: + count_columns_metadata = json.load(count_columns_metadata_file) + else: + count_columns_metadata = None + # Insert variant manually, worker jobs are tested elsewhere separately. item = db.scalars(select(ScoreSet).where(ScoreSet.urn == score_set["urn"])).one_or_none() assert item is not None - scores, counts = validate_and_standardize_dataframe_pair(score_df, counts_df, item.target_genes, data_provider) + scores, counts, score_columns_metadata, count_columns_metadata = validate_and_standardize_dataframe_pair( + score_df, counts_df, score_columns_metadata, count_columns_metadata, item.target_genes, data_provider + ) variants = create_variants_data(scores, counts, None) num_variants = create_variants(db, item, variants) assert num_variants == 3 item.processing_state = ProcessingState.success - item.dataset_columns = { - "score_columns": columns_for_dataset(scores), - "count_columns": columns_for_dataset(counts), - } + item.dataset_columns = DatasetColumnsCreate( + score_columns=columns_for_dataset(scores), + count_columns=columns_for_dataset(counts), + score_columns_metadata=score_columns_metadata if score_columns_metadata is not None else {}, + count_columns_metadata=count_columns_metadata if count_columns_metadata is not None else {}, + ).model_dump() db.add(item) db.commit() diff --git a/tests/routers/data/count_columns_metadata.json b/tests/routers/data/count_columns_metadata.json new file mode 100644 index 00000000..9aaaa355 --- /dev/null +++ b/tests/routers/data/count_columns_metadata.json @@ -0,0 +1,10 @@ +{ + "c_0": { + "description": "c_0 description", + "details": "c_0 details" + }, + "c_1": { + "description": "c_1 description", + "details": "c_1 details" + } +} diff --git a/tests/routers/data/score_columns_metadata.json b/tests/routers/data/score_columns_metadata.json new file mode 100644 index 00000000..a21bc31e --- /dev/null +++ b/tests/routers/data/score_columns_metadata.json @@ -0,0 +1,10 @@ +{ + "s_0": { + "description": "s_0 description", + "details": "s_0 details" + }, + "s_1": { + "description": "s_0 description", + "details": "s_0 details" + } +} diff --git a/tests/routers/data/scores.csv b/tests/routers/data/scores.csv index a2eb1377..a1f08563 100644 --- a/tests/routers/data/scores.csv +++ b/tests/routers/data/scores.csv @@ -1,4 +1,4 @@ -hgvs_nt,hgvs_pro,score -c.1A>T,p.Thr1Ser,0.3 -c.2C>T,p.Thr1Met,1.0 -c.6T>A,p.Phe2Leu,-1.65 +hgvs_nt,hgvs_pro,score,s_0,s_1 +c.1A>T,p.Thr1Ser,0.3,val1,val1 +c.2C>T,p.Thr1Met,1.0,val2,val2 +c.6T>A,p.Phe2Leu,-1.65,val3,val3 diff --git a/tests/routers/test_score_set.py b/tests/routers/test_score_set.py index 19ef8518..47e41f80 100644 --- a/tests/routers/test_score_set.py +++ b/tests/routers/test_score_set.py @@ -1,8 +1,9 @@ # ruff: noqa: E402 +import csv +import json import re from copy import deepcopy -import csv from datetime import date from io import StringIO from unittest.mock import patch @@ -16,8 +17,8 @@ cdot = pytest.importorskip("cdot") fastapi = pytest.importorskip("fastapi") -from mavedb.lib.validation.urn_re import MAVEDB_TMP_URN_RE, MAVEDB_SCORE_SET_URN_RE, MAVEDB_EXPERIMENT_URN_RE from mavedb.lib.exceptions import NonexistentOrcidUserError +from mavedb.lib.validation.urn_re import MAVEDB_EXPERIMENT_URN_RE, MAVEDB_SCORE_SET_URN_RE, MAVEDB_TMP_URN_RE from mavedb.models.enums.processing_state import ProcessingState from mavedb.models.enums.target_category import TargetCategory from mavedb.models.experiment import Experiment as ExperimentDbModel @@ -25,37 +26,37 @@ from mavedb.models.variant import Variant as VariantDbModel from mavedb.view_models.orcid import OrcidUser from mavedb.view_models.score_set import ScoreSet, ScoreSetCreate - from tests.helpers.constants import ( - EXTRA_USER, EXTRA_LICENSE, + EXTRA_USER, + SAVED_DOI_IDENTIFIER, + SAVED_EXTRA_CONTRIBUTOR, + SAVED_MINIMAL_DATASET_COLUMNS, + SAVED_PUBMED_PUBLICATION, + SAVED_SHORT_EXTRA_LICENSE, TEST_CROSSREF_IDENTIFIER, + TEST_GNOMAD_DATA_VERSION, + TEST_INACTIVE_LICENSE, TEST_MAPPED_VARIANT_WITH_HGVS_G_EXPRESSION, TEST_MAPPED_VARIANT_WITH_HGVS_P_EXPRESSION, TEST_MINIMAL_ACC_SCORESET, + TEST_MINIMAL_ACC_SCORESET_RESPONSE, TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_SEQ_SCORESET_RESPONSE, - TEST_PUBMED_IDENTIFIER, TEST_ORCID_ID, - TEST_MINIMAL_ACC_SCORESET_RESPONSE, - TEST_USER, - TEST_INACTIVE_LICENSE, - SAVED_DOI_IDENTIFIER, - SAVED_EXTRA_CONTRIBUTOR, - SAVED_PUBMED_PUBLICATION, - SAVED_SHORT_EXTRA_LICENSE, + TEST_PUBMED_IDENTIFIER, TEST_SAVED_CLINVAR_CONTROL, TEST_SAVED_GENERIC_CLINICAL_CONTROL, - TEST_SCORE_SET_RANGES_ONLY_INVESTIGATOR_PROVIDED, + TEST_SAVED_GNOMAD_VARIANT, + TEST_SAVED_SCORE_SET_RANGES_ALL_SCHEMAS_PRESENT, TEST_SAVED_SCORE_SET_RANGES_ONLY_INVESTIGATOR_PROVIDED, - TEST_SCORE_SET_RANGES_ONLY_ZEIBERG_CALIBRATION, - TEST_SAVED_SCORE_SET_RANGES_ONLY_ZEIBERG_CALIBRATION, - TEST_SCORE_SET_RANGES_ONLY_SCOTT, TEST_SAVED_SCORE_SET_RANGES_ONLY_SCOTT, + TEST_SAVED_SCORE_SET_RANGES_ONLY_ZEIBERG_CALIBRATION, TEST_SCORE_SET_RANGES_ALL_SCHEMAS_PRESENT, - TEST_SAVED_SCORE_SET_RANGES_ALL_SCHEMAS_PRESENT, - TEST_GNOMAD_DATA_VERSION, - TEST_SAVED_GNOMAD_VARIANT, + TEST_SCORE_SET_RANGES_ONLY_INVESTIGATOR_PROVIDED, + TEST_SCORE_SET_RANGES_ONLY_SCOTT, + TEST_SCORE_SET_RANGES_ONLY_ZEIBERG_CALIBRATION, + TEST_USER, ) from tests.helpers.dependency_overrider import DependencyOverrider from tests.helpers.util.common import update_expected_response_for_created_resources @@ -65,19 +66,18 @@ from tests.helpers.util.score_set import ( create_seq_score_set, create_seq_score_set_with_mapped_variants, + create_seq_score_set_with_variants, link_clinical_controls_to_mapped_variants, link_gnomad_variants_to_mapped_variants, publish_score_set, - create_seq_score_set_with_variants, ) from tests.helpers.util.user import change_ownership from tests.helpers.util.variant import ( + clear_first_mapped_variant_post_mapped, create_mapped_variants_for_score_set, mock_worker_variant_insertion, - clear_first_mapped_variant_post_mapped, ) - ######################################################################################################################## # Score set schemas ######################################################################################################################## @@ -414,6 +414,114 @@ def test_can_update_score_set_data_before_publication( ("secondary_publication_identifiers", [{"identifier": TEST_PUBMED_IDENTIFIER}], [SAVED_PUBMED_PUBLICATION]), ("doi_identifiers", [{"identifier": TEST_CROSSREF_IDENTIFIER}], [SAVED_DOI_IDENTIFIER]), ("license_id", EXTRA_LICENSE["id"], SAVED_SHORT_EXTRA_LICENSE), + ("target_genes", TEST_MINIMAL_ACC_SCORESET["targetGenes"], TEST_MINIMAL_ACC_SCORESET_RESPONSE["targetGenes"]), + ("score_ranges", TEST_SCORE_SET_RANGES_ALL_SCHEMAS_PRESENT, TEST_SAVED_SCORE_SET_RANGES_ALL_SCHEMAS_PRESENT), + ], +) +@pytest.mark.parametrize( + "mock_publication_fetch", + [({"dbName": "PubMed", "identifier": f"{TEST_PUBMED_IDENTIFIER}"})], + indirect=["mock_publication_fetch"], +) +def test_can_patch_score_set_data_before_publication( + client, setup_router_db, attribute, updated_data, expected_response_data, mock_publication_fetch +): + experiment = create_experiment(client) + score_set = create_seq_score_set(client, experiment["urn"]) + expected_response = update_expected_response_for_created_resources( + deepcopy(TEST_MINIMAL_SEQ_SCORESET_RESPONSE), experiment, score_set + ) + expected_response["experiment"].update({"numScoreSets": 1}) + + response = client.get(f"/api/v1/score-sets/{score_set['urn']}") + assert response.status_code == 200 + response_data = response.json() + + assert sorted(expected_response.keys()) == sorted(response_data.keys()) + for key in expected_response: + assert (key, expected_response[key]) == (key, response_data[key]) + + data = {} + if isinstance(updated_data, (dict, list)): + form_value = json.dumps(updated_data) + else: + form_value = str(updated_data) + data[attribute] = form_value + + # The score ranges attribute requires a publication identifier source + if attribute == "score_ranges": + data["secondary_publication_identifiers"] = json.dumps( + [{"identifier": TEST_PUBMED_IDENTIFIER, "dbName": "PubMed"}] + ) + + response = client.patch(f"/api/v1/score-sets-with-variants/{score_set['urn']}", data=data) + assert response.status_code == 200 + + response = client.get(f"/api/v1/score-sets/{score_set['urn']}") + assert response.status_code == 200 + response_data = response.json() + + # Although the client provides the license id, the response includes the full license. + if attribute == "license_id": + attribute = "license" + + assert expected_response_data == response_data[camelize(attribute)] + + +@pytest.mark.parametrize( + "form_field,filename,mime_type", + [ + ("scores_file", "scores.csv", "text/csv"), + ("counts_file", "counts.csv", "text/csv"), + ("score_columns_metadata_file", "score_columns_metadata.json", "application/json"), + ("count_columns_metadata_file", "count_columns_metadata.json", "application/json"), + ], +) +@pytest.mark.parametrize( + "mock_publication_fetch", + [({"dbName": "PubMed", "identifier": f"{TEST_PUBMED_IDENTIFIER}"})], + indirect=["mock_publication_fetch"], +) +def test_can_patch_score_set_data_with_files_before_publication( + client, setup_router_db, form_field, filename, mime_type, data_files, mock_publication_fetch +): + experiment = create_experiment(client) + score_set = create_seq_score_set(client, experiment["urn"]) + expected_response = update_expected_response_for_created_resources( + deepcopy(TEST_MINIMAL_SEQ_SCORESET_RESPONSE), experiment, score_set + ) + expected_response["experiment"].update({"numScoreSets": 1}) + + if form_field == "counts_file" or form_field == "scores_file": + data_file_path = data_files / filename + files = {form_field: (filename, open(data_file_path, "rb"), mime_type)} + with patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as worker_queue: + response = client.patch(f"/api/v1/score-sets-with-variants/{score_set['urn']}", files=files) + worker_queue.assert_called_once() + assert response.status_code == 200 + elif form_field == "score_columns_metadata_file" or form_field == "count_columns_metadata_file": + data_file_path = data_files / filename + with open(data_file_path, "rb") as f: + data = json.load(f) + response = client.patch(f"/api/v1/score-sets-with-variants/{score_set['urn']}", data=data) + assert response.status_code == 200 + + +@pytest.mark.parametrize( + "attribute,updated_data,expected_response_data", + [ + ("title", "Updated Title", "Updated Title"), + ("method_text", "Updated Method Text", "Updated Method Text"), + ("abstract_text", "Updated Abstract Text", "Updated Abstract Text"), + ("short_description", "Updated Abstract Text", "Updated Abstract Text"), + ("extra_metadata", {"updated": "metadata"}, {"updated": "metadata"}), + ("data_usage_policy", "data_usage_policy", "data_usage_policy"), + ("contributors", [{"orcid_id": EXTRA_USER["username"]}], [SAVED_EXTRA_CONTRIBUTOR]), + ("primary_publication_identifiers", [{"identifier": TEST_PUBMED_IDENTIFIER}], [SAVED_PUBMED_PUBLICATION]), + ("secondary_publication_identifiers", [{"identifier": TEST_PUBMED_IDENTIFIER}], [SAVED_PUBMED_PUBLICATION]), + ("doi_identifiers", [{"identifier": TEST_CROSSREF_IDENTIFIER}], [SAVED_DOI_IDENTIFIER]), + ("license_id", EXTRA_LICENSE["id"], SAVED_SHORT_EXTRA_LICENSE), + ("dataset_columns", None, SAVED_MINIMAL_DATASET_COLUMNS), ], ) @pytest.mark.parametrize( @@ -455,7 +563,7 @@ def test_can_update_score_set_supporting_data_after_publication( "publishedDate": date.today().isoformat(), "numVariants": 3, "private": False, - "datasetColumns": {"countColumns": [], "scoreColumns": ["score"]}, + "datasetColumns": SAVED_MINIMAL_DATASET_COLUMNS, "processingState": ProcessingState.success.name, } ) @@ -490,6 +598,7 @@ def test_can_update_score_set_supporting_data_after_publication( TEST_SCORE_SET_RANGES_ALL_SCHEMAS_PRESENT, None, ), + ("dataset_columns", {"countColumns": [], "scoreColumns": ["score"]}, SAVED_MINIMAL_DATASET_COLUMNS), ], ) @pytest.mark.parametrize( @@ -531,7 +640,7 @@ def test_cannot_update_score_set_target_data_after_publication( "publishedDate": date.today().isoformat(), "numVariants": 3, "private": False, - "datasetColumns": {"countColumns": [], "scoreColumns": ["score"]}, + "datasetColumns": SAVED_MINIMAL_DATASET_COLUMNS, "processingState": ProcessingState.success.name, } ) @@ -781,6 +890,47 @@ def test_add_score_set_variants_scores_and_counts_endpoint(session, client, setu assert score_set == response_data +def test_add_score_set_variants_scores_counts_and_column_metadata_endpoint( + session, client, setup_router_db, data_files +): + experiment = create_experiment(client) + score_set = create_seq_score_set(client, experiment["urn"]) + scores_csv_path = data_files / "scores.csv" + counts_csv_path = data_files / "counts.csv" + score_columns_metadata_path = data_files / "score_columns_metadata.json" + count_columns_metadata_path = data_files / "count_columns_metadata.json" + with ( + open(scores_csv_path, "rb") as scores_file, + open(counts_csv_path, "rb") as counts_file, + open(score_columns_metadata_path, "rb") as score_columns_metadata_file, + open(count_columns_metadata_path, "rb") as count_columns_metadata_file, + patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as queue, + ): + score_columns_metadata = json.load(score_columns_metadata_file) + count_columns_metadata = json.load(count_columns_metadata_file) + response = client.post( + f"/api/v1/score-sets/{score_set['urn']}/variants/data", + files={ + "scores_file": (scores_csv_path.name, scores_file, "text/csv"), + "counts_file": (counts_csv_path.name, counts_file, "text/csv"), + }, + data={ + "score_columns_metadata": json.dumps(score_columns_metadata), + "count_columns_metadata": json.dumps(count_columns_metadata), + }, + ) + queue.assert_called_once() + + assert response.status_code == 200 + response_data = response.json() + jsonschema.validate(instance=response_data, schema=ScoreSet.model_json_schema()) + + # We test the worker process that actually adds the variant data separately. Here, we take it as + # fact that it would have succeeded. + score_set.update({"processingState": "processing"}) + assert score_set == response_data + + def test_add_score_set_variants_scores_only_endpoint_utf8_encoded(client, setup_router_db, data_files): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) @@ -1086,7 +1236,7 @@ def test_publish_score_set(session, data_provider, client, setup_router_db, data "publishedDate": date.today().isoformat(), "numVariants": 3, "private": False, - "datasetColumns": {"countColumns": [], "scoreColumns": ["score"]}, + "datasetColumns": SAVED_MINIMAL_DATASET_COLUMNS, "processingState": ProcessingState.success.name, } ) @@ -1221,7 +1371,7 @@ def test_contributor_can_publish_other_users_score_set(session, data_provider, c "publishedDate": date.today().isoformat(), "numVariants": 3, "private": False, - "datasetColumns": {"countColumns": [], "scoreColumns": ["score"]}, + "datasetColumns": SAVED_MINIMAL_DATASET_COLUMNS, "processingState": ProcessingState.success.name, } ) @@ -2515,7 +2665,7 @@ def test_score_set_not_found_for_non_existent_score_set_when_adding_score_calibr with DependencyOverrider(admin_app_overrides): response = client.post( - f"/api/v1/score-sets/{score_set['urn']+'xxx'}/ranges/data", + f"/api/v1/score-sets/{score_set['urn'] + 'xxx'}/ranges/data", json=range_payload, ) response_data = response.json() @@ -2539,7 +2689,7 @@ def test_upload_a_non_utf8_file(session, client, setup_router_db, data_files): f"/api/v1/score-sets/{score_set['urn']}/variants/data", files={"scores_file": (scores_csv_path.name, scores_file, "text/csv")}, ) - assert response.status_code == 400 + assert response.status_code == 422 response_data = response.json() assert ( "Error decoding file: 'utf-8' codec can't decode byte 0xdd in position 10: invalid continuation byte. " @@ -2712,11 +2862,11 @@ def test_cannot_fetch_clinical_controls_for_nonexistent_score_set( ) link_clinical_controls_to_mapped_variants(session, score_set) - response = client.get(f"/api/v1/score-sets/{score_set['urn']+'xxx'}/clinical-controls") + response = client.get(f"/api/v1/score-sets/{score_set['urn'] + 'xxx'}/clinical-controls") assert response.status_code == 404 response_data = response.json() - assert f"score set with URN '{score_set['urn']+'xxx'}' not found" in response_data["detail"] + assert f"score set with URN '{score_set['urn'] + 'xxx'}' not found" in response_data["detail"] def test_cannot_fetch_clinical_controls_for_score_set_when_none_exist( @@ -2777,11 +2927,11 @@ def test_cannot_get_annotated_variants_for_nonexistent_score_set(client, setup_r experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) - response = client.get(f"/api/v1/score-sets/{score_set['urn']+'xxx'}/annotated-variants/{annotation_type}") + response = client.get(f"/api/v1/score-sets/{score_set['urn'] + 'xxx'}/annotated-variants/{annotation_type}") response_data = response.json() assert response.status_code == 404 - assert f"score set with URN {score_set['urn']+'xxx'} not found" in response_data["detail"] + assert f"score set with URN {score_set['urn'] + 'xxx'} not found" in response_data["detail"] @pytest.mark.parametrize( @@ -3357,11 +3507,11 @@ def test_cannot_fetch_gnomad_variants_for_nonexistent_score_set( ) link_gnomad_variants_to_mapped_variants(session, score_set) - response = client.get(f"/api/v1/score-sets/{score_set['urn']+'xxx'}/gnomad-variants") + response = client.get(f"/api/v1/score-sets/{score_set['urn'] + 'xxx'}/gnomad-variants") assert response.status_code == 404 response_data = response.json() - assert f"score set with URN '{score_set['urn']+'xxx'}' not found" in response_data["detail"] + assert f"score set with URN '{score_set['urn'] + 'xxx'}' not found" in response_data["detail"] def test_cannot_fetch_gnomad_variants_for_score_set_when_none_exist( diff --git a/tests/validation/dataframe/test_dataframe.py b/tests/validation/dataframe/test_dataframe.py index 2becb745..4c8334de 100644 --- a/tests/validation/dataframe/test_dataframe.py +++ b/tests/validation/dataframe/test_dataframe.py @@ -6,10 +6,10 @@ import pytest from mavedb.lib.validation.constants.general import ( + guide_sequence_column, hgvs_nt_column, hgvs_pro_column, hgvs_splice_column, - guide_sequence_column, required_score_column, ) from mavedb.lib.validation.dataframe.dataframe import ( @@ -157,7 +157,12 @@ class TestValidateStandardizeDataFramePair(DfTestCase): def test_no_targets(self): with self.assertRaises(ValueError): validate_and_standardize_dataframe_pair( - self.dataframe, counts_df=None, targets=[], hdp=self.mocked_nt_human_data_provider + self.dataframe, + counts_df=None, + score_columns_metadata=None, + count_columns_metadata=None, + targets=[], + hdp=self.mocked_nt_human_data_provider, ) # TODO: Add additional DataFrames. Realistically, if other unit tests pass this function is ok diff --git a/tests/view_models/test_all_fields_optional_model.py b/tests/view_models/test_all_fields_optional_model.py new file mode 100644 index 00000000..2580b95f --- /dev/null +++ b/tests/view_models/test_all_fields_optional_model.py @@ -0,0 +1,186 @@ +from typing import Optional + +import pytest +from pydantic import Field + +from mavedb.view_models.base.base import BaseModel +from mavedb.view_models.utils import all_fields_optional_model + + +# Test models +class DummyModel(BaseModel): + required_string: str = Field(..., description="Required string field") + required_int: int + optional_with_default: str = "default_value" + optional_nullable: Optional[str] = None + field_with_constraints: int = Field(..., ge=0, le=100) + optional_boolean: bool = True + + +def test_all_fields_optional_model_basic(): + """Test that all fields become optional in the decorated model.""" + + @all_fields_optional_model() + class OptionalDummyModel(DummyModel): + pass + + # Should be able to create instance with no arguments + instance = OptionalDummyModel() + + assert instance.required_string is None + assert instance.required_int is None + assert instance.optional_with_default is None # Default overridden to None + assert instance.optional_nullable is None + assert instance.field_with_constraints is None + assert instance.optional_boolean is None + + +def test_all_fields_optional_model_partial_assignment(): + """Test that partial field assignment works correctly.""" + + @all_fields_optional_model() + class OptionalDummyModel(DummyModel): + pass + + instance = OptionalDummyModel(required_string="test", required_int=42) + + assert instance.required_string == "test" + assert instance.required_int == 42 + assert instance.optional_with_default is None + assert instance.optional_nullable is None + assert instance.field_with_constraints is None + assert instance.optional_boolean is None + + +def test_all_fields_optional_model_all_fields_provided(): + """Test that all fields can still be provided.""" + + @all_fields_optional_model() + class OptionalDummyModel(DummyModel): + pass + + instance = OptionalDummyModel( + required_string="test", + required_int=42, + optional_with_default="custom_value", + optional_nullable="not_null", + field_with_constraints=50, + optional_boolean=False, + ) + + assert instance.required_string == "test" + assert instance.required_int == 42 + assert instance.optional_with_default == "custom_value" + assert instance.optional_nullable == "not_null" + assert instance.field_with_constraints == 50 + assert instance.optional_boolean is False + + +def test_all_fields_optional_model_field_info_preserved(): + """Test that field constraints and metadata are preserved.""" + + @all_fields_optional_model() + class OptionalDummyModel(DummyModel): + pass + + # Check that field info is preserved + required_str_field = OptionalDummyModel.model_fields["required_string"] + assert required_str_field.description == "Required string field" + + # Field should now be optional + assert required_str_field.default is None + + +def test_all_fields_optional_model_validation_still_works(): + """Test that field validation still works when values are provided.""" + + @all_fields_optional_model() + class OptionalDummyModel(DummyModel): + pass + + # Should still validate constraints when value is provided + with pytest.raises(ValueError): + OptionalDummyModel(field_with_constraints=150) # Exceeds max value of 100 + + +def test_all_fields_optional_model_type_annotations(): + """Test that type annotations are correctly made optional.""" + + @all_fields_optional_model() + class OptionalDummyModel(DummyModel): + pass + + # Get field annotations + fields = OptionalDummyModel.model_fields + + # Check that previously required fields are now Optional + assert fields["required_string"].annotation == Optional[str] + assert fields["required_int"].annotation == Optional[int] + + # Check that already optional fields remain optional + assert fields["optional_nullable"].annotation == Optional[str] + assert fields["optional_boolean"].annotation == Optional[bool] + + +def test_all_fields_optional_model_serialization(): + """Test that the optional model serializes correctly.""" + + @all_fields_optional_model() + class OptionalDummyModel(DummyModel): + pass + + instance = OptionalDummyModel(required_string="test") + serialized = instance.model_dump() + + expected = { + "required_string": "test", + "required_int": None, + "optional_with_default": None, + "optional_nullable": None, + "field_with_constraints": None, + "optional_boolean": None, + } + + assert serialized == expected + + +def test_all_fields_optional_model_exclude_unset(): + """Test that model_dump with exclude_unset works correctly.""" + + @all_fields_optional_model() + class OptionalDummyModel(DummyModel): + pass + + instance = OptionalDummyModel(required_string="test") + serialized = instance.model_dump(exclude_unset=True) + + # Should only include explicitly set fields + assert serialized == {"required_string": "test"} + + +def test_all_fields_optional_model_inheritance(): + """Test that inheritance still works with the decorated model.""" + + @all_fields_optional_model() + class OptionalDummyModel(DummyModel): + pass + + # Should inherit from DummyModel + assert issubclass(OptionalDummyModel, DummyModel) + assert issubclass(OptionalDummyModel, BaseModel) + + +def test_all_fields_optional_model_field_defaults_overridden(): + """Test that original defaults are overridden to None.""" + + @all_fields_optional_model() + class OptionalDummyModel(DummyModel): + pass + + instance = OptionalDummyModel() + + # Originally had default True, should now be None + assert instance.optional_boolean is None + + # Originally had default None, should still be None + assert instance.optional_nullable is None diff --git a/tests/view_models/test_score_set.py b/tests/view_models/test_score_set.py index 5cae54e7..a74e4d79 100644 --- a/tests/view_models/test_score_set.py +++ b/tests/view_models/test_score_set.py @@ -1,22 +1,23 @@ -import pytest from copy import deepcopy -from humps import camelize +import pytest from mavedb.view_models.publication_identifier import PublicationIdentifier, PublicationIdentifierCreate -from mavedb.view_models.score_set import SavedScoreSet, ScoreSetCreate, ScoreSetModify +from mavedb.view_models.score_set import SavedScoreSet, ScoreSetCreate, ScoreSetModify, ScoreSetUpdateAllOptional from mavedb.view_models.target_gene import SavedTargetGene, TargetGeneCreate - from tests.helpers.constants import ( - TEST_PUBMED_IDENTIFIER, + EXTRA_LICENSE, + EXTRA_USER, + SAVED_PUBMED_PUBLICATION, + TEST_BIORXIV_IDENTIFIER, + TEST_CROSSREF_IDENTIFIER, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_SEQ_SCORESET, + TEST_MINIMAL_SEQ_SCORESET_RESPONSE, + TEST_PUBMED_IDENTIFIER, + TEST_SCORE_SET_RANGES_ALL_SCHEMAS_PRESENT, TEST_SCORE_SET_RANGES_ONLY_INVESTIGATOR_PROVIDED, TEST_SCORE_SET_RANGES_ONLY_ZEIBERG_CALIBRATION, - TEST_SCORE_SET_RANGES_ALL_SCHEMAS_PRESENT, - SAVED_PUBMED_PUBLICATION, - TEST_BIORXIV_IDENTIFIER, - TEST_MINIMAL_SEQ_SCORESET_RESPONSE, VALID_EXPERIMENT_URN, VALID_SCORE_SET_URN, VALID_TMP_URN, @@ -383,65 +384,27 @@ def test_saved_score_set_synthetic_properties(): ) -def test_saved_score_set_data_set_columns_are_camelized(): - score_set = TEST_MINIMAL_SEQ_SCORESET_RESPONSE.copy() - score_set["urn"] = "urn:score-set-xxx" - - # Remove pre-set synthetic properties - score_set.pop("metaAnalyzesScoreSetUrns") - score_set.pop("metaAnalyzedByScoreSetUrns") - score_set.pop("primaryPublicationIdentifiers") - score_set.pop("secondaryPublicationIdentifiers") - score_set.pop("datasetColumns") - - # Convert fields expecting an object to attributed objects - external_identifiers = {"refseq_offset": None, "ensembl_offset": None, "uniprot_offset": None} - target_genes = [ - dummy_attributed_object_from_dict({**target, **external_identifiers}) for target in score_set["targetGenes"] - ] - score_set["targetGenes"] = [SavedTargetGene.model_validate(target) for target in target_genes] - - # Set synthetic properties with dummy attributed objects to mock SQLAlchemy model objects. - score_set["meta_analyzes_score_sets"] = [ - dummy_attributed_object_from_dict({"urn": "urn:meta-analyzes-xxx", "superseding_score_set": None}) - ] - score_set["meta_analyzed_by_score_sets"] = [ - dummy_attributed_object_from_dict({"urn": "urn:meta-analyzed-xxx", "superseding_score_set": None}) - ] - score_set["publication_identifier_associations"] = [ - dummy_attributed_object_from_dict( - { - "publication": PublicationIdentifier(**SAVED_PUBMED_PUBLICATION), - "primary": True, - } - ), - dummy_attributed_object_from_dict( - { - "publication": PublicationIdentifier( - **{**SAVED_PUBMED_PUBLICATION, **{"identifier": TEST_BIORXIV_IDENTIFIER}} - ), - "primary": False, - } - ), - dummy_attributed_object_from_dict( - { - "publication": PublicationIdentifier( - **{**SAVED_PUBMED_PUBLICATION, **{"identifier": TEST_BIORXIV_IDENTIFIER}} - ), - "primary": False, - } - ), - ] - - # The camelized dataset columns we are testing - score_set["dataset_columns"] = {"camelize_me": "test", "noNeed": "test"} - - score_set_attributed_object = dummy_attributed_object_from_dict(score_set) - saved_score_set = SavedScoreSet.model_validate(score_set_attributed_object) - - assert sorted(list(saved_score_set.dataset_columns.keys())) == sorted( - [camelize(k) for k in score_set["dataset_columns"].keys()] - ) +@pytest.mark.parametrize( + "attribute,updated_data", + [ + ("title", "Updated Title"), + ("method_text", "Updated Method Text"), + ("abstract_text", "Updated Abstract Text"), + ("short_description", "Updated Abstract Text"), + ("title", "Updated Title"), + ("extra_metadata", {"updated": "metadata"}), + ("data_usage_policy", "data_usage_policy"), + ("contributors", [{"orcid_id": EXTRA_USER["username"]}]), + ("primary_publication_identifiers", [{"identifier": TEST_PUBMED_IDENTIFIER}]), + ("secondary_publication_identifiers", [{"identifier": TEST_PUBMED_IDENTIFIER}]), + ("doi_identifiers", [{"identifier": TEST_CROSSREF_IDENTIFIER}]), + ("license_id", EXTRA_LICENSE["id"]), + ("target_genes", TEST_MINIMAL_SEQ_SCORESET["targetGenes"]), + ("score_ranges", TEST_SCORE_SET_RANGES_ALL_SCHEMAS_PRESENT), + ], +) +def test_score_set_update_all_optional(attribute, updated_data): + ScoreSetUpdateAllOptional(**{attribute: updated_data}) @pytest.mark.parametrize( diff --git a/tests/view_models/test_score_set_dataset_columns.py b/tests/view_models/test_score_set_dataset_columns.py new file mode 100644 index 00000000..a5b304e7 --- /dev/null +++ b/tests/view_models/test_score_set_dataset_columns.py @@ -0,0 +1,18 @@ +from mavedb.view_models.score_set_dataset_columns import DatasetColumnMetadata, SavedDatasetColumns +from tests.helpers.constants import TEST_SCORE_SET_DATASET_COLUMNS + + +def test_score_set_dataset_columns(): + score_set_dataset_columns = TEST_SCORE_SET_DATASET_COLUMNS.copy() + + for k, v in score_set_dataset_columns["score_columns_metadata"].items(): + score_set_dataset_columns["score_columns_metadata"][k] = DatasetColumnMetadata.model_validate(v) + for k, v in score_set_dataset_columns["count_columns_metadata"].items(): + score_set_dataset_columns["count_columns_metadata"][k] = DatasetColumnMetadata.model_validate(v) + + saved_score_set_dataset_columns = SavedDatasetColumns.model_validate(score_set_dataset_columns) + + assert saved_score_set_dataset_columns.score_columns_metadata == score_set_dataset_columns["score_columns_metadata"] + assert saved_score_set_dataset_columns.count_columns_metadata == score_set_dataset_columns["count_columns_metadata"] + assert saved_score_set_dataset_columns.score_columns == score_set_dataset_columns["score_columns"] + assert saved_score_set_dataset_columns.count_columns == score_set_dataset_columns["count_columns"] diff --git a/tests/worker/data/count_columns_metadata.json b/tests/worker/data/count_columns_metadata.json new file mode 100644 index 00000000..9aaaa355 --- /dev/null +++ b/tests/worker/data/count_columns_metadata.json @@ -0,0 +1,10 @@ +{ + "c_0": { + "description": "c_0 description", + "details": "c_0 details" + }, + "c_1": { + "description": "c_1 description", + "details": "c_1 details" + } +} diff --git a/tests/worker/data/score_columns_metadata.json b/tests/worker/data/score_columns_metadata.json new file mode 100644 index 00000000..a21bc31e --- /dev/null +++ b/tests/worker/data/score_columns_metadata.json @@ -0,0 +1,10 @@ +{ + "s_0": { + "description": "s_0 description", + "details": "s_0 details" + }, + "s_1": { + "description": "s_0 description", + "details": "s_0 details" + } +} diff --git a/tests/worker/data/scores.csv b/tests/worker/data/scores.csv index f23cafcb..11fce498 100644 --- a/tests/worker/data/scores.csv +++ b/tests/worker/data/scores.csv @@ -1,4 +1,4 @@ -hgvs_nt,hgvs_pro,score -c.1A>T,p.Thr1Ser,0.3 -c.2C>T,p.Thr1Met,0.0 -c.6T>A,p.Phe2Leu,-1.65 +hgvs_nt,hgvs_pro,score,s_0,s_1 +c.1A>T,p.Thr1Ser,0.3,val1,val1 +c.2C>T,p.Thr1Met,0.0,val2,val2 +c.6T>A,p.Phe2Leu,-1.65,val3,val3 diff --git a/tests/worker/data/scores_acc.csv b/tests/worker/data/scores_acc.csv index 30b0d836..1440bc8c 100644 --- a/tests/worker/data/scores_acc.csv +++ b/tests/worker/data/scores_acc.csv @@ -1,4 +1,4 @@ -hgvs_nt,score -NM_001637.3:c.1G>C,0.3 -NM_001637.3:c.2A>G,0.0 -NM_001637.3:c.6C>A,-1.65 +hgvs_nt,score,s_0,s_1 +NM_001637.3:c.1G>C,0.3,val1,val1 +NM_001637.3:c.2A>G,0.0,val2,val2 +NM_001637.3:c.6C>A,-1.65,val3,val3 diff --git a/tests/worker/data/scores_multi_target.csv b/tests/worker/data/scores_multi_target.csv index 11dcc55f..903b8cbc 100644 --- a/tests/worker/data/scores_multi_target.csv +++ b/tests/worker/data/scores_multi_target.csv @@ -1,4 +1,4 @@ -hgvs_nt,score -TEST3:n.1A>T,0.3 -TEST3:n.6T>A,-1.65 -TEST4:n.2A>T,0.1 +hgvs_nt,score,s_0,s_1 +TEST3:n.1A>T,0.3,val1,val1 +TEST3:n.6T>A,-1.65,val2,val2 +TEST4:n.2A>T,0.1,val3,val3 diff --git a/tests/worker/test_jobs.py b/tests/worker/test_jobs.py index 3f42fc8d..e7fd0b39 100644 --- a/tests/worker/test_jobs.py +++ b/tests/worker/test_jobs.py @@ -1,5 +1,6 @@ # ruff: noqa: E402 +import json from asyncio.unix_events import _UnixSelectorEventLoop from copy import deepcopy from datetime import date @@ -18,13 +19,13 @@ pyathena = pytest.importorskip("pyathena") from mavedb.data_providers.services import VRSMap -from mavedb.lib.mave.constants import HGVS_NT_COLUMN -from mavedb.lib.score_sets import csv_data_to_df from mavedb.lib.clingen.services import ( ClinGenAlleleRegistryService, ClinGenLdhService, clingen_allele_id_from_ldh_variation, ) +from mavedb.lib.mave.constants import HGVS_NT_COLUMN +from mavedb.lib.score_sets import csv_data_to_df from mavedb.lib.uniprot.id_mapping import UniProtIDMappingAPI from mavedb.lib.validation.exceptions import ValidationError from mavedb.models.enums.mapping_state import MappingState @@ -39,39 +40,39 @@ MAPPING_CURRENT_ID_NAME, MAPPING_QUEUE_NAME, create_variants_for_score_set, + link_clingen_variants, link_gnomad_variants, map_variants_for_score_set, - variant_mapper_manager, - submit_score_set_mappings_to_ldh, - link_clingen_variants, + poll_uniprot_mapping_jobs_for_score_set, submit_score_set_mappings_to_car, + submit_score_set_mappings_to_ldh, submit_uniprot_mapping_jobs_for_score_set, - poll_uniprot_mapping_jobs_for_score_set, + variant_mapper_manager, ) from tests.helpers.constants import ( TEST_ACC_SCORESET_VARIANT_MAPPING_SCAFFOLD, TEST_CLINGEN_ALLELE_OBJECT, - TEST_CLINGEN_SUBMISSION_RESPONSE, + TEST_CLINGEN_LDH_LINKING_RESPONSE, TEST_CLINGEN_SUBMISSION_BAD_RESQUEST_RESPONSE, + TEST_CLINGEN_SUBMISSION_RESPONSE, TEST_CLINGEN_SUBMISSION_UNAUTHORIZED_RESPONSE, - TEST_CLINGEN_LDH_LINKING_RESPONSE, TEST_GNOMAD_DATA_VERSION, - TEST_NT_CDOT_TRANSCRIPT, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_EXPERIMENT, TEST_MINIMAL_MULTI_TARGET_SCORESET, TEST_MINIMAL_SEQ_SCORESET, TEST_MULTI_TARGET_SCORESET_VARIANT_MAPPING_SCAFFOLD, + TEST_NT_CDOT_TRANSCRIPT, TEST_SEQ_SCORESET_VARIANT_MAPPING_SCAFFOLD, - VALID_NT_ACCESSION, - TEST_VALID_PRE_MAPPED_VRS_ALLELE_VRS2_X, - TEST_VALID_POST_MAPPED_VRS_ALLELE_VRS2_X, + TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, TEST_UNIPROT_JOB_SUBMISSION_RESPONSE, TEST_UNIPROT_SWISS_PROT_TYPE, - TEST_UNIPROT_ID_MAPPING_SWISS_PROT_RESPONSE, - VALID_UNIPROT_ACCESSION, + TEST_VALID_POST_MAPPED_VRS_ALLELE_VRS2_X, + TEST_VALID_PRE_MAPPED_VRS_ALLELE_VRS2_X, VALID_CHR_ACCESSION, VALID_CLINGEN_CA_ID, + VALID_NT_ACCESSION, + VALID_UNIPROT_ACCESSION, ) from tests.helpers.util.exceptions import awaitable_exception from tests.helpers.util.experiment import create_experiment @@ -118,15 +119,21 @@ async def setup_records_and_files(async_client, data_files, input_score_set): with ( open(data_files / scores_fp, "rb") as score_file, open(data_files / counts_fp, "rb") as count_file, + open(data_files / "score_columns_metadata.json", "rb") as score_columns_file, + open(data_files / "count_columns_metadata.json", "rb") as count_columns_file, ): scores = csv_data_to_df(score_file) counts = csv_data_to_df(count_file) + score_columns_metadata = json.load(score_columns_file) + count_columns_metadata = json.load(count_columns_file) - return score_set["urn"], scores, counts + return score_set["urn"], scores, counts, score_columns_metadata, count_columns_metadata async def setup_records_files_and_variants(session, async_client, data_files, input_score_set, worker_ctx): - score_set_urn, scores, counts = await setup_records_and_files(async_client, data_files, input_score_set) + score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( + async_client, data_files, input_score_set + ) score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() # Patch CDOT `_get_transcript`, in the event this function is called on an accesssion based scoreset. @@ -135,7 +142,9 @@ async def setup_records_files_and_variants(session, async_client, data_files, in "_get_transcript", return_value=TEST_NT_CDOT_TRANSCRIPT, ): - result = await create_variants_for_score_set(worker_ctx, uuid4().hex, score_set.id, 1, scores, counts) + result = await create_variants_for_score_set( + worker_ctx, uuid4().hex, score_set.id, 1, scores, counts, score_columns_metadata, count_columns_metadata + ) score_set_with_variants = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() @@ -248,7 +257,9 @@ async def test_create_variants_for_score_set_with_validation_error( session, data_files, ): - score_set_urn, scores, counts = await setup_records_and_files(async_client, data_files, input_score_set) + score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( + async_client, data_files, input_score_set + ) score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() if input_score_set == TEST_MINIMAL_SEQ_SCORESET: @@ -266,7 +277,14 @@ async def test_create_variants_for_score_set_with_validation_error( ) as hdp, ): result = await create_variants_for_score_set( - standalone_worker_context, uuid4().hex, score_set.id, 1, scores, counts + standalone_worker_context, + uuid4().hex, + score_set.id, + 1, + scores, + counts, + score_columns_metadata, + count_columns_metadata, ) # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. @@ -298,7 +316,9 @@ async def test_create_variants_for_score_set_with_caught_exception( session, data_files, ): - score_set_urn, scores, counts = await setup_records_and_files(async_client, data_files, input_score_set) + score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( + async_client, data_files, input_score_set + ) score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() # This is somewhat dumb and wouldn't actually happen like this, but it serves as an effective way to guarantee @@ -307,7 +327,14 @@ async def test_create_variants_for_score_set_with_caught_exception( patch.object(pd.DataFrame, "isnull", side_effect=Exception) as mocked_exc, ): result = await create_variants_for_score_set( - standalone_worker_context, uuid4().hex, score_set.id, 1, scores, counts + standalone_worker_context, + uuid4().hex, + score_set.id, + 1, + scores, + counts, + score_columns_metadata, + count_columns_metadata, ) mocked_exc.assert_called() @@ -334,7 +361,9 @@ async def test_create_variants_for_score_set_with_caught_base_exception( session, data_files, ): - score_set_urn, scores, counts = await setup_records_and_files(async_client, data_files, input_score_set) + score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( + async_client, data_files, input_score_set + ) score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() # This is somewhat (extra) dumb and wouldn't actually happen like this, but it serves as an effective way to guarantee @@ -343,7 +372,14 @@ async def test_create_variants_for_score_set_with_caught_base_exception( patch.object(pd.DataFrame, "isnull", side_effect=BaseException), ): result = await create_variants_for_score_set( - standalone_worker_context, uuid4().hex, score_set.id, 1, scores, counts + standalone_worker_context, + uuid4().hex, + score_set.id, + 1, + scores, + counts, + score_columns_metadata, + count_columns_metadata, ) db_variants = session.scalars(select(Variant)).all() @@ -369,7 +405,9 @@ async def test_create_variants_for_score_set_with_existing_variants( session, data_files, ): - score_set_urn, scores, counts = await setup_records_and_files(async_client, data_files, input_score_set) + score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( + async_client, data_files, input_score_set + ) score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() with patch.object( @@ -378,7 +416,14 @@ async def test_create_variants_for_score_set_with_existing_variants( return_value=TEST_NT_CDOT_TRANSCRIPT, ) as hdp: result = await create_variants_for_score_set( - standalone_worker_context, uuid4().hex, score_set.id, 1, scores, counts + standalone_worker_context, + uuid4().hex, + score_set.id, + 1, + scores, + counts, + score_columns_metadata, + count_columns_metadata, ) # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. @@ -401,7 +446,14 @@ async def test_create_variants_for_score_set_with_existing_variants( return_value=TEST_NT_CDOT_TRANSCRIPT, ) as hdp: result = await create_variants_for_score_set( - standalone_worker_context, uuid4().hex, score_set.id, 1, scores, counts + standalone_worker_context, + uuid4().hex, + score_set.id, + 1, + scores, + counts, + score_columns_metadata, + count_columns_metadata, ) db_variants = session.scalars(select(Variant)).all() @@ -427,7 +479,9 @@ async def test_create_variants_for_score_set_with_existing_exceptions( session, data_files, ): - score_set_urn, scores, counts = await setup_records_and_files(async_client, data_files, input_score_set) + score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( + async_client, data_files, input_score_set + ) score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() # This is somewhat dumb and wouldn't actually happen like this, but it serves as an effective way to guarantee @@ -440,7 +494,14 @@ async def test_create_variants_for_score_set_with_existing_exceptions( ) as mocked_exc, ): result = await create_variants_for_score_set( - standalone_worker_context, uuid4().hex, score_set.id, 1, scores, counts + standalone_worker_context, + uuid4().hex, + score_set.id, + 1, + scores, + counts, + score_columns_metadata, + count_columns_metadata, ) mocked_exc.assert_called() @@ -461,7 +522,14 @@ async def test_create_variants_for_score_set_with_existing_exceptions( return_value=TEST_NT_CDOT_TRANSCRIPT, ) as hdp: result = await create_variants_for_score_set( - standalone_worker_context, uuid4().hex, score_set.id, 1, scores, counts + standalone_worker_context, + uuid4().hex, + score_set.id, + 1, + scores, + counts, + score_columns_metadata, + count_columns_metadata, ) # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. @@ -493,7 +561,9 @@ async def test_create_variants_for_score_set( session, data_files, ): - score_set_urn, scores, counts = await setup_records_and_files(async_client, data_files, input_score_set) + score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( + async_client, data_files, input_score_set + ) score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() with patch.object( @@ -502,7 +572,14 @@ async def test_create_variants_for_score_set( return_value=TEST_NT_CDOT_TRANSCRIPT, ) as hdp: result = await create_variants_for_score_set( - standalone_worker_context, uuid4().hex, score_set.id, 1, scores, counts + standalone_worker_context, + uuid4().hex, + score_set.id, + 1, + scores, + counts, + score_columns_metadata, + count_columns_metadata, ) # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. @@ -536,7 +613,9 @@ async def test_create_variants_for_score_set_enqueues_manager_and_successful_map ): score_set_is_seq = all(["targetSequence" in target for target in input_score_set["targetGenes"]]) score_set_is_multi_target = len(input_score_set["targetGenes"]) > 1 - score_set_urn, scores, counts = await setup_records_and_files(async_client, data_files, input_score_set) + score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( + async_client, data_files, input_score_set + ) score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() async def dummy_mapping_job(): @@ -573,7 +652,16 @@ async def dummy_linking_job(): patch("mavedb.worker.jobs.LINKING_BACKOFF_IN_SECONDS", 0), patch("mavedb.worker.jobs.CLIN_GEN_SUBMISSION_ENABLED", True), ): - await arq_redis.enqueue_job("create_variants_for_score_set", uuid4().hex, score_set.id, 1, scores, counts) + await arq_redis.enqueue_job( + "create_variants_for_score_set", + uuid4().hex, + score_set.id, + 1, + scores, + counts, + score_columns_metadata, + count_columns_metadata, + ) await arq_worker.async_run() await arq_worker.run_check() @@ -612,11 +700,22 @@ async def test_create_variants_for_score_set_exception_skips_mapping( arq_worker, arq_redis, ): - score_set_urn, scores, counts = await setup_records_and_files(async_client, data_files, input_score_set) + score_set_urn, scores, counts, score_columns_metadata, count_columns_metadata = await setup_records_and_files( + async_client, data_files, input_score_set + ) score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() with patch.object(pd.DataFrame, "isnull", side_effect=Exception) as mocked_exc: - await arq_redis.enqueue_job("create_variants_for_score_set", uuid4().hex, score_set.id, 1, scores, counts) + await arq_redis.enqueue_job( + "create_variants_for_score_set", + uuid4().hex, + score_set.id, + 1, + scores, + counts, + score_columns_metadata, + count_columns_metadata, + ) await arq_worker.async_run() await arq_worker.run_check()