Skip to content

Commit 9733b91

Browse files
committed
Solve #230 All-NA hgvs columns in input problem. Remove NA columns from downloading file. Add some related tests. Failed tests due a bug.
1 parent af5606f commit 9733b91

File tree

3 files changed

+82
-2
lines changed

3 files changed

+82
-2
lines changed

src/mavedb/lib/score_sets.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from mavedb.lib.mave.utils import is_csv_null
2323
from mavedb.lib.validation.constants.general import null_values_list
24+
from mavedb.lib.validation.utilities import is_null as validate_is_null
2425
from mavedb.models.contributor import Contributor
2526
from mavedb.models.controlled_keyword import ControlledKeyword
2627
from mavedb.models.doi_identifier import DoiIdentifier
@@ -311,6 +312,7 @@ def get_score_set_counts_as_csv(
311312
score_set: ScoreSet,
312313
start: Optional[int] = None,
313314
limit: Optional[int] = None,
315+
download: Optional[bool] = None,
314316
) -> str:
315317
assert type(score_set.dataset_columns) is dict
316318
count_columns = [str(x) for x in list(score_set.dataset_columns.get("count_columns", []))]
@@ -329,6 +331,9 @@ def get_score_set_counts_as_csv(
329331
variants = db.scalars(variants_query).all()
330332

331333
rows_data = variants_to_csv_rows(variants, columns=columns, dtype=type_column)
334+
if download:
335+
rows_data, columns = process_downloadable_data(rows_data, columns)
336+
332337
stream = io.StringIO()
333338
writer = csv.DictWriter(stream, fieldnames=columns, quoting=csv.QUOTE_MINIMAL)
334339
writer.writeheader()
@@ -341,6 +346,7 @@ def get_score_set_scores_as_csv(
341346
score_set: ScoreSet,
342347
start: Optional[int] = None,
343348
limit: Optional[int] = None,
349+
download: Optional[bool] = None,
344350
) -> str:
345351
assert type(score_set.dataset_columns) is dict
346352
score_columns = [str(x) for x in list(score_set.dataset_columns.get("score_columns", []))]
@@ -359,13 +365,38 @@ def get_score_set_scores_as_csv(
359365
variants = db.scalars(variants_query).all()
360366

361367
rows_data = variants_to_csv_rows(variants, columns=columns, dtype=type_column)
368+
if download:
369+
rows_data, columns = process_downloadable_data(rows_data, columns)
370+
362371
stream = io.StringIO()
363372
writer = csv.DictWriter(stream, fieldnames=columns, quoting=csv.QUOTE_MINIMAL)
364373
writer.writeheader()
365374
writer.writerows(rows_data)
366375
return stream.getvalue()
367376

368377

378+
def process_downloadable_data(
379+
rows_data: Iterable[dict[str, Any]],
380+
columns: list[str]
381+
) -> tuple[list[str], list[dict[str, Any]]]:
382+
"""Process rows_data for downloadable CSV by removing empty columns."""
383+
# Convert map to list.
384+
rows_data = list(rows_data)
385+
columns_to_check = ["hgvs_nt", "hgvs_splice", "hgvs_pro"]
386+
columns_to_remove = []
387+
388+
# Check if all values in a column are None or "NA"
389+
for col in columns_to_check:
390+
if all(validate_is_null(row[col]) for row in rows_data):
391+
columns_to_remove.append(col)
392+
for row in rows_data:
393+
row.pop(col, None) # Remove column from each row
394+
395+
# Remove these columns from the header list
396+
columns = [col for col in columns if col not in columns_to_remove]
397+
return rows_data, columns
398+
399+
369400
null_values_re = re.compile(r"\s+|none|nan|na|undefined|n/a|null|nil", flags=re.IGNORECASE)
370401

371402

src/mavedb/routers/score_sets.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def get_score_set_scores_csv(
180180
urn: str,
181181
start: int = Query(default=None, description="Start index for pagination"),
182182
limit: int = Query(default=None, description="Number of variants to return"),
183+
download: Optional[bool] = None,
183184
db: Session = Depends(deps.get_db),
184185
user_data: Optional[UserData] = Depends(get_current_user),
185186
) -> Any:
@@ -214,7 +215,7 @@ def get_score_set_scores_csv(
214215

215216
assert_permission(user_data, score_set, Action.READ)
216217

217-
csv_str = get_score_set_scores_as_csv(db, score_set, start, limit)
218+
csv_str = get_score_set_scores_as_csv(db, score_set, start, limit, download)
218219
return StreamingResponse(iter([csv_str]), media_type="text/csv")
219220

220221

@@ -234,6 +235,7 @@ async def get_score_set_counts_csv(
234235
urn: str,
235236
start: int = Query(default=None, description="Start index for pagination"),
236237
limit: int = Query(default=None, description="Number of variants to return"),
238+
download: Optional[bool] = None,
237239
db: Session = Depends(deps.get_db),
238240
user_data: Optional[UserData] = Depends(get_current_user),
239241
) -> Any:
@@ -268,7 +270,7 @@ async def get_score_set_counts_csv(
268270

269271
assert_permission(user_data, score_set, Action.READ)
270272

271-
csv_str = get_score_set_counts_as_csv(db, score_set, start, limit)
273+
csv_str = get_score_set_counts_as_csv(db, score_set, start, limit, download)
272274
return StreamingResponse(iter([csv_str]), media_type="text/csv")
273275

274276

tests/routers/test_score_set.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,3 +1749,50 @@ def test_score_set_not_found_for_non_existent_score_set_when_adding_score_calibr
17491749

17501750
assert response.status_code == 404
17511751
assert "score_calibrations" not in response_data
1752+
1753+
1754+
########################################################################################################################
1755+
# Score set download files
1756+
########################################################################################################################
1757+
1758+
# Test file doesn't have hgvs_splice so its values are all NA.
1759+
def test_download_scores_file(session, data_provider, client, setup_router_db, data_files):
1760+
experiment = create_experiment(client)
1761+
score_set = create_seq_score_set_with_variants(
1762+
client, session, data_provider, experiment["urn"], data_files / "scores.csv"
1763+
)
1764+
1765+
publish_score_set_response = client.post(f"/api/v1/score-sets/{score_set['urn']}/publish")
1766+
assert publish_score_set_response.status_code == 200
1767+
publish_score_set = publish_score_set_response.json()
1768+
print(publish_score_set)
1769+
1770+
download_scores_csv_response = client.get(f"/api/v1/score-sets/{publish_score_set['urn']}/scores?download=true")
1771+
assert download_scores_csv_response.status_code == 200
1772+
download_scores_csv = download_scores_csv_response.text
1773+
csv_header = download_scores_csv.split("\n")[0]
1774+
columns = csv_header.split(",")
1775+
assert "hgvs_nt" in columns
1776+
assert "hgvs_pro" in columns
1777+
assert "hgvs_splice" not in columns
1778+
1779+
1780+
def test_download_counts_file(session, data_provider, client, setup_router_db, data_files):
1781+
experiment = create_experiment(client)
1782+
score_set = create_seq_score_set_with_variants(
1783+
client, session, data_provider, experiment["urn"],
1784+
scores_csv_path=data_files / "scores.csv",
1785+
counts_csv_path = data_files / "counts.csv"
1786+
)
1787+
publish_score_set_response = client.post(f"/api/v1/score-sets/{score_set['urn']}/publish")
1788+
assert publish_score_set_response.status_code == 200
1789+
publish_score_set = publish_score_set_response.json()
1790+
1791+
download_counts_csv_response = client.get(f"/api/v1/score-sets/{publish_score_set['urn']}/counts?download=true")
1792+
assert download_counts_csv_response.status_code == 200
1793+
download_counts_csv = download_counts_csv_response.text
1794+
csv_header = download_counts_csv.split("\n")[0]
1795+
columns = csv_header.split(",")
1796+
assert "hgvs_nt" in columns
1797+
assert "hgvs_pro" in columns
1798+
assert "hgvs_splice" not in columns

0 commit comments

Comments
 (0)