Skip to content

Commit b24f3c0

Browse files
committed
Modify the function of get_score_set_variants_csv to allow downloading multiple data types together.
1 parent d088e1a commit b24f3c0

File tree

3 files changed

+89
-33
lines changed

3 files changed

+89
-33
lines changed

src/mavedb/lib/score_sets.py

Lines changed: 73 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
import re
55
from operator import attrgetter
6-
from typing import Any, BinaryIO, Iterable, Optional, TYPE_CHECKING, Sequence, Literal
6+
from typing import Any, BinaryIO, Iterable, List, Optional, TYPE_CHECKING, Sequence, Literal
77

88
from mavedb.models.mapped_variant import MappedVariant
99
import numpy as np
@@ -401,12 +401,12 @@ def find_publish_or_private_superseded_score_set_tail(
401401
def get_score_set_variants_as_csv(
402402
db: Session,
403403
score_set: ScoreSet,
404-
data_type: Literal["scores", "counts"],
404+
data_types: List[Literal["scores", "counts", "clinVar", "gnomAD"]],
405405
start: Optional[int] = None,
406406
limit: Optional[int] = None,
407407
drop_na_columns: Optional[bool] = None,
408-
include_custom_columns: bool = True,
409-
include_post_mapped_hgvs: bool = False,
408+
include_custom_columns: Optional[bool] = True,
409+
include_post_mapped_hgvs: Optional[bool] = False,
410410
) -> str:
411411
"""
412412
Get the variant data from a score set as a CSV string.
@@ -417,8 +417,8 @@ def get_score_set_variants_as_csv(
417417
The database session to use.
418418
score_set : ScoreSet
419419
The score set to get the variants from.
420-
data_type : {'scores', 'counts'}
421-
The type of data to get. Either 'scores' or 'counts'.
420+
data_types : List[Literal["scores", "counts", "clinVar", "gnomAD"]]
421+
The data types to get. Either one of 'scores', 'counts', 'clinVar', 'gnomAD' or some of them.
422422
start : int, optional
423423
The index to start from. If None, starts from the beginning.
424424
limit : int, optional
@@ -437,18 +437,33 @@ def get_score_set_variants_as_csv(
437437
The CSV string containing the variant data.
438438
"""
439439
assert type(score_set.dataset_columns) is dict
440-
custom_columns_set = "score_columns" if data_type == "scores" else "count_columns"
441-
type_column = "score_data" if data_type == "scores" else "count_data"
442-
440+
custom_columns = {
441+
"scores": "score_columns",
442+
"counts": "count_columns",
443+
}
444+
custom_columns_set = [custom_columns[dt] for dt in data_types if dt in custom_columns]
445+
type_to_column = {
446+
"scores": "score_data",
447+
"counts": "count_data"
448+
}
449+
type_columns = [type_to_column[dt] for dt in data_types if dt in type_to_column]
443450
columns = ["accession", "hgvs_nt", "hgvs_splice", "hgvs_pro"]
444451
if include_post_mapped_hgvs:
445452
columns.append("post_mapped_hgvs_g")
446453
columns.append("post_mapped_hgvs_p")
447454

448455
if include_custom_columns:
449-
custom_columns = [str(x) for x in list(score_set.dataset_columns.get(custom_columns_set, []))]
450-
columns += custom_columns
451-
elif data_type == "scores":
456+
for column in custom_columns_set:
457+
dataset_columns = [str(x) for x in list(score_set.dataset_columns.get(column, []))]
458+
if column == "score_columns":
459+
for c in dataset_columns:
460+
prefixed = "scores." + c
461+
columns.append(prefixed)
462+
elif column == "count_columns":
463+
for c in dataset_columns:
464+
prefixed = "counts." + c
465+
columns.append(prefixed)
466+
elif len(data_types) == 1 and data_types[0] == "scores":
452467
columns.append(REQUIRED_SCORE_COLUMN)
453468

454469
variants: Sequence[Variant] = []
@@ -488,7 +503,35 @@ def get_score_set_variants_as_csv(
488503
variants_query = variants_query.limit(limit)
489504
variants = db.scalars(variants_query).all()
490505

491-
rows_data = variants_to_csv_rows(variants, columns=columns, dtype=type_column, mappings=mappings) # type: ignore
506+
rows_data = variants_to_csv_rows(variants, columns=columns, dtype=type_columns, mappings=mappings) # type: ignore
507+
508+
# TODO: will add len(data_types) == 1 and "scores"/"counts" are not in [data_types] and include_post_mapped_hgvs
509+
# case when we get the clinVar and gnomAD
510+
if len(data_types) > 1 and include_post_mapped_hgvs:
511+
rename_map = {}
512+
rename_map["post_mapped_hgvs_g"] = "mavedb.post_mapped_hgvs_g"
513+
rename_map["post_mapped_hgvs_p"] = "mavedb.post_mapped_hgvs_p"
514+
515+
# Update column order list (preserve original order)
516+
columns = [rename_map.get(col, col) for col in columns]
517+
518+
# Rename keys in each row
519+
renamed_rows_data = []
520+
for row in rows_data:
521+
renamed_row = {rename_map.get(k, k): v for k, v in row.items()}
522+
renamed_rows_data.append(renamed_row)
523+
524+
rows_data = renamed_rows_data
525+
elif len(data_types) == 1:
526+
prefix = f"{data_types[0]}."
527+
columns = [col[len(prefix):] if col.startswith(prefix) else col for col in columns]
528+
529+
# Rename rows to remove the same prefix from keys
530+
renamed_rows_data = []
531+
for row in rows_data:
532+
renamed_row = {(k[len(prefix):] if k.startswith(prefix) else k): v for k, v in row.items()}
533+
renamed_rows_data.append(renamed_row)
534+
rows_data = renamed_rows_data
492535
if drop_na_columns:
493536
rows_data, columns = drop_na_columns_from_csv_file_rows(rows_data, columns)
494537

@@ -532,7 +575,7 @@ def is_null(value):
532575
def variant_to_csv_row(
533576
variant: Variant,
534577
columns: list[str],
535-
dtype: str,
578+
dtype: list[str],
536579
mapping: Optional[MappedVariant] = None,
537580
na_rep="NA",
538581
) -> dict[str, Any]:
@@ -546,7 +589,7 @@ def variant_to_csv_row(
546589
columns : list[str]
547590
Columns to serialize.
548591
dtype : str, {'scores', 'counts'}
549-
The type of data requested. Either the 'score_data' or 'count_data'.
592+
The type of data requested. ['score_data'], ['count_data'] or ['score_data', 'count_data'].
550593
na_rep : str
551594
String to represent null values.
552595
@@ -577,8 +620,18 @@ def variant_to_csv_row(
577620
else:
578621
value = ""
579622
else:
580-
parent = variant.data.get(dtype) if variant.data else None
581-
value = str(parent.get(column_key)) if parent else na_rep
623+
for dt in dtype:
624+
parent = variant.data.get(dt) if variant.data else None
625+
if column_key.startswith("scores."):
626+
inner_key = column_key.replace("scores.", "")
627+
elif column_key.startswith("counts."):
628+
inner_key = column_key.replace("counts.", "")
629+
else:
630+
# fallback for non-prefixed columns
631+
inner_key = column_key
632+
if parent and inner_key in parent:
633+
value = str(parent[inner_key])
634+
break
582635
if is_null(value):
583636
value = na_rep
584637
row[column_key] = value
@@ -589,7 +642,7 @@ def variant_to_csv_row(
589642
def variants_to_csv_rows(
590643
variants: Sequence[Variant],
591644
columns: list[str],
592-
dtype: str,
645+
dtype: List[str],
593646
mappings: Optional[Sequence[Optional[MappedVariant]]] = None,
594647
na_rep="NA",
595648
) -> Iterable[dict[str, Any]]:
@@ -602,8 +655,8 @@ def variants_to_csv_rows(
602655
List of variants.
603656
columns : list[str]
604657
Columns to serialize.
605-
dtype : str, {'scores', 'counts'}
606-
The type of data requested. Either the 'score_data' or 'count_data'.
658+
dtype : list, {'scores', 'counts'}
659+
The type of data requested. ['score_data'], ['count_data'] or ['score_data', 'count_data'].
607660
na_rep : str
608661
String to represent null values.
609662

src/mavedb/routers/score_sets.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from datetime import date
3-
from typing import Any, List, Optional, Sequence, Union
3+
from typing import Any, List, Literal, Optional, Sequence, Union
44

55
import pandas as pd
66
from arq import ArqRedis
@@ -249,7 +249,13 @@ def get_score_set_variants_csv(
249249
urn: str,
250250
start: int = Query(default=None, description="Start index for pagination"),
251251
limit: int = Query(default=None, description="Maximum number of variants to return"),
252+
data_types: List[Literal["scores", "counts", "clinVar", "gnomAD"]] = Query(
253+
default=["scores"],
254+
description="One or more data types to include: scores, counts, clinVar, gnomAD"
255+
),
252256
drop_na_columns: Optional[bool] = None,
257+
include_custom_columns: Optional[bool] = None,
258+
include_post_mapped_hgvs: Optional[bool] = None,
253259
db: Session = Depends(deps.get_db),
254260
user_data: Optional[UserData] = Depends(get_current_user),
255261
) -> Any:
@@ -262,9 +268,6 @@ def get_score_set_variants_csv(
262268
TODO (https://github.com/VariantEffect/mavedb-api/issues/446) We may want to turn this into a general-purpose CSV
263269
export endpoint, with options governing which columns to include.
264270
265-
Parameters
266-
__________
267-
268271
Parameters
269272
__________
270273
urn : str
@@ -312,12 +315,12 @@ def get_score_set_variants_csv(
312315
csv_str = get_score_set_variants_as_csv(
313316
db,
314317
score_set,
315-
"scores",
318+
data_types,
316319
start,
317320
limit,
318321
drop_na_columns,
319-
include_custom_columns=False,
320-
include_post_mapped_hgvs=True,
322+
include_custom_columns,
323+
include_post_mapped_hgvs,
321324
)
322325
return StreamingResponse(iter([csv_str]), media_type="text/csv")
323326

@@ -373,7 +376,7 @@ def get_score_set_scores_csv(
373376

374377
assert_permission(user_data, score_set, Action.READ)
375378

376-
csv_str = get_score_set_variants_as_csv(db, score_set, "scores", start, limit, drop_na_columns)
379+
csv_str = get_score_set_variants_as_csv(db, score_set, ["scores"], start, limit, drop_na_columns)
377380
return StreamingResponse(iter([csv_str]), media_type="text/csv")
378381

379382

@@ -428,7 +431,7 @@ async def get_score_set_counts_csv(
428431

429432
assert_permission(user_data, score_set, Action.READ)
430433

431-
csv_str = get_score_set_variants_as_csv(db, score_set, "counts", start, limit, drop_na_columns)
434+
csv_str = get_score_set_variants_as_csv(db, score_set, ["counts"], start, limit, drop_na_columns)
432435
return StreamingResponse(iter([csv_str]), media_type="text/csv")
433436

434437

@@ -1252,12 +1255,12 @@ async def update_score_set(
12521255
] + item.dataset_columns["count_columns"]
12531256

12541257
scores_data = pd.DataFrame(
1255-
variants_to_csv_rows(item.variants, columns=score_columns, dtype="score_data")
1258+
variants_to_csv_rows(item.variants, columns=score_columns, dtype=["score_data"])
12561259
).replace("NA", pd.NA)
12571260

12581261
if item.dataset_columns["count_columns"]:
12591262
count_data = pd.DataFrame(
1260-
variants_to_csv_rows(item.variants, columns=count_columns, dtype="count_data")
1263+
variants_to_csv_rows(item.variants, columns=count_columns, dtype=["count_data"])
12611264
).replace("NA", pd.NA)
12621265
else:
12631266
count_data = None

src/mavedb/scripts/export_public_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,12 +147,12 @@ def export_public_data(db: Session):
147147
logger.info(f"{i + 1}/{num_score_sets} Exporting variants for score set {score_set.urn}")
148148
csv_filename_base = score_set.urn.replace(":", "-")
149149

150-
csv_str = get_score_set_variants_as_csv(db, score_set, "scores")
150+
csv_str = get_score_set_variants_as_csv(db, score_set, ["scores"])
151151
zipfile.writestr(f"csv/{csv_filename_base}.scores.csv", csv_str)
152152

153153
count_columns = score_set.dataset_columns["count_columns"] if score_set.dataset_columns else None
154154
if count_columns and len(count_columns) > 0:
155-
csv_str = get_score_set_variants_as_csv(db, score_set, "counts")
155+
csv_str = get_score_set_variants_as_csv(db, score_set, ["counts"])
156156
zipfile.writestr(f"csv/{csv_filename_base}.counts.csv", csv_str)
157157

158158

0 commit comments

Comments
 (0)