Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 71 additions & 39 deletions src/mavedb/lib/score_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from operator import attrgetter
import re
from typing import Any, BinaryIO, Iterable, Optional, TYPE_CHECKING, Sequence, Literal
from typing import Any, BinaryIO, Iterable, List, Optional, TYPE_CHECKING, Sequence, Literal

from mavedb.models.mapped_variant import MappedVariant
import numpy as np
Expand Down Expand Up @@ -501,12 +501,13 @@ def find_publish_or_private_superseded_score_set_tail(
def get_score_set_variants_as_csv(
db: Session,
score_set: ScoreSet,
data_type: Literal["scores", "counts"],
namespaces: List[Literal["scores", "counts"]],
namespaced: Optional[bool] = None,
start: Optional[int] = None,
limit: Optional[int] = None,
drop_na_columns: Optional[bool] = None,
include_custom_columns: bool = True,
include_post_mapped_hgvs: bool = False,
include_custom_columns: Optional[bool] = True,
include_post_mapped_hgvs: Optional[bool] = False,
) -> str:
"""
Get the variant data from a score set as a CSV string.
Expand All @@ -517,8 +518,10 @@ def get_score_set_variants_as_csv(
The database session to use.
score_set : ScoreSet
The score set to get the variants from.
data_type : {'scores', 'counts'}
The type of data to get. Either 'scores' or 'counts'.
namespaces : List[Literal["scores", "counts"]]
The namespaces for data. Now there are only scores and counts. There will be ClinVar and gnomAD.
namespaced: Optional[bool] = None
Whether namespace the columns or not.
start : int, optional
The index to start from. If None, starts from the beginning.
limit : int, optional
Expand All @@ -537,20 +540,26 @@ def get_score_set_variants_as_csv(
The CSV string containing the variant data.
"""
assert type(score_set.dataset_columns) is dict
custom_columns_set = "score_columns" if data_type == "scores" else "count_columns"
type_column = "score_data" if data_type == "scores" else "count_data"

columns = ["accession", "hgvs_nt", "hgvs_splice", "hgvs_pro"]
namespaced_score_set_columns: dict[str, list[str]] = {
"core": ["accession", "hgvs_nt", "hgvs_splice", "hgvs_pro"],
"mavedb": [],
}
if include_post_mapped_hgvs:
columns.append("post_mapped_hgvs_g")
columns.append("post_mapped_hgvs_p")

namespaced_score_set_columns["mavedb"].append("post_mapped_hgvs_g")
namespaced_score_set_columns["mavedb"].append("post_mapped_hgvs_p")
for namespace in namespaces:
namespaced_score_set_columns[namespace] = []
if include_custom_columns:
custom_columns = [str(x) for x in list(score_set.dataset_columns.get(custom_columns_set, []))]
columns += custom_columns
elif data_type == "scores":
columns.append(REQUIRED_SCORE_COLUMN)

if "scores" in namespaced_score_set_columns:
namespaced_score_set_columns["scores"] = [
col for col in [str(x) for x in list(score_set.dataset_columns.get("score_columns", []))]
]
if "counts" in namespaced_score_set_columns:
namespaced_score_set_columns["counts"] = [
col for col in [str(x) for x in list(score_set.dataset_columns.get("count_columns", []))]
]
elif "scores" in namespaced_score_set_columns:
namespaced_score_set_columns["scores"].append(REQUIRED_SCORE_COLUMN)
variants: Sequence[Variant] = []
mappings: Optional[list[Optional[MappedVariant]]] = None

Expand Down Expand Up @@ -587,13 +596,22 @@ def get_score_set_variants_as_csv(
if limit:
variants_query = variants_query.limit(limit)
variants = db.scalars(variants_query).all()
rows_data = variants_to_csv_rows(variants, columns=namespaced_score_set_columns, namespaced=namespaced, mappings=mappings) # type: ignore
rows_columns = [
(
f"{namespace}.{col}"
if (namespaced and namespace not in ["core", "mavedb"])
else (f"mavedb.{col}" if namespaced and namespace == "mavedb" else col)
)
for namespace, cols in namespaced_score_set_columns.items()
for col in cols
]

rows_data = variants_to_csv_rows(variants, columns=columns, dtype=type_column, mappings=mappings) # type: ignore
if drop_na_columns:
rows_data, columns = drop_na_columns_from_csv_file_rows(rows_data, columns)
rows_data, rows_columns = drop_na_columns_from_csv_file_rows(rows_data, rows_columns)

stream = io.StringIO()
writer = csv.DictWriter(stream, fieldnames=columns, quoting=csv.QUOTE_MINIMAL)
writer = csv.DictWriter(stream, fieldnames=rows_columns, quoting=csv.QUOTE_MINIMAL)
writer.writeheader()
writer.writerows(rows_data)
return stream.getvalue()
Expand Down Expand Up @@ -631,9 +649,9 @@ def is_null(value):

def variant_to_csv_row(
variant: Variant,
columns: list[str],
dtype: str,
columns: dict[str, list[str]],
mapping: Optional[MappedVariant] = None,
namespaced: Optional[bool] = None,
na_rep="NA",
) -> dict[str, Any]:
"""
Expand All @@ -645,17 +663,18 @@ def variant_to_csv_row(
List of variants.
columns : list[str]
Columns to serialize.
dtype : str, {'scores', 'counts'}
The type of data requested. Either the 'score_data' or 'count_data'.
namespaced: Optional[bool] = None
Namespace the columns or not.
na_rep : str
String to represent null values.

Returns
-------
dict[str, Any]
"""
row = {}
for column_key in columns:
row: dict[str, Any] = {}
# Handle each column key explicitly as part of its namespace.
for column_key in columns.get("core", []):
if column_key == "hgvs_nt":
value = str(variant.hgvs_nt)
elif column_key == "hgvs_pro":
Expand All @@ -664,7 +683,13 @@ def variant_to_csv_row(
value = str(variant.hgvs_splice)
elif column_key == "accession":
value = str(variant.urn)
elif column_key == "post_mapped_hgvs_g":
if is_null(value):
value = na_rep

# export columns in the `core` namespace without a namespace
row[column_key] = value
for column_key in columns.get("mavedb", []):
if column_key == "post_mapped_hgvs_g":
hgvs_str = get_hgvs_from_post_mapped(mapping.post_mapped) if mapping and mapping.post_mapped else None
if hgvs_str is not None and is_hgvs_g(hgvs_str):
value = hgvs_str
Expand All @@ -676,21 +701,28 @@ def variant_to_csv_row(
value = hgvs_str
else:
value = ""
else:
parent = variant.data.get(dtype) if variant.data else None
value = str(parent.get(column_key)) if parent else na_rep
if is_null(value):
value = na_rep
row[column_key] = value

key = f"mavedb.{column_key}" if namespaced else column_key
row[key] = value
for column_key in columns.get("scores", []):
parent = variant.data.get("score_data") if variant.data else None
value = str(parent.get(column_key)) if parent else na_rep
key = f"scores.{column_key}" if namespaced else column_key
row[key] = value
for column_key in columns.get("counts", []):
parent = variant.data.get("count_data") if variant.data else None
value = str(parent.get(column_key)) if parent else na_rep
key = f"counts.{column_key}" if namespaced else column_key
row[key] = value
return row


def variants_to_csv_rows(
variants: Sequence[Variant],
columns: list[str],
dtype: str,
columns: dict[str, list[str]],
mappings: Optional[Sequence[Optional[MappedVariant]]] = None,
namespaced: Optional[bool] = None,
na_rep="NA",
) -> Iterable[dict[str, Any]]:
"""
Expand All @@ -702,8 +734,8 @@ def variants_to_csv_rows(
List of variants.
columns : list[str]
Columns to serialize.
dtype : str, {'scores', 'counts'}
The type of data requested. Either the 'score_data' or 'count_data'.
namespaced: Optional[bool] = None
Namespace the columns or not.
na_rep : str
String to represent null values.

Expand All @@ -713,10 +745,10 @@ def variants_to_csv_rows(
"""
if mappings is not None:
return map(
lambda pair: variant_to_csv_row(pair[0], columns, dtype, mapping=pair[1], na_rep=na_rep),
lambda pair: variant_to_csv_row(pair[0], columns, mapping=pair[1], namespaced=namespaced, na_rep=na_rep),
zip(variants, mappings),
)
return map(lambda v: variant_to_csv_row(v, columns, dtype, na_rep=na_rep), variants)
return map(lambda v: variant_to_csv_row(v, columns, namespaced=namespaced, na_rep=na_rep), variants)


def find_meta_analyses_for_score_sets(db: Session, urns: list[str]) -> list[ScoreSet]:
Expand Down
47 changes: 31 additions & 16 deletions src/mavedb/routers/score_sets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
from datetime import date
from typing import Any, List, Optional, Sequence, TypedDict, Union
from typing import Any, List, Literal, Optional, Sequence, TypedDict, Union

import pandas as pd
from arq import ArqRedis
Expand Down Expand Up @@ -110,20 +110,28 @@ async def enqueue_variant_creation(
"hgvs_splice",
"hgvs_pro",
] + item.dataset_columns.get("score_columns", [])
# score_columns = {
# "core": ["hgvs_nt", "hgvs_splice", "hgvs_pro"],
# "counts": item.dataset_columns["score_columns"],
# }
existing_scores_df = pd.DataFrame(
variants_to_csv_rows(item.variants, columns=score_columns, dtype="score_data")
variants_to_csv_rows(item.variants, columns=score_columns, namespaced=False)
).replace("NA", pd.NA)

# create CSV from existing variants on the score set if no new dataframe provided
existing_counts_df = None
if new_counts_df is None and item.dataset_columns.get("count_columns"):
# count_columns = {
# "core": ["hgvs_nt", "hgvs_splice", "hgvs_pro"],
# "counts": item.dataset_columns["count_columns"],
# }
count_columns = [
"hgvs_nt",
"hgvs_splice",
"hgvs_pro",
] + item.dataset_columns["count_columns"]
"hgvs_nt",
"hgvs_splice",
"hgvs_pro",
] + item.dataset_columns["count_columns"]
existing_counts_df = pd.DataFrame(
variants_to_csv_rows(item.variants, columns=count_columns, dtype="count_data")
variants_to_csv_rows(item.variants, columns=count_columns, namespaced=False)
).replace("NA", pd.NA)

# Await the insertion of this job into the worker queue, not the job itself.
Expand Down Expand Up @@ -638,7 +646,13 @@ def get_score_set_variants_csv(
urn: str,
start: int = Query(default=None, description="Start index for pagination"),
limit: int = Query(default=None, description="Maximum number of variants to return"),
namespaces: List[Literal["scores", "counts"]] = Query(
default=["scores"],
description="One or more data types to include: scores, counts, clinVar, gnomAD"
),
drop_na_columns: Optional[bool] = None,
include_custom_columns: Optional[bool] = None,
include_post_mapped_hgvs: Optional[bool] = None,
db: Session = Depends(deps.get_db),
user_data: Optional[UserData] = Depends(get_current_user),
) -> Any:
Expand All @@ -648,12 +662,9 @@ def get_score_set_variants_csv(
This differs from get_score_set_scores_csv() in that it returns only the HGVS columns, score column, and mapped HGVS
string.

TODO (https://github.com/VariantEffect/mavedb-api/issues/446) We may want to turn this into a general-purpose CSV
TODO (https://github.com/VariantEffect/mavedb-api/issues/446) We may add another function for ClinVar and gnomAD.
export endpoint, with options governing which columns to include.

Parameters
__________

Parameters
__________
urn : str
Expand All @@ -662,6 +673,9 @@ def get_score_set_variants_csv(
The index to start from. If None, starts from the beginning.
limit : Optional[int]
The maximum number of variants to return. If None, returns all variants.
namespaces: List[Literal["scores", "counts"]]
The namespaces of all columns except for accession, hgvs_nt, hgvs_pro, and hgvs_splice.
We may add ClinVar and gnomAD in the future.
drop_na_columns : bool, optional
Whether to drop columns that contain only NA values. Defaults to False.
db : Session
Expand Down Expand Up @@ -701,12 +715,13 @@ def get_score_set_variants_csv(
csv_str = get_score_set_variants_as_csv(
db,
score_set,
"scores",
namespaces,
True,
start,
limit,
drop_na_columns,
include_custom_columns=False,
include_post_mapped_hgvs=True,
include_custom_columns,
include_post_mapped_hgvs,
)
return StreamingResponse(iter([csv_str]), media_type="text/csv")

Expand Down Expand Up @@ -762,7 +777,7 @@ def get_score_set_scores_csv(

assert_permission(user_data, score_set, Action.READ)

csv_str = get_score_set_variants_as_csv(db, score_set, "scores", start, limit, drop_na_columns)
csv_str = get_score_set_variants_as_csv(db, score_set, ["scores"], False, start, limit, drop_na_columns)
return StreamingResponse(iter([csv_str]), media_type="text/csv")


Expand Down Expand Up @@ -817,7 +832,7 @@ async def get_score_set_counts_csv(

assert_permission(user_data, score_set, Action.READ)

csv_str = get_score_set_variants_as_csv(db, score_set, "counts", start, limit, drop_na_columns)
csv_str = get_score_set_variants_as_csv(db, score_set, ["counts"], False, start, limit, drop_na_columns)
return StreamingResponse(iter([csv_str]), media_type="text/csv")


Expand Down
4 changes: 2 additions & 2 deletions src/mavedb/scripts/export_public_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,12 @@ def export_public_data(db: Session):
logger.info(f"{i + 1}/{num_score_sets} Exporting variants for score set {score_set.urn}")
csv_filename_base = score_set.urn.replace(":", "-")

csv_str = get_score_set_variants_as_csv(db, score_set, "scores")
csv_str = get_score_set_variants_as_csv(db, score_set, ["scores"])
zipfile.writestr(f"csv/{csv_filename_base}.scores.csv", csv_str)

count_columns = score_set.dataset_columns["count_columns"] if score_set.dataset_columns else None
if count_columns and len(count_columns) > 0:
csv_str = get_score_set_variants_as_csv(db, score_set, "counts")
csv_str = get_score_set_variants_as_csv(db, score_set, ["counts"])
zipfile.writestr(f"csv/{csv_filename_base}.counts.csv", csv_str)


Expand Down
Loading