Skip to content

Commit 7d3e575

Browse files
authored
Merge branch 'master' into plink-converter-2024-03-26-tristanpwdennis-shadow
2 parents e07f92c + 1ba4714 commit 7d3e575

File tree

4 files changed

+171
-15
lines changed

4 files changed

+171
-15
lines changed

malariagen_data/anoph/frq_params.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Parameter definitions for functions computing and plotting allele frequencies."""
22

3-
from typing import Literal
3+
from typing import Literal, List, Optional, Tuple, Union
44

55
import xarray as xr
66
from typing_extensions import Annotated, TypeAlias
@@ -70,3 +70,13 @@
7070
bool,
7171
"Include columns with allele counts and number of non-missing allele calls (nobs).",
7272
]
73+
74+
taxa: TypeAlias = Annotated[
75+
Optional[Union[str, List[str], Tuple[str, ...]]],
76+
"The taxon or taxa to restrict the dataset to.",
77+
]
78+
79+
areas: TypeAlias = Annotated[
80+
Optional[Union[str, List[str], Tuple[str, ...]]],
81+
"The area or areas to restrict the dataset to.",
82+
]

malariagen_data/anoph/snp_frq.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,8 @@ def plot_frequencies_time_series(
936936
legend_sizing: plotly_params.legend_sizing = "constant",
937937
show: plotly_params.show = True,
938938
renderer: plotly_params.renderer = None,
939+
taxa: frq_params.taxa = None,
940+
areas: frq_params.areas = None,
939941
**kwargs,
940942
) -> plotly_params.figure:
941943
# Handle title.
@@ -947,6 +949,18 @@ def plot_frequencies_time_series(
947949
df_cohorts = ds[cohort_vars].to_dataframe()
948950
df_cohorts.columns = [c.split("cohort_")[1] for c in df_cohorts.columns] # type: ignore
949951

952+
# If specified, restrict the dataframe by taxa.
953+
if isinstance(taxa, str):
954+
df_cohorts = df_cohorts[df_cohorts["taxon"] == taxa]
955+
elif isinstance(taxa, (list, tuple)):
956+
df_cohorts = df_cohorts[df_cohorts["taxon"].isin(taxa)]
957+
958+
# If specified, restrict the dataframe by areas.
959+
if isinstance(areas, str):
960+
df_cohorts = df_cohorts[df_cohorts["area"] == areas]
961+
elif isinstance(areas, (list, tuple)):
962+
df_cohorts = df_cohorts[df_cohorts["area"].isin(areas)]
963+
950964
# Extract variant labels.
951965
variant_labels = ds["variant_label"].values
952966

notebooks/plot_frequencies_space_time.ipynb

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
{
22
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "47f669f3",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"import malariagen_data"
11+
]
12+
},
313
{
414
"cell_type": "code",
515
"execution_count": null,
616
"id": "f820bc66-2fb2-4ca2-9b54-824e50d61a0a",
717
"metadata": {},
818
"outputs": [],
919
"source": [
10-
"import malariagen_data\n",
11-
"\n",
1220
"ag3 = malariagen_data.Ag3(\n",
1321
" \"simplecache::gs://vo_agam_release_master_us_central1\",\n",
1422
" simplecache=dict(cache_storage=\"../gcs_cache\"),\n",
@@ -23,8 +31,6 @@
2331
"metadata": {},
2432
"outputs": [],
2533
"source": [
26-
"import malariagen_data\n",
27-
"\n",
2834
"af1 = malariagen_data.Af1(\n",
2935
" \"simplecache::gs://vo_afun_release_master_us_central1\",\n",
3036
" simplecache=dict(cache_storage=\"../gcs_cache\"),\n",
@@ -69,6 +75,26 @@
6975
"ag3.plot_frequencies_time_series(ds, height=500, width=1000)"
7076
]
7177
},
78+
{
79+
"cell_type": "code",
80+
"execution_count": null,
81+
"id": "790c99e8",
82+
"metadata": {},
83+
"outputs": [],
84+
"source": [
85+
"ag3.plot_frequencies_time_series(ds, taxa=\"gambiae\", height=500, width=1000)"
86+
]
87+
},
88+
{
89+
"cell_type": "code",
90+
"execution_count": null,
91+
"id": "1bfc7298",
92+
"metadata": {},
93+
"outputs": [],
94+
"source": [
95+
"ag3.plot_frequencies_time_series(ds, taxa=(\"gambiae\", \"arabiensis\"), height=500, width=1000)"
96+
]
97+
},
7298
{
7399
"cell_type": "code",
74100
"execution_count": null,
@@ -252,6 +278,26 @@
252278
"ag3.plot_frequencies_time_series(ds, height=900, width=900)"
253279
]
254280
},
281+
{
282+
"cell_type": "code",
283+
"execution_count": null,
284+
"id": "e16ab3fe",
285+
"metadata": {},
286+
"outputs": [],
287+
"source": [
288+
"ag3.plot_frequencies_time_series(ds, areas=\"BF-09\", height=400, width=900)"
289+
]
290+
},
291+
{
292+
"cell_type": "code",
293+
"execution_count": null,
294+
"id": "26af27a1",
295+
"metadata": {},
296+
"outputs": [],
297+
"source": [
298+
"ag3.plot_frequencies_time_series(ds, areas=(\"BF-09\", \"TZ-25\"), height=400, width=900)"
299+
]
300+
},
255301
{
256302
"cell_type": "code",
257303
"execution_count": null,
@@ -336,19 +382,11 @@
336382
"source": [
337383
"af1.plot_frequencies_interactive_map(ds)"
338384
]
339-
},
340-
{
341-
"cell_type": "code",
342-
"execution_count": null,
343-
"id": "a512b459",
344-
"metadata": {},
345-
"outputs": [],
346-
"source": []
347385
}
348386
],
349387
"metadata": {
350388
"kernelspec": {
351-
"display_name": "Python 3 (ipykernel)",
389+
"display_name": "mgen_data_py3.11",
352390
"language": "python",
353391
"name": "python3"
354392
},
@@ -362,7 +400,7 @@
362400
"name": "python",
363401
"nbconvert_exporter": "python",
364402
"pygments_lexer": "ipython3",
365-
"version": "3.10.12"
403+
"version": "3.11.5"
366404
},
367405
"widgets": {
368406
"application/vnd.jupyter.widget-state+json": {

tests/anoph/test_snp_frq.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,6 +1528,100 @@ def test_plot_frequencies_time_series(
15281528
assert isinstance(fig, go.Figure)
15291529

15301530

1531+
@parametrize_with_cases("fixture,api", cases=".")
1532+
def test_plot_frequencies_time_series_with_taxa(
1533+
fixture,
1534+
api: AnophelesSnpFrequencyAnalysis,
1535+
):
1536+
# Pick test parameters at random.
1537+
all_sample_sets = api.sample_sets()["sample_set"].to_list()
1538+
sample_sets = random.choice(all_sample_sets)
1539+
site_mask = random.choice(api.site_mask_ids + (None,))
1540+
transcript = random_transcript(api=api).name
1541+
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
1542+
period_by = random.choice(["year", "quarter", "month"])
1543+
1544+
# Pick a random taxon and taxa from valid taxa.
1545+
sample_sets_taxa = (
1546+
api.sample_metadata(sample_sets=sample_sets)["taxon"].dropna().unique().tolist()
1547+
)
1548+
taxon = random.choice(sample_sets_taxa)
1549+
taxa = random.sample(sample_sets_taxa, random.randint(1, len(sample_sets_taxa)))
1550+
1551+
# Compute SNP frequencies.
1552+
ds = api.snp_allele_frequencies_advanced(
1553+
transcript=transcript,
1554+
area_by=area_by,
1555+
period_by=period_by,
1556+
sample_sets=sample_sets,
1557+
min_cohort_size=1, # Don't exclude any samples.
1558+
site_mask=site_mask,
1559+
)
1560+
1561+
# Trim things down a bit for speed.
1562+
ds = ds.isel(variants=slice(0, 100))
1563+
1564+
# Plot with taxon.
1565+
fig = api.plot_frequencies_time_series(ds, show=False, taxa=taxon)
1566+
1567+
# Test taxon plot.
1568+
assert isinstance(fig, go.Figure)
1569+
1570+
# Plot with taxa.
1571+
fig = api.plot_frequencies_time_series(ds, show=False, taxa=taxa)
1572+
1573+
# Test taxa plot.
1574+
assert isinstance(fig, go.Figure)
1575+
1576+
1577+
@parametrize_with_cases("fixture,api", cases=".")
1578+
def test_plot_frequencies_time_series_with_areas(
1579+
fixture,
1580+
api: AnophelesSnpFrequencyAnalysis,
1581+
):
1582+
# Pick test parameters at random.
1583+
all_sample_sets = api.sample_sets()["sample_set"].to_list()
1584+
sample_sets = random.choice(all_sample_sets)
1585+
site_mask = random.choice(api.site_mask_ids + (None,))
1586+
transcript = random_transcript(api=api).name
1587+
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
1588+
period_by = random.choice(["year", "quarter", "month"])
1589+
1590+
# Compute SNP frequencies.
1591+
ds = api.snp_allele_frequencies_advanced(
1592+
transcript=transcript,
1593+
area_by=area_by,
1594+
period_by=period_by,
1595+
sample_sets=sample_sets,
1596+
min_cohort_size=1, # Don't exclude any samples.
1597+
site_mask=site_mask,
1598+
)
1599+
1600+
# Trim things down a bit for speed.
1601+
ds = ds.isel(variants=slice(0, 100))
1602+
1603+
# Extract cohorts into a DataFrame.
1604+
cohort_vars = [v for v in ds if str(v).startswith("cohort_")]
1605+
df_cohorts = ds[cohort_vars].to_dataframe()
1606+
1607+
# Pick a random area and areas from valid areas.
1608+
cohorts_areas = df_cohorts["cohort_area"].dropna().unique().tolist()
1609+
area = random.choice(cohorts_areas)
1610+
areas = random.sample(cohorts_areas, random.randint(1, len(cohorts_areas)))
1611+
1612+
# Plot with area.
1613+
fig = api.plot_frequencies_time_series(ds, show=False, areas=area)
1614+
1615+
# Test areas plot.
1616+
assert isinstance(fig, go.Figure)
1617+
1618+
# Plot with areas.
1619+
fig = api.plot_frequencies_time_series(ds, show=False, areas=areas)
1620+
1621+
# Test area plot.
1622+
assert isinstance(fig, go.Figure)
1623+
1624+
15311625
@parametrize_with_cases("fixture,api", cases=".")
15321626
def test_plot_frequencies_interactive_map(
15331627
fixture,

0 commit comments

Comments
 (0)