Skip to content

Commit c9d8ec4

Browse files
committed
Merge branch '367-haps-freq' of github.com:malariagen/malariagen-data-python into 367-haps-freq
2 parents 0ec6cb3 + 80f8cc5 commit c9d8ec4

File tree

6 files changed

+183
-16
lines changed

6 files changed

+183
-16
lines changed

malariagen_data/anoph/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,14 +327,18 @@ def _discover_releases(self) -> Tuple[str, ...]:
327327
)
328328
# Note: this matches v3, v3. and v3.1, but not v3001.1
329329
version_pattern = re.compile(f"^v{self._major_version_number}(\\..*)?$")
330+
# To sort the versions numerically, we use a lambda function for the "key" parameter of sorted().
331+
# The lambda function splits each version string into a list of its integer parts, using split('.') and int(), e.g. [3, 1],
332+
# which sorted() then uses to determine the order, as opposed to the default lexicographic order.
330333
discovered_releases = tuple(
331334
sorted(
332335
[
333336
self._path_to_release(d)
334337
for d in sub_dirs
335338
if version_pattern.match(d)
336339
and self._fs.exists(f"{self._base_path}/{d}/manifest.tsv")
337-
]
340+
],
341+
key=lambda v: [int(part) for part in v.split(".")],
338342
)
339343
)
340344
return discovered_releases

malariagen_data/anoph/frq_params.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Parameter definitions for functions computing and plotting allele frequencies."""
22

3-
from typing import Literal
3+
from typing import Literal, List, Optional, Tuple, Union
44

55
import xarray as xr
66
from typing_extensions import Annotated, TypeAlias
@@ -70,3 +70,13 @@
7070
bool,
7171
"Include columns with allele counts and number of non-missing allele calls (nobs).",
7272
]
73+
74+
taxa: TypeAlias = Annotated[
75+
Optional[Union[str, List[str], Tuple[str, ...]]],
76+
"The taxon or taxa to restrict the dataset to.",
77+
]
78+
79+
areas: TypeAlias = Annotated[
80+
Optional[Union[str, List[str], Tuple[str, ...]]],
81+
"The area or areas to restrict the dataset to.",
82+
]

malariagen_data/anoph/snp_frq.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,8 @@ def plot_frequencies_time_series(
931931
legend_sizing: plotly_params.legend_sizing = "constant",
932932
show: plotly_params.show = True,
933933
renderer: plotly_params.renderer = None,
934+
taxa: frq_params.taxa = None,
935+
areas: frq_params.areas = None,
934936
**kwargs,
935937
) -> plotly_params.figure:
936938
# Handle title.
@@ -942,6 +944,18 @@ def plot_frequencies_time_series(
942944
df_cohorts = ds[cohort_vars].to_dataframe()
943945
df_cohorts.columns = [c.split("cohort_")[1] for c in df_cohorts.columns] # type: ignore
944946

947+
# If specified, restrict the dataframe by taxa.
948+
if isinstance(taxa, str):
949+
df_cohorts = df_cohorts[df_cohorts["taxon"] == taxa]
950+
elif isinstance(taxa, (list, tuple)):
951+
df_cohorts = df_cohorts[df_cohorts["taxon"].isin(taxa)]
952+
953+
# If specified, restrict the dataframe by areas.
954+
if isinstance(areas, str):
955+
df_cohorts = df_cohorts[df_cohorts["area"] == areas]
956+
elif isinstance(areas, (list, tuple)):
957+
df_cohorts = df_cohorts[df_cohorts["area"].isin(areas)]
958+
945959
# Extract variant labels.
946960
variant_labels = ds["variant_label"].values
947961

malariagen_data/anopheles.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,6 +1185,11 @@ def _gene_cnv_frequencies(
11851185

11861186
freq_cols[f"frq_{coh}"] = np.concatenate([amp_freq_coh, del_freq_coh])
11871187

1188+
if len(coh_dict) == 0:
1189+
raise ValueError(
1190+
"No cohorts available for the given sample selection parameters and minimum cohort size."
1191+
)
1192+
11881193
debug("build a dataframe with the frequency columns")
11891194
df_freqs = pd.DataFrame(freq_cols)
11901195

notebooks/plot_frequencies_space_time.ipynb

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
{
22
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "47f669f3",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"import malariagen_data"
11+
]
12+
},
313
{
414
"cell_type": "code",
515
"execution_count": null,
616
"id": "f820bc66-2fb2-4ca2-9b54-824e50d61a0a",
717
"metadata": {},
818
"outputs": [],
919
"source": [
10-
"import malariagen_data\n",
11-
"\n",
1220
"ag3 = malariagen_data.Ag3(\n",
1321
" \"simplecache::gs://vo_agam_release_master_us_central1\",\n",
1422
" simplecache=dict(cache_storage=\"../gcs_cache\"),\n",
@@ -23,8 +31,6 @@
2331
"metadata": {},
2432
"outputs": [],
2533
"source": [
26-
"import malariagen_data\n",
27-
"\n",
2834
"af1 = malariagen_data.Af1(\n",
2935
" \"simplecache::gs://vo_afun_release_master_us_central1\",\n",
3036
" simplecache=dict(cache_storage=\"../gcs_cache\"),\n",
@@ -69,6 +75,26 @@
6975
"ag3.plot_frequencies_time_series(ds, height=500, width=1000)"
7076
]
7177
},
78+
{
79+
"cell_type": "code",
80+
"execution_count": null,
81+
"id": "790c99e8",
82+
"metadata": {},
83+
"outputs": [],
84+
"source": [
85+
"ag3.plot_frequencies_time_series(ds, taxa=\"gambiae\", height=500, width=1000)"
86+
]
87+
},
88+
{
89+
"cell_type": "code",
90+
"execution_count": null,
91+
"id": "1bfc7298",
92+
"metadata": {},
93+
"outputs": [],
94+
"source": [
95+
"ag3.plot_frequencies_time_series(ds, taxa=(\"gambiae\", \"arabiensis\"), height=500, width=1000)"
96+
]
97+
},
7298
{
7399
"cell_type": "code",
74100
"execution_count": null,
@@ -252,6 +278,26 @@
252278
"ag3.plot_frequencies_time_series(ds, height=900, width=900)"
253279
]
254280
},
281+
{
282+
"cell_type": "code",
283+
"execution_count": null,
284+
"id": "e16ab3fe",
285+
"metadata": {},
286+
"outputs": [],
287+
"source": [
288+
"ag3.plot_frequencies_time_series(ds, areas=\"BF-09\", height=400, width=900)"
289+
]
290+
},
291+
{
292+
"cell_type": "code",
293+
"execution_count": null,
294+
"id": "26af27a1",
295+
"metadata": {},
296+
"outputs": [],
297+
"source": [
298+
"ag3.plot_frequencies_time_series(ds, areas=(\"BF-09\", \"TZ-25\"), height=400, width=900)"
299+
]
300+
},
255301
{
256302
"cell_type": "code",
257303
"execution_count": null,
@@ -336,19 +382,11 @@
336382
"source": [
337383
"af1.plot_frequencies_interactive_map(ds)"
338384
]
339-
},
340-
{
341-
"cell_type": "code",
342-
"execution_count": null,
343-
"id": "a512b459",
344-
"metadata": {},
345-
"outputs": [],
346-
"source": []
347385
}
348386
],
349387
"metadata": {
350388
"kernelspec": {
351-
"display_name": "Python 3 (ipykernel)",
389+
"display_name": "mgen_data_py3.11",
352390
"language": "python",
353391
"name": "python3"
354392
},
@@ -362,7 +400,7 @@
362400
"name": "python",
363401
"nbconvert_exporter": "python",
364402
"pygments_lexer": "ipython3",
365-
"version": "3.10.12"
403+
"version": "3.11.5"
366404
},
367405
"widgets": {
368406
"application/vnd.jupyter.widget-state+json": {

tests/anoph/test_snp_frq.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,6 +1528,102 @@ def test_plot_frequencies_time_series(
15281528
assert isinstance(fig, go.Figure)
15291529

15301530

1531+
@parametrize_with_cases("fixture,api", cases=".")
1532+
def test_plot_frequencies_time_series_with_taxa(
1533+
fixture,
1534+
api: AnophelesSnpFrequencyAnalysis,
1535+
):
1536+
# Pick test parameters at random.
1537+
all_sample_sets = api.sample_sets()["sample_set"].to_list()
1538+
sample_sets = random.choice(all_sample_sets)
1539+
site_mask = random.choice(api.site_mask_ids + (None,))
1540+
min_cohort_size = random.randint(0, 2)
1541+
transcript = random_transcript(api=api).name
1542+
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
1543+
period_by = random.choice(["year", "quarter", "month"])
1544+
1545+
# Pick a random taxon and taxa from valid taxa.
1546+
sample_sets_taxa = (
1547+
api.sample_metadata(sample_sets=sample_sets)["taxon"].dropna().unique().tolist()
1548+
)
1549+
taxon = random.choice(sample_sets_taxa)
1550+
taxa = random.sample(sample_sets_taxa, random.randint(1, len(sample_sets_taxa)))
1551+
1552+
# Compute SNP frequencies.
1553+
ds = api.snp_allele_frequencies_advanced(
1554+
transcript=transcript,
1555+
area_by=area_by,
1556+
period_by=period_by,
1557+
sample_sets=sample_sets,
1558+
min_cohort_size=min_cohort_size,
1559+
site_mask=site_mask,
1560+
)
1561+
1562+
# Trim things down a bit for speed.
1563+
ds = ds.isel(variants=slice(0, 100))
1564+
1565+
# Plot with taxon.
1566+
fig = api.plot_frequencies_time_series(ds, show=False, taxa=taxon)
1567+
1568+
# Test taxon plot.
1569+
assert isinstance(fig, go.Figure)
1570+
1571+
# Plot with taxa.
1572+
fig = api.plot_frequencies_time_series(ds, show=False, taxa=taxa)
1573+
1574+
# Test taxa plot.
1575+
assert isinstance(fig, go.Figure)
1576+
1577+
1578+
@parametrize_with_cases("fixture,api", cases=".")
1579+
def test_plot_frequencies_time_series_with_areas(
1580+
fixture,
1581+
api: AnophelesSnpFrequencyAnalysis,
1582+
):
1583+
# Pick test parameters at random.
1584+
all_sample_sets = api.sample_sets()["sample_set"].to_list()
1585+
sample_sets = random.choice(all_sample_sets)
1586+
site_mask = random.choice(api.site_mask_ids + (None,))
1587+
min_cohort_size = random.randint(0, 2)
1588+
transcript = random_transcript(api=api).name
1589+
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
1590+
period_by = random.choice(["year", "quarter", "month"])
1591+
1592+
# Compute SNP frequencies.
1593+
ds = api.snp_allele_frequencies_advanced(
1594+
transcript=transcript,
1595+
area_by=area_by,
1596+
period_by=period_by,
1597+
sample_sets=sample_sets,
1598+
min_cohort_size=min_cohort_size,
1599+
site_mask=site_mask,
1600+
)
1601+
1602+
# Trim things down a bit for speed.
1603+
ds = ds.isel(variants=slice(0, 100))
1604+
1605+
# Extract cohorts into a DataFrame.
1606+
cohort_vars = [v for v in ds if str(v).startswith("cohort_")]
1607+
df_cohorts = ds[cohort_vars].to_dataframe()
1608+
1609+
# Pick a random area and areas from valid areas.
1610+
cohorts_areas = df_cohorts["cohort_area"].dropna().unique().tolist()
1611+
area = random.choice(cohorts_areas)
1612+
areas = random.sample(cohorts_areas, random.randint(1, len(cohorts_areas)))
1613+
1614+
# Plot with area.
1615+
fig = api.plot_frequencies_time_series(ds, show=False, areas=area)
1616+
1617+
# Test areas plot.
1618+
assert isinstance(fig, go.Figure)
1619+
1620+
# Plot with areas.
1621+
fig = api.plot_frequencies_time_series(ds, show=False, areas=areas)
1622+
1623+
# Test area plot.
1624+
assert isinstance(fig, go.Figure)
1625+
1626+
15311627
@parametrize_with_cases("fixture,api", cases=".")
15321628
def test_plot_frequencies_interactive_map(
15331629
fixture,

0 commit comments

Comments
 (0)