Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 62 additions & 7 deletions malariagen_data/anoph/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,28 @@ def __init__(
`site_class`, `cohort_size`, `min_cohort_size`, `max_cohort_size`,
`random_seed`.

.. versionchanged:: 9.0.0
The `cohorts` parameter has been added to enable cohort-based
downsampling via the `max_cohort_size` parameter.
""",
returns=("df_pca", "evr"),
notes="""
This computation may take some time to run, depending on your computing
environment. Results of this computation will be cached and re-used if
the `results_cache` parameter was set when instantiating the API client.
""",
examples="""
Run a PCA, downsampling to a maximum of 20 samples per country::

>>> import malariagen_data
>>> ag3 = malariagen_data.Ag3()
>>> df_pca, evr = ag3.pca(
... region="3R",
... n_snps=1000,
... cohorts="country",
... max_cohort_size=20,
... )
""",
)
def pca(
self,
Expand All @@ -61,6 +76,10 @@ def pca(
sample_query: Optional[base_params.sample_query] = None,
sample_query_options: Optional[base_params.sample_query_options] = None,
sample_indices: Optional[base_params.sample_indices] = None,
cohorts: Optional[base_params.cohorts] = None,
cohort_size: Optional[base_params.cohort_size] = None,
Copy link
Preview

Copilot AI Jun 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The cohort_size and min_cohort_size parameters are accepted but not implemented beyond raising errors; consider either implementing their logic or removing them to avoid confusion.

Copilot uses AI. Check for mistakes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to note that Copilot's nitpick above does not appear to be true because cohort_size and min_cohort_size are either used to retrieve the cached results (which are based on certain params, including these) or they are passed to the _pca function, i.e.

        params = dict(
            region=region_prepped,
            n_snps=n_snps,
            thin_offset=thin_offset,
            sample_sets=sample_sets_prepped,
            sample_indices=sample_indices_prepped,
            site_mask=site_mask_prepped,
            site_class=site_class,
            min_minor_ac=min_minor_ac,
            max_missing_an=max_missing_an,
            n_components=n_components,
            cohorts=cohorts,
            cohort_size=cohort_size,
            min_cohort_size=min_cohort_size,
            max_cohort_size=max_cohort_size,
            exclude_samples=exclude_samples,
            fit_exclude_samples=fit_exclude_samples,
            random_seed=random_seed,
        )

        # Try to retrieve results from the cache.
        try:
            results = self.results_cache_get(name=name, params=params)

        except CacheMiss:
            results = self._pca(chunks=chunks, inline_array=inline_array, **params)
            self.results_cache_set(name=name, params=params, results=results)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks @leehart. We decided on Friday that Copilot was less nitpicking and more plain old wrong ;).

min_cohort_size: Optional[base_params.min_cohort_size] = None,
max_cohort_size: Optional[base_params.max_cohort_size] = None,
site_mask: Optional[base_params.site_mask] = base_params.DEFAULT,
site_class: Optional[base_params.site_class] = None,
min_minor_ac: Optional[
Expand All @@ -69,9 +88,6 @@ def pca(
max_missing_an: Optional[
base_params.max_missing_an
] = pca_params.max_missing_an_default,
cohort_size: Optional[base_params.cohort_size] = None,
min_cohort_size: Optional[base_params.min_cohort_size] = None,
max_cohort_size: Optional[base_params.max_cohort_size] = None,
exclude_samples: Optional[base_params.samples] = None,
fit_exclude_samples: Optional[base_params.samples] = None,
random_seed: base_params.random_seed = 42,
Expand All @@ -82,6 +98,43 @@ def pca(
# invalidate any previously cached data.
name = "pca_v4"

# Handle cohort downsampling.
if cohorts is not None:
if max_cohort_size is None:
raise ValueError(
"`max_cohort_size` is required when `cohorts` is provided."
)
if sample_indices is not None:
raise ValueError(
"Cannot use `sample_indices` with `cohorts` and `max_cohort_size`."
)
if cohort_size is not None or min_cohort_size is not None:
raise ValueError(
"Cannot use `cohort_size` or `min_cohort_size` with `cohorts`."
)
df_samples = self.sample_metadata(
sample_sets=sample_sets,
sample_query=sample_query,
sample_query_options=sample_query_options,
)
# N.B., we are going to overwrite the sample_indices parameter here.
groups = df_samples.groupby(cohorts, sort=False)
ix = []
for _, group in groups:
if len(group) > max_cohort_size:
ix.extend(
group.sample(
n=max_cohort_size, random_state=random_seed, replace=False
).index
)
else:
ix.extend(group.index)
sample_indices = ix
# From this point onwards, the sample_query is no longer needed, because
# the sample selection is defined by the sample_indices.
sample_query = None
sample_query_options = None

# Normalize params for consistent hash value.
(
sample_sets_prepped,
Expand All @@ -105,6 +158,7 @@ def pca(
min_minor_ac=min_minor_ac,
max_missing_an=max_missing_an,
n_components=n_components,
cohorts=cohorts,
cohort_size=cohort_size,
min_cohort_size=min_cohort_size,
max_cohort_size=max_cohort_size,
Expand All @@ -122,10 +176,10 @@ def pca(
self.results_cache_set(name=name, params=params, results=results)

# Unpack results.
coords = results["coords"]
evr = results["evr"]
samples = results["samples"]
loc_keep_fit = results["loc_keep_fit"]
coords = np.array(results["coords"])
evr = np.array(results["evr"])
samples = np.array(results["samples"])
loc_keep_fit = np.array(results["loc_keep_fit"])

# Load sample metadata.
df_samples = self.sample_metadata(
Expand Down Expand Up @@ -166,6 +220,7 @@ def _pca(
random_seed,
chunks,
inline_array,
**kwargs,
):
# Load diplotypes.
gn, samples = self.biallelic_diplotypes(
Expand Down
30 changes: 29 additions & 1 deletion notebooks/plot_pca.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -620,10 +620,38 @@
")"
]
},
{
"cell_type": "markdown",
"id": "f1e8c954",
"metadata": {},
"source": [
"## PCA with cohort downsampling"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e4a484f3",
"metadata": {},
"outputs": [],
"source": [
"df_pca_cohorts, evr_cohorts = ag3.pca(\n",
" region=\"3L:15,000,000-16,000,000\",\n",
" sample_sets=\"3.0\",\n",
" n_snps=10_000,\n",
" cohorts=\"country\",\n",
" max_cohort_size=20,\n",
")\n",
"ag3.plot_pca_coords(\n",
" df_pca_cohorts,\n",
" color=\"country\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "33d788a2-f256-4930-b1e5-b4f31e681a36",
"id": "abb2ee83",
"metadata": {},
"outputs": [],
"source": []
Expand Down
79 changes: 79 additions & 0 deletions tests/anoph/test_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,82 @@ def test_pca_fit_exclude_samples(fixture, api: AnophelesPca):
len(pca_df.query(f"sample_id in {exclude_samples} and not pca_fit"))
== n_samples_excluded
)


@parametrize_with_cases("fixture,api", cases=".")
def test_pca_cohort_downsampling(fixture, api: AnophelesPca):
# Parameters for selecting input data.
all_sample_sets = api.sample_sets()["sample_set"].to_list()
Copy link
Preview

Copilot AI Jun 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Using random.sample and random.choice introduces non-determinism into the test, which can lead to flaky failures; consider using fixed test inputs or seeding the random module at the start of the test.

Suggested change
all_sample_sets = api.sample_sets()["sample_set"].to_list()
all_sample_sets = api.sample_sets()["sample_set"].to_list()
random.seed(random_seed) # Seed the random module for deterministic sampling

Copilot uses AI. Check for mistakes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to note that these "flaky failures" seem to be a design decision, rather than an oversight, in order to cover a greater range of variables than would be covered by fixed test inputs.

For what it's worth, I do find it a little annoying when these random tests failures occur on PRs that aren't modifying anything related to the code involved in the random test failures. The natural response is often to log the unrelated issue separately and then to try re-running the tests in hope of green light, which doesn't seem ideal.

sample_sets = random.sample(all_sample_sets, 2)
data_params = dict(
region=random.choice(api.contigs),
sample_sets=sample_sets,
site_mask=random.choice((None,) + api.site_mask_ids),
)

# Test cohort downsampling.
cohort_col = "country"
max_cohort_size = 10
random_seed = 42

# Try to run the PCA with cohort downsampling.
try:
pca_df, pca_evr = api.pca(
n_snps=100, # Use a small number to avoid "Not enough SNPs" errors
n_components=2,
cohorts=cohort_col,
max_cohort_size=max_cohort_size,
random_seed=random_seed,
**data_params,
)
except ValueError as e:
if "Not enough SNPs" in str(e):
pytest.skip("Not enough SNPs available after downsampling to run test.")
else:
raise

# Check types.
assert isinstance(pca_df, pd.DataFrame)
assert isinstance(pca_evr, np.ndarray)

# Check basic structure.
assert len(pca_df) > 0
assert "PC1" in pca_df.columns
assert "PC2" in pca_df.columns
assert "pca_fit" in pca_df.columns
assert pca_df["pca_fit"].all()
assert pca_evr.ndim == 1
assert pca_evr.shape[0] == 2

# Check cohort counts.
final_cohort_counts = pca_df[cohort_col].value_counts()
for cohort, count in final_cohort_counts.items():
assert count <= max_cohort_size

# Test bad parameter combinations.
with pytest.raises(ValueError):
api.pca(
n_snps=100,
n_components=2,
cohorts=cohort_col,
# max_cohort_size is missing
**data_params,
)
with pytest.raises(ValueError):
api.pca(
n_snps=100,
n_components=2,
cohorts=cohort_col,
max_cohort_size=max_cohort_size,
sample_indices=[0, 1, 2],
**data_params,
)
with pytest.raises(ValueError):
api.pca(
n_snps=100,
n_components=2,
cohorts=cohort_col,
max_cohort_size=max_cohort_size,
cohort_size=10,
**data_params,
)
Loading