Skip to content

Commit 96149bd

Browse files
committed
Modified and debug the codes and some related tests.
1 parent c6aab8f commit 96149bd

File tree

3 files changed

+90
-118
lines changed

3 files changed

+90
-118
lines changed

src/mavedb/lib/score_sets.py

Lines changed: 52 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def find_publish_or_private_superseded_score_set_tail(
401401
def get_score_set_variants_as_csv(
402402
db: Session,
403403
score_set: ScoreSet,
404-
data_types: List[Literal["scores", "counts", "clinVar", "gnomAD"]],
404+
namespaces: List[Literal["scores", "counts"]],
405405
start: Optional[int] = None,
406406
limit: Optional[int] = None,
407407
drop_na_columns: Optional[bool] = None,
@@ -417,8 +417,8 @@ def get_score_set_variants_as_csv(
417417
The database session to use.
418418
score_set : ScoreSet
419419
The score set to get the variants from.
420-
data_types : List[Literal["scores", "counts", "clinVar", "gnomAD"]]
421-
The data types to get. Either one of 'scores', 'counts', 'clinVar', 'gnomAD' or some of them.
420+
namespaces : List[Literal["scores", "counts"]]
421+
The namespaces for data. Now there are only scores and counts. There will be ClinVar and gnomAD.
422422
start : int, optional
423423
The index to start from. If None, starts from the beginning.
424424
limit : int, optional
@@ -437,35 +437,26 @@ def get_score_set_variants_as_csv(
437437
The CSV string containing the variant data.
438438
"""
439439
assert type(score_set.dataset_columns) is dict
440-
custom_columns = {
441-
"scores": "score_columns",
442-
"counts": "count_columns",
440+
namespaced_score_set_columns: dict[str, list[str]] = {
441+
"core": ["accession", "hgvs_nt", "hgvs_splice", "hgvs_pro"],
442+
"mavedb": [],
443443
}
444-
custom_columns_set = [custom_columns[dt] for dt in data_types if dt in custom_columns]
445-
type_to_column = {
446-
"scores": "score_data",
447-
"counts": "count_data"
448-
}
449-
type_columns = [type_to_column[dt] for dt in data_types if dt in type_to_column]
450-
columns = ["accession", "hgvs_nt", "hgvs_splice", "hgvs_pro"]
451444
if include_post_mapped_hgvs:
452-
columns.append("post_mapped_hgvs_g")
453-
columns.append("post_mapped_hgvs_p")
454-
445+
namespaced_score_set_columns["mavedb"].append("post_mapped_hgvs_g")
446+
namespaced_score_set_columns["mavedb"].append("post_mapped_hgvs_p")
447+
for namespace in namespaces:
448+
namespaced_score_set_columns[namespace] = []
455449
if include_custom_columns:
456-
for column in custom_columns_set:
457-
dataset_columns = [str(x) for x in list(score_set.dataset_columns.get(column, []))]
458-
if column == "score_columns":
459-
for c in dataset_columns:
460-
prefixed = "scores." + c
461-
columns.append(prefixed)
462-
elif column == "count_columns":
463-
for c in dataset_columns:
464-
prefixed = "counts." + c
465-
columns.append(prefixed)
466-
elif len(data_types) == 1 and data_types[0] == "scores":
467-
columns.append(REQUIRED_SCORE_COLUMN)
468-
450+
if "scores" in namespaced_score_set_columns:
451+
namespaced_score_set_columns["scores"] = [
452+
col for col in [str(x) for x in list(score_set.dataset_columns.get("score_columns", []))]
453+
]
454+
if "counts" in namespaced_score_set_columns:
455+
namespaced_score_set_columns["counts"] = [
456+
col for col in [str(x) for x in list(score_set.dataset_columns.get("count_columns", []))]
457+
]
458+
elif "scores" in namespaced_score_set_columns:
459+
namespaced_score_set_columns["scores"].append(REQUIRED_SCORE_COLUMN)
469460
variants: Sequence[Variant] = []
470461
mappings: Optional[list[Optional[MappedVariant]]] = None
471462

@@ -503,40 +494,18 @@ def get_score_set_variants_as_csv(
503494
variants_query = variants_query.limit(limit)
504495
variants = db.scalars(variants_query).all()
505496

506-
rows_data = variants_to_csv_rows(variants, columns=columns, dtype=type_columns, mappings=mappings) # type: ignore
507-
508-
# TODO: will add len(data_types) == 1 and "scores"/"counts" are not in [data_types] and include_post_mapped_hgvs
509-
# case when we get the clinVar and gnomAD
510-
if len(data_types) > 1 and include_post_mapped_hgvs:
511-
rename_map = {}
512-
rename_map["post_mapped_hgvs_g"] = "mavedb.post_mapped_hgvs_g"
513-
rename_map["post_mapped_hgvs_p"] = "mavedb.post_mapped_hgvs_p"
514-
515-
# Update column order list (preserve original order)
516-
columns = [rename_map.get(col, col) for col in columns]
517-
518-
# Rename keys in each row
519-
renamed_rows_data = []
520-
for row in rows_data:
521-
renamed_row = {rename_map.get(k, k): v for k, v in row.items()}
522-
renamed_rows_data.append(renamed_row)
523-
524-
rows_data = renamed_rows_data
525-
elif len(data_types) == 1:
526-
prefix = f"{data_types[0]}."
527-
columns = [col[len(prefix):] if col.startswith(prefix) else col for col in columns]
528-
529-
# Rename rows to remove the same prefix from keys
530-
renamed_rows_data = []
531-
for row in rows_data:
532-
renamed_row = {(k[len(prefix):] if k.startswith(prefix) else k): v for k, v in row.items()}
533-
renamed_rows_data.append(renamed_row)
534-
rows_data = renamed_rows_data
497+
rows_data = variants_to_csv_rows(variants, columns=namespaced_score_set_columns, mappings=mappings) # type: ignore
498+
rows_columns = [
499+
f"{namespace}.{col}" if namespace != "core" else col
500+
for namespace, cols in namespaced_score_set_columns.items()
501+
for col in cols
502+
]
503+
535504
if drop_na_columns:
536-
rows_data, columns = drop_na_columns_from_csv_file_rows(rows_data, columns)
505+
rows_data, rows_columns = drop_na_columns_from_csv_file_rows(rows_data, rows_columns)
537506

538507
stream = io.StringIO()
539-
writer = csv.DictWriter(stream, fieldnames=columns, quoting=csv.QUOTE_MINIMAL)
508+
writer = csv.DictWriter(stream, fieldnames=rows_columns, quoting=csv.QUOTE_MINIMAL)
540509
writer.writeheader()
541510
writer.writerows(rows_data)
542511
return stream.getvalue()
@@ -574,8 +543,7 @@ def is_null(value):
574543

575544
def variant_to_csv_row(
576545
variant: Variant,
577-
columns: list[str],
578-
dtype: list[str],
546+
columns: dict[str, list[str]],
579547
mapping: Optional[MappedVariant] = None,
580548
na_rep="NA",
581549
) -> dict[str, Any]:
@@ -588,17 +556,16 @@ def variant_to_csv_row(
588556
List of variants.
589557
columns : list[str]
590558
Columns to serialize.
591-
dtype : str, {'scores', 'counts'}
592-
The type of data requested. ['score_data'], ['count_data'] or ['score_data', 'count_data'].
593559
na_rep : str
594560
String to represent null values.
595561
596562
Returns
597563
-------
598564
dict[str, Any]
599565
"""
600-
row = {}
601-
for column_key in columns:
566+
row: dict[str, Any] = {}
567+
# Handle each column key explicitly as part of its namespace.
568+
for column_key in columns.get("core", []):
602569
if column_key == "hgvs_nt":
603570
value = str(variant.hgvs_nt)
604571
elif column_key == "hgvs_pro":
@@ -607,7 +574,13 @@ def variant_to_csv_row(
607574
value = str(variant.hgvs_splice)
608575
elif column_key == "accession":
609576
value = str(variant.urn)
610-
elif column_key == "post_mapped_hgvs_g":
577+
if is_null(value):
578+
value = na_rep
579+
580+
# export columns in the `core` namespace without a namespace
581+
row[column_key] = value
582+
for column_key in columns.get("mavedb", []):
583+
if column_key == "post_mapped_hgvs_g":
611584
hgvs_str = get_hgvs_from_post_mapped(mapping.post_mapped) if mapping and mapping.post_mapped else None
612585
if hgvs_str is not None and is_hgvs_g(hgvs_str):
613586
value = hgvs_str
@@ -619,30 +592,23 @@ def variant_to_csv_row(
619592
value = hgvs_str
620593
else:
621594
value = ""
622-
else:
623-
for dt in dtype:
624-
parent = variant.data.get(dt) if variant.data else None
625-
if column_key.startswith("scores."):
626-
inner_key = column_key.replace("scores.", "")
627-
elif column_key.startswith("counts."):
628-
inner_key = column_key.replace("counts.", "")
629-
else:
630-
# fallback for non-prefixed columns
631-
inner_key = column_key
632-
if parent and inner_key in parent:
633-
value = str(parent[inner_key])
634-
break
635595
if is_null(value):
636596
value = na_rep
637-
row[column_key] = value
638-
597+
row[f"mavedb.{column_key}"] = value
598+
for column_key in columns.get("scores", []):
599+
parent = variant.data.get("score_data") if variant.data else None
600+
value = str(parent.get(column_key)) if parent else na_rep
601+
row[f"scores.{column_key}"] = value
602+
for column_key in columns.get("counts", []):
603+
parent = variant.data.get("count_data") if variant.data else None
604+
value = str(parent.get(column_key)) if parent else na_rep
605+
row[f"counts.{column_key}"] = value
639606
return row
640607

641608

642609
def variants_to_csv_rows(
643610
variants: Sequence[Variant],
644-
columns: list[str],
645-
dtype: List[str],
611+
columns: dict[str, list[str]],
646612
mappings: Optional[Sequence[Optional[MappedVariant]]] = None,
647613
na_rep="NA",
648614
) -> Iterable[dict[str, Any]]:
@@ -655,8 +621,6 @@ def variants_to_csv_rows(
655621
List of variants.
656622
columns : list[str]
657623
Columns to serialize.
658-
dtype : list, {'scores', 'counts'}
659-
The type of data requested. ['score_data'], ['count_data'] or ['score_data', 'count_data'].
660624
na_rep : str
661625
String to represent null values.
662626
@@ -666,10 +630,10 @@ def variants_to_csv_rows(
666630
"""
667631
if mappings is not None:
668632
return map(
669-
lambda pair: variant_to_csv_row(pair[0], columns, dtype, mapping=pair[1], na_rep=na_rep),
633+
lambda pair: variant_to_csv_row(pair[0], columns, mapping=pair[1], na_rep=na_rep),
670634
zip(variants, mappings),
671635
)
672-
return map(lambda v: variant_to_csv_row(v, columns, dtype, na_rep=na_rep), variants)
636+
return map(lambda v: variant_to_csv_row(v, columns, na_rep=na_rep), variants)
673637

674638

675639
def find_meta_analyses_for_score_sets(db: Session, urns: list[str]) -> list[ScoreSet]:

src/mavedb/routers/score_sets.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def get_score_set_variants_csv(
249249
urn: str,
250250
start: int = Query(default=None, description="Start index for pagination"),
251251
limit: int = Query(default=None, description="Maximum number of variants to return"),
252-
data_types: List[Literal["scores", "counts", "clinVar", "gnomAD"]] = Query(
252+
namespaces: List[Literal["scores", "counts"]] = Query(
253253
default=["scores"],
254254
description="One or more data types to include: scores, counts, clinVar, gnomAD"
255255
),
@@ -265,7 +265,7 @@ def get_score_set_variants_csv(
265265
This differs from get_score_set_scores_csv() in that it returns only the HGVS columns, score column, and mapped HGVS
266266
string.
267267
268-
TODO (https://github.com/VariantEffect/mavedb-api/issues/446) We may want to turn this into a general-purpose CSV
268+
TODO (https://github.com/VariantEffect/mavedb-api/issues/446) We may add another function for ClinVar and gnomAD.
269269
export endpoint, with options governing which columns to include.
270270
271271
Parameters
@@ -276,6 +276,9 @@ def get_score_set_variants_csv(
276276
The index to start from. If None, starts from the beginning.
277277
limit : Optional[int]
278278
The maximum number of variants to return. If None, returns all variants.
279+
namespaces: List[Literal["scores", "counts"]]
280+
The namespaces of all columns except for accession, hgvs_nt, hgvs_pro, and hgvs_splice.
281+
We may add ClinVar and gnomAD in the future.
279282
drop_na_columns : bool, optional
280283
Whether to drop columns that contain only NA values. Defaults to False.
281284
db : Session
@@ -315,7 +318,7 @@ def get_score_set_variants_csv(
315318
csv_str = get_score_set_variants_as_csv(
316319
db,
317320
score_set,
318-
data_types,
321+
namespaces,
319322
start,
320323
limit,
321324
drop_na_columns,
@@ -377,6 +380,10 @@ def get_score_set_scores_csv(
377380
assert_permission(user_data, score_set, Action.READ)
378381

379382
csv_str = get_score_set_variants_as_csv(db, score_set, ["scores"], start, limit, drop_na_columns)
383+
lines = csv_str.splitlines()
384+
if lines:
385+
header = lines[0].replace("scores.", "")
386+
csv_str = "\n".join([header] + lines[1:])
380387
return StreamingResponse(iter([csv_str]), media_type="text/csv")
381388

382389

@@ -432,6 +439,10 @@ async def get_score_set_counts_csv(
432439
assert_permission(user_data, score_set, Action.READ)
433440

434441
csv_str = get_score_set_variants_as_csv(db, score_set, ["counts"], start, limit, drop_na_columns)
442+
lines = csv_str.splitlines()
443+
if lines:
444+
header = lines[0].replace("counts.", "")
445+
csv_str = "\n".join([header] + lines[1:])
435446
return StreamingResponse(iter([csv_str]), media_type="text/csv")
436447

437448

@@ -1243,24 +1254,22 @@ async def update_score_set(
12431254
# re-validate existing variants and clear them if they do not pass validation
12441255
if item.variants:
12451256
assert item.dataset_columns is not None
1246-
score_columns = [
1247-
"hgvs_nt",
1248-
"hgvs_splice",
1249-
"hgvs_pro",
1250-
] + item.dataset_columns["score_columns"]
1251-
count_columns = [
1252-
"hgvs_nt",
1253-
"hgvs_splice",
1254-
"hgvs_pro",
1255-
] + item.dataset_columns["count_columns"]
1257+
score_columns = {
1258+
"core": ["hgvs_nt", "hgvs_splice", "hgvs_pro"],
1259+
"mavedb": item.dataset_columns["score_columns"],
1260+
}
1261+
count_columns = {
1262+
"core": ["hgvs_nt", "hgvs_splice", "hgvs_pro"],
1263+
"mavedb": item.dataset_columns["count_columns"],
1264+
}
12561265

12571266
scores_data = pd.DataFrame(
1258-
variants_to_csv_rows(item.variants, columns=score_columns, dtype=["score_data"])
1267+
variants_to_csv_rows(item.variants, columns=score_columns)
12591268
).replace("NA", pd.NA)
12601269

12611270
if item.dataset_columns["count_columns"]:
12621271
count_data = pd.DataFrame(
1263-
variants_to_csv_rows(item.variants, columns=count_columns, dtype=["count_data"])
1272+
variants_to_csv_rows(item.variants, columns=count_columns)
12641273
).replace("NA", pd.NA)
12651274
else:
12661275
count_data = None

0 commit comments

Comments
 (0)