diff --git a/src/mavedb/lib/score_sets.py b/src/mavedb/lib/score_sets.py index a8a42736..f1234de5 100644 --- a/src/mavedb/lib/score_sets.py +++ b/src/mavedb/lib/score_sets.py @@ -4,7 +4,7 @@ import logging from operator import attrgetter import re -from typing import Any, BinaryIO, Iterable, Optional, TYPE_CHECKING, Sequence, Literal +from typing import Any, BinaryIO, Iterable, List, Optional, TYPE_CHECKING, Sequence, Literal from mavedb.models.mapped_variant import MappedVariant import numpy as np @@ -501,12 +501,13 @@ def find_publish_or_private_superseded_score_set_tail( def get_score_set_variants_as_csv( db: Session, score_set: ScoreSet, - data_type: Literal["scores", "counts"], + namespaces: List[Literal["scores", "counts"]], + namespaced: Optional[bool] = None, start: Optional[int] = None, limit: Optional[int] = None, drop_na_columns: Optional[bool] = None, - include_custom_columns: bool = True, - include_post_mapped_hgvs: bool = False, + include_custom_columns: Optional[bool] = True, + include_post_mapped_hgvs: Optional[bool] = False, ) -> str: """ Get the variant data from a score set as a CSV string. @@ -517,8 +518,10 @@ def get_score_set_variants_as_csv( The database session to use. score_set : ScoreSet The score set to get the variants from. - data_type : {'scores', 'counts'} - The type of data to get. Either 'scores' or 'counts'. + namespaces : List[Literal["scores", "counts"]] + The namespaces for data. Now there are only scores and counts. There will be ClinVar and gnomAD. + namespaced: Optional[bool] = None + Whether namespace the columns or not. start : int, optional The index to start from. If None, starts from the beginning. limit : int, optional @@ -537,20 +540,26 @@ def get_score_set_variants_as_csv( The CSV string containing the variant data. """ assert type(score_set.dataset_columns) is dict - custom_columns_set = "score_columns" if data_type == "scores" else "count_columns" - type_column = "score_data" if data_type == "scores" else "count_data" - - columns = ["accession", "hgvs_nt", "hgvs_splice", "hgvs_pro"] + namespaced_score_set_columns: dict[str, list[str]] = { + "core": ["accession", "hgvs_nt", "hgvs_splice", "hgvs_pro"], + "mavedb": [], + } if include_post_mapped_hgvs: - columns.append("post_mapped_hgvs_g") - columns.append("post_mapped_hgvs_p") - + namespaced_score_set_columns["mavedb"].append("post_mapped_hgvs_g") + namespaced_score_set_columns["mavedb"].append("post_mapped_hgvs_p") + for namespace in namespaces: + namespaced_score_set_columns[namespace] = [] if include_custom_columns: - custom_columns = [str(x) for x in list(score_set.dataset_columns.get(custom_columns_set, []))] - columns += custom_columns - elif data_type == "scores": - columns.append(REQUIRED_SCORE_COLUMN) - + if "scores" in namespaced_score_set_columns: + namespaced_score_set_columns["scores"] = [ + col for col in [str(x) for x in list(score_set.dataset_columns.get("score_columns", []))] + ] + if "counts" in namespaced_score_set_columns: + namespaced_score_set_columns["counts"] = [ + col for col in [str(x) for x in list(score_set.dataset_columns.get("count_columns", []))] + ] + elif "scores" in namespaced_score_set_columns: + namespaced_score_set_columns["scores"].append(REQUIRED_SCORE_COLUMN) variants: Sequence[Variant] = [] mappings: Optional[list[Optional[MappedVariant]]] = None @@ -587,13 +596,22 @@ def get_score_set_variants_as_csv( if limit: variants_query = variants_query.limit(limit) variants = db.scalars(variants_query).all() + rows_data = variants_to_csv_rows(variants, columns=namespaced_score_set_columns, namespaced=namespaced, mappings=mappings) # type: ignore + rows_columns = [ + ( + f"{namespace}.{col}" + if (namespaced and namespace not in ["core", "mavedb"]) + else (f"mavedb.{col}" if namespaced and namespace == "mavedb" else col) + ) + for namespace, cols in namespaced_score_set_columns.items() + for col in cols + ] - rows_data = variants_to_csv_rows(variants, columns=columns, dtype=type_column, mappings=mappings) # type: ignore if drop_na_columns: - rows_data, columns = drop_na_columns_from_csv_file_rows(rows_data, columns) + rows_data, rows_columns = drop_na_columns_from_csv_file_rows(rows_data, rows_columns) stream = io.StringIO() - writer = csv.DictWriter(stream, fieldnames=columns, quoting=csv.QUOTE_MINIMAL) + writer = csv.DictWriter(stream, fieldnames=rows_columns, quoting=csv.QUOTE_MINIMAL) writer.writeheader() writer.writerows(rows_data) return stream.getvalue() @@ -631,9 +649,9 @@ def is_null(value): def variant_to_csv_row( variant: Variant, - columns: list[str], - dtype: str, + columns: dict[str, list[str]], mapping: Optional[MappedVariant] = None, + namespaced: Optional[bool] = None, na_rep="NA", ) -> dict[str, Any]: """ @@ -645,8 +663,8 @@ def variant_to_csv_row( List of variants. columns : list[str] Columns to serialize. - dtype : str, {'scores', 'counts'} - The type of data requested. Either the 'score_data' or 'count_data'. + namespaced: Optional[bool] = None + Namespace the columns or not. na_rep : str String to represent null values. @@ -654,8 +672,9 @@ def variant_to_csv_row( ------- dict[str, Any] """ - row = {} - for column_key in columns: + row: dict[str, Any] = {} + # Handle each column key explicitly as part of its namespace. + for column_key in columns.get("core", []): if column_key == "hgvs_nt": value = str(variant.hgvs_nt) elif column_key == "hgvs_pro": @@ -664,7 +683,13 @@ def variant_to_csv_row( value = str(variant.hgvs_splice) elif column_key == "accession": value = str(variant.urn) - elif column_key == "post_mapped_hgvs_g": + if is_null(value): + value = na_rep + + # export columns in the `core` namespace without a namespace + row[column_key] = value + for column_key in columns.get("mavedb", []): + if column_key == "post_mapped_hgvs_g": hgvs_str = get_hgvs_from_post_mapped(mapping.post_mapped) if mapping and mapping.post_mapped else None if hgvs_str is not None and is_hgvs_g(hgvs_str): value = hgvs_str @@ -676,21 +701,28 @@ def variant_to_csv_row( value = hgvs_str else: value = "" - else: - parent = variant.data.get(dtype) if variant.data else None - value = str(parent.get(column_key)) if parent else na_rep if is_null(value): value = na_rep - row[column_key] = value - + key = f"mavedb.{column_key}" if namespaced else column_key + row[key] = value + for column_key in columns.get("scores", []): + parent = variant.data.get("score_data") if variant.data else None + value = str(parent.get(column_key)) if parent else na_rep + key = f"scores.{column_key}" if namespaced else column_key + row[key] = value + for column_key in columns.get("counts", []): + parent = variant.data.get("count_data") if variant.data else None + value = str(parent.get(column_key)) if parent else na_rep + key = f"counts.{column_key}" if namespaced else column_key + row[key] = value return row def variants_to_csv_rows( variants: Sequence[Variant], - columns: list[str], - dtype: str, + columns: dict[str, list[str]], mappings: Optional[Sequence[Optional[MappedVariant]]] = None, + namespaced: Optional[bool] = None, na_rep="NA", ) -> Iterable[dict[str, Any]]: """ @@ -702,8 +734,8 @@ def variants_to_csv_rows( List of variants. columns : list[str] Columns to serialize. - dtype : str, {'scores', 'counts'} - The type of data requested. Either the 'score_data' or 'count_data'. + namespaced: Optional[bool] = None + Namespace the columns or not. na_rep : str String to represent null values. @@ -713,10 +745,10 @@ def variants_to_csv_rows( """ if mappings is not None: return map( - lambda pair: variant_to_csv_row(pair[0], columns, dtype, mapping=pair[1], na_rep=na_rep), + lambda pair: variant_to_csv_row(pair[0], columns, mapping=pair[1], namespaced=namespaced, na_rep=na_rep), zip(variants, mappings), ) - return map(lambda v: variant_to_csv_row(v, columns, dtype, na_rep=na_rep), variants) + return map(lambda v: variant_to_csv_row(v, columns, namespaced=namespaced, na_rep=na_rep), variants) def find_meta_analyses_for_score_sets(db: Session, urns: list[str]) -> list[ScoreSet]: diff --git a/src/mavedb/routers/score_sets.py b/src/mavedb/routers/score_sets.py index 2e1aa87b..6ee49235 100644 --- a/src/mavedb/routers/score_sets.py +++ b/src/mavedb/routers/score_sets.py @@ -1,7 +1,7 @@ import json import logging from datetime import date -from typing import Any, List, Optional, Sequence, TypedDict, Union +from typing import Any, List, Literal, Optional, Sequence, TypedDict, Union import pandas as pd from arq import ArqRedis @@ -110,20 +110,28 @@ async def enqueue_variant_creation( "hgvs_splice", "hgvs_pro", ] + item.dataset_columns.get("score_columns", []) + # score_columns = { + # "core": ["hgvs_nt", "hgvs_splice", "hgvs_pro"], + # "counts": item.dataset_columns["score_columns"], + # } existing_scores_df = pd.DataFrame( - variants_to_csv_rows(item.variants, columns=score_columns, dtype="score_data") + variants_to_csv_rows(item.variants, columns=score_columns, namespaced=False) ).replace("NA", pd.NA) # create CSV from existing variants on the score set if no new dataframe provided existing_counts_df = None if new_counts_df is None and item.dataset_columns.get("count_columns"): + # count_columns = { + # "core": ["hgvs_nt", "hgvs_splice", "hgvs_pro"], + # "counts": item.dataset_columns["count_columns"], + # } count_columns = [ - "hgvs_nt", - "hgvs_splice", - "hgvs_pro", - ] + item.dataset_columns["count_columns"] + "hgvs_nt", + "hgvs_splice", + "hgvs_pro", + ] + item.dataset_columns["count_columns"] existing_counts_df = pd.DataFrame( - variants_to_csv_rows(item.variants, columns=count_columns, dtype="count_data") + variants_to_csv_rows(item.variants, columns=count_columns, namespaced=False) ).replace("NA", pd.NA) # Await the insertion of this job into the worker queue, not the job itself. @@ -638,7 +646,13 @@ def get_score_set_variants_csv( urn: str, start: int = Query(default=None, description="Start index for pagination"), limit: int = Query(default=None, description="Maximum number of variants to return"), + namespaces: List[Literal["scores", "counts"]] = Query( + default=["scores"], + description="One or more data types to include: scores, counts, clinVar, gnomAD" + ), drop_na_columns: Optional[bool] = None, + include_custom_columns: Optional[bool] = None, + include_post_mapped_hgvs: Optional[bool] = None, db: Session = Depends(deps.get_db), user_data: Optional[UserData] = Depends(get_current_user), ) -> Any: @@ -648,12 +662,9 @@ def get_score_set_variants_csv( This differs from get_score_set_scores_csv() in that it returns only the HGVS columns, score column, and mapped HGVS string. - TODO (https://github.com/VariantEffect/mavedb-api/issues/446) We may want to turn this into a general-purpose CSV + TODO (https://github.com/VariantEffect/mavedb-api/issues/446) We may add another function for ClinVar and gnomAD. export endpoint, with options governing which columns to include. - Parameters - __________ - Parameters __________ urn : str @@ -662,6 +673,9 @@ def get_score_set_variants_csv( The index to start from. If None, starts from the beginning. limit : Optional[int] The maximum number of variants to return. If None, returns all variants. + namespaces: List[Literal["scores", "counts"]] + The namespaces of all columns except for accession, hgvs_nt, hgvs_pro, and hgvs_splice. + We may add ClinVar and gnomAD in the future. drop_na_columns : bool, optional Whether to drop columns that contain only NA values. Defaults to False. db : Session @@ -701,12 +715,13 @@ def get_score_set_variants_csv( csv_str = get_score_set_variants_as_csv( db, score_set, - "scores", + namespaces, + True, start, limit, drop_na_columns, - include_custom_columns=False, - include_post_mapped_hgvs=True, + include_custom_columns, + include_post_mapped_hgvs, ) return StreamingResponse(iter([csv_str]), media_type="text/csv") @@ -762,7 +777,7 @@ def get_score_set_scores_csv( assert_permission(user_data, score_set, Action.READ) - csv_str = get_score_set_variants_as_csv(db, score_set, "scores", start, limit, drop_na_columns) + csv_str = get_score_set_variants_as_csv(db, score_set, ["scores"], False, start, limit, drop_na_columns) return StreamingResponse(iter([csv_str]), media_type="text/csv") @@ -817,7 +832,7 @@ async def get_score_set_counts_csv( assert_permission(user_data, score_set, Action.READ) - csv_str = get_score_set_variants_as_csv(db, score_set, "counts", start, limit, drop_na_columns) + csv_str = get_score_set_variants_as_csv(db, score_set, ["counts"], False, start, limit, drop_na_columns) return StreamingResponse(iter([csv_str]), media_type="text/csv") diff --git a/src/mavedb/scripts/export_public_data.py b/src/mavedb/scripts/export_public_data.py index 9d7d8e7f..2172878d 100644 --- a/src/mavedb/scripts/export_public_data.py +++ b/src/mavedb/scripts/export_public_data.py @@ -147,12 +147,12 @@ def export_public_data(db: Session): logger.info(f"{i + 1}/{num_score_sets} Exporting variants for score set {score_set.urn}") csv_filename_base = score_set.urn.replace(":", "-") - csv_str = get_score_set_variants_as_csv(db, score_set, "scores") + csv_str = get_score_set_variants_as_csv(db, score_set, ["scores"]) zipfile.writestr(f"csv/{csv_filename_base}.scores.csv", csv_str) count_columns = score_set.dataset_columns["count_columns"] if score_set.dataset_columns else None if count_columns and len(count_columns) > 0: - csv_str = get_score_set_variants_as_csv(db, score_set, "counts") + csv_str = get_score_set_variants_as_csv(db, score_set, ["counts"]) zipfile.writestr(f"csv/{csv_filename_base}.counts.csv", csv_str) diff --git a/tests/routers/test_score_set.py b/tests/routers/test_score_set.py index 47e41f80..d1f960e2 100644 --- a/tests/routers/test_score_set.py +++ b/tests/routers/test_score_set.py @@ -2725,7 +2725,7 @@ def test_download_variants_data_file( worker_queue.assert_called_once() download_scores_csv_response = client.get( - f"/api/v1/score-sets/{published_score_set['urn']}/variants/data?drop_na_columns=true" + f"/api/v1/score-sets/{published_score_set['urn']}/variants/data?drop_na_columns=true&include_post_mapped_hgvs=true" ) assert download_scores_csv_response.status_code == 200 download_scores_csv = download_scores_csv_response.text @@ -2736,21 +2736,21 @@ def test_download_variants_data_file( "accession", "hgvs_nt", "hgvs_pro", - "post_mapped_hgvs_g", - "post_mapped_hgvs_p", - "score", + "mavedb.post_mapped_hgvs_g", + "mavedb.post_mapped_hgvs_p", + "scores.score", ] ) rows = list(reader) for row in rows: if has_hgvs_g: - assert row["post_mapped_hgvs_g"] == mapped_variant["post_mapped"]["expressions"][0]["value"] + assert row["mavedb.post_mapped_hgvs_g"] == mapped_variant["post_mapped"]["expressions"][0]["value"] else: - assert row["post_mapped_hgvs_g"] == "NA" + assert row["mavedb.post_mapped_hgvs_g"] == "NA" if has_hgvs_p: - assert row["post_mapped_hgvs_p"] == mapped_variant["post_mapped"]["expressions"][0]["value"] + assert row["mavedb.post_mapped_hgvs_p"] == mapped_variant["post_mapped"]["expressions"][0]["value"] else: - assert row["post_mapped_hgvs_p"] == "NA" + assert row["mavedb.post_mapped_hgvs_p"] == "NA" # Test file doesn't have hgvs_splice so its values are all NA. @@ -2797,6 +2797,130 @@ def test_download_counts_file(session, data_provider, client, setup_router_db, d assert "hgvs_splice" not in columns +# Namespace variant CSV export tests. +def test_download_scores_file_in_variant_data_path(session, data_provider, client, setup_router_db, data_files): + experiment = create_experiment(client) + score_set = create_seq_score_set(client, experiment["urn"]) + score_set = mock_worker_variant_insertion( + client, session, data_provider, score_set, data_files / "scores.csv", data_files / "counts.csv" + ) + with patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as worker_queue: + published_score_set = publish_score_set(client, score_set["urn"]) + worker_queue.assert_called_once() + + download_scores_csv_response = client.get( + f"/api/v1/score-sets/{published_score_set['urn']}/variants/data?namespaces=scores&drop_na_columns=true" + ) + assert download_scores_csv_response.status_code == 200 + download_scores_csv = download_scores_csv_response.text + reader = csv.reader(StringIO(download_scores_csv)) + columns = next(reader) + assert "hgvs_nt" in columns + assert "hgvs_pro" in columns + assert "hgvs_splice" not in columns + assert "scores.score" in columns + + +def test_download_counts_file_in_variant_data_path(session, data_provider, client, setup_router_db, data_files): + experiment = create_experiment(client) + score_set = create_seq_score_set(client, experiment["urn"]) + score_set = mock_worker_variant_insertion( + client, session, data_provider, score_set, data_files / "scores.csv", data_files / "counts.csv" + ) + with patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as worker_queue: + published_score_set = publish_score_set(client, score_set["urn"]) + worker_queue.assert_called_once() + + download_counts_csv_response = client.get( + f"/api/v1/score-sets/{published_score_set['urn']}/variants/data?namespaces=counts&include_custom_columns=true&drop_na_columns=true" + ) + assert download_counts_csv_response.status_code == 200 + download_counts_csv = download_counts_csv_response.text + reader = csv.reader(StringIO(download_counts_csv)) + columns = next(reader) + assert "hgvs_nt" in columns + assert "hgvs_pro" in columns + assert "hgvs_splice" not in columns + assert "counts.c_0" in columns + assert "counts.c_1" in columns + + +def test_download_scores_and_counts_file(session, data_provider, client, setup_router_db, data_files): + experiment = create_experiment(client) + score_set = create_seq_score_set(client, experiment["urn"]) + score_set = mock_worker_variant_insertion( + client, session, data_provider, score_set, data_files / "scores.csv", data_files / "counts.csv" + ) + with patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as worker_queue: + published_score_set = publish_score_set(client, score_set["urn"]) + worker_queue.assert_called_once() + + download_scores_and_counts_csv_response = client.get( + f"/api/v1/score-sets/{published_score_set['urn']}/variants/data?namespaces=counts&namespaces=scores&include_custom_columns=true&drop_na_columns=true" + ) + assert download_scores_and_counts_csv_response.status_code == 200 + download_scores_and_counts_csv = download_scores_and_counts_csv_response.text + reader = csv.DictReader(StringIO(download_scores_and_counts_csv)) + assert sorted(reader.fieldnames) == sorted( + [ + "accession", + "hgvs_nt", + "hgvs_pro", + "scores.score", + "scores.s_0", + "scores.s_1", + "counts.c_0", + "counts.c_1" + ] + ) + + +@pytest.mark.parametrize( + "mapped_variant,has_hgvs_g,has_hgvs_p", + [ + (None, False, False), + (TEST_MAPPED_VARIANT_WITH_HGVS_G_EXPRESSION, True, False), + (TEST_MAPPED_VARIANT_WITH_HGVS_P_EXPRESSION, False, True), + ], + ids=["without_post_mapped_vrs", "with_post_mapped_hgvs_g", "with_post_mapped_hgvs_p"], +) +def test_download_scores_counts_and_post_mapped_variants_file( + session, data_provider, client, setup_router_db, data_files, mapped_variant, has_hgvs_g, has_hgvs_p +): + experiment = create_experiment(client) + score_set = create_seq_score_set(client, experiment["urn"]) + score_set = mock_worker_variant_insertion( + client, session, data_provider, score_set, data_files / "scores.csv", data_files / "counts.csv" + ) + if mapped_variant is not None: + create_mapped_variants_for_score_set(session, score_set["urn"], mapped_variant) + + with patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as worker_queue: + published_score_set = publish_score_set(client, score_set["urn"]) + worker_queue.assert_called_once() + + download_multiple_data_csv_response = client.get( + f"/api/v1/score-sets/{published_score_set['urn']}/variants/data?namespaces=scores&namespaces=counts&include_custom_columns=true&include_post_mapped_hgvs=true&drop_na_columns=true" + ) + assert download_multiple_data_csv_response.status_code == 200 + download_multiple_data_csv = download_multiple_data_csv_response.text + reader = csv.DictReader(StringIO(download_multiple_data_csv)) + assert sorted(reader.fieldnames) == sorted( + [ + "accession", + "hgvs_nt", + "hgvs_pro", + "mavedb.post_mapped_hgvs_g", + "mavedb.post_mapped_hgvs_p", + "scores.score", + "scores.s_0", + "scores.s_1", + "counts.c_0", + "counts.c_1" + ] + ) + + ######################################################################################################################## # Fetching clinical controls and control options for a score set ########################################################################################################################