diff --git a/malariagen_data/anoph/pca.py b/malariagen_data/anoph/pca.py index 18cd0bdb..bde19495 100644 --- a/malariagen_data/anoph/pca.py +++ b/malariagen_data/anoph/pca.py @@ -43,6 +43,9 @@ def __init__( `site_class`, `cohort_size`, `min_cohort_size`, `max_cohort_size`, `random_seed`. + .. versionchanged:: 9.0.0 + The `cohorts` parameter has been added to enable cohort-based + downsampling via the `max_cohort_size` parameter. """, returns=("df_pca", "evr"), notes=""" @@ -50,6 +53,18 @@ def __init__( environment. Results of this computation will be cached and re-used if the `results_cache` parameter was set when instantiating the API client. """, + examples=""" + Run a PCA, downsampling to a maximum of 20 samples per country:: + + >>> import malariagen_data + >>> ag3 = malariagen_data.Ag3() + >>> df_pca, evr = ag3.pca( + ... region="3R", + ... n_snps=1000, + ... cohorts="country", + ... max_cohort_size=20, + ... ) + """, ) def pca( self, @@ -61,6 +76,10 @@ def pca( sample_query: Optional[base_params.sample_query] = None, sample_query_options: Optional[base_params.sample_query_options] = None, sample_indices: Optional[base_params.sample_indices] = None, + cohorts: Optional[base_params.cohorts] = None, + cohort_size: Optional[base_params.cohort_size] = None, + min_cohort_size: Optional[base_params.min_cohort_size] = None, + max_cohort_size: Optional[base_params.max_cohort_size] = None, site_mask: Optional[base_params.site_mask] = base_params.DEFAULT, site_class: Optional[base_params.site_class] = None, min_minor_ac: Optional[ @@ -69,9 +88,6 @@ def pca( max_missing_an: Optional[ base_params.max_missing_an ] = pca_params.max_missing_an_default, - cohort_size: Optional[base_params.cohort_size] = None, - min_cohort_size: Optional[base_params.min_cohort_size] = None, - max_cohort_size: Optional[base_params.max_cohort_size] = None, exclude_samples: Optional[base_params.samples] = None, fit_exclude_samples: Optional[base_params.samples] = None, random_seed: base_params.random_seed = 42, @@ -82,6 +98,43 @@ def pca( # invalidate any previously cached data. name = "pca_v4" + # Handle cohort downsampling. + if cohorts is not None: + if max_cohort_size is None: + raise ValueError( + "`max_cohort_size` is required when `cohorts` is provided." + ) + if sample_indices is not None: + raise ValueError( + "Cannot use `sample_indices` with `cohorts` and `max_cohort_size`." + ) + if cohort_size is not None or min_cohort_size is not None: + raise ValueError( + "Cannot use `cohort_size` or `min_cohort_size` with `cohorts`." + ) + df_samples = self.sample_metadata( + sample_sets=sample_sets, + sample_query=sample_query, + sample_query_options=sample_query_options, + ) + # N.B., we are going to overwrite the sample_indices parameter here. + groups = df_samples.groupby(cohorts, sort=False) + ix = [] + for _, group in groups: + if len(group) > max_cohort_size: + ix.extend( + group.sample( + n=max_cohort_size, random_state=random_seed, replace=False + ).index + ) + else: + ix.extend(group.index) + sample_indices = ix + # From this point onwards, the sample_query is no longer needed, because + # the sample selection is defined by the sample_indices. + sample_query = None + sample_query_options = None + # Normalize params for consistent hash value. ( sample_sets_prepped, @@ -105,6 +158,7 @@ def pca( min_minor_ac=min_minor_ac, max_missing_an=max_missing_an, n_components=n_components, + cohorts=cohorts, cohort_size=cohort_size, min_cohort_size=min_cohort_size, max_cohort_size=max_cohort_size, @@ -122,10 +176,10 @@ def pca( self.results_cache_set(name=name, params=params, results=results) # Unpack results. - coords = results["coords"] - evr = results["evr"] - samples = results["samples"] - loc_keep_fit = results["loc_keep_fit"] + coords = np.array(results["coords"]) + evr = np.array(results["evr"]) + samples = np.array(results["samples"]) + loc_keep_fit = np.array(results["loc_keep_fit"]) # Load sample metadata. df_samples = self.sample_metadata( @@ -166,6 +220,7 @@ def _pca( random_seed, chunks, inline_array, + **kwargs, ): # Load diplotypes. gn, samples = self.biallelic_diplotypes( diff --git a/notebooks/plot_pca.ipynb b/notebooks/plot_pca.ipynb index 1ff57aaa..fcda6d53 100644 --- a/notebooks/plot_pca.ipynb +++ b/notebooks/plot_pca.ipynb @@ -620,10 +620,38 @@ ")" ] }, + { + "cell_type": "markdown", + "id": "f1e8c954", + "metadata": {}, + "source": [ + "## PCA with cohort downsampling" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e4a484f3", + "metadata": {}, + "outputs": [], + "source": [ + "df_pca_cohorts, evr_cohorts = ag3.pca(\n", + " region=\"3L:15,000,000-16,000,000\",\n", + " sample_sets=\"3.0\",\n", + " n_snps=10_000,\n", + " cohorts=\"country\",\n", + " max_cohort_size=20,\n", + ")\n", + "ag3.plot_pca_coords(\n", + " df_pca_cohorts,\n", + " color=\"country\",\n", + ")" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "33d788a2-f256-4930-b1e5-b4f31e681a36", + "id": "abb2ee83", "metadata": {}, "outputs": [], "source": [] diff --git a/tests/anoph/test_pca.py b/tests/anoph/test_pca.py index e5fa667a..41a8f7c9 100644 --- a/tests/anoph/test_pca.py +++ b/tests/anoph/test_pca.py @@ -288,3 +288,82 @@ def test_pca_fit_exclude_samples(fixture, api: AnophelesPca): len(pca_df.query(f"sample_id in {exclude_samples} and not pca_fit")) == n_samples_excluded ) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_pca_cohort_downsampling(fixture, api: AnophelesPca): + # Parameters for selecting input data. + all_sample_sets = api.sample_sets()["sample_set"].to_list() + sample_sets = random.sample(all_sample_sets, 2) + data_params = dict( + region=random.choice(api.contigs), + sample_sets=sample_sets, + site_mask=random.choice((None,) + api.site_mask_ids), + ) + + # Test cohort downsampling. + cohort_col = "country" + max_cohort_size = 10 + random_seed = 42 + + # Try to run the PCA with cohort downsampling. + try: + pca_df, pca_evr = api.pca( + n_snps=100, # Use a small number to avoid "Not enough SNPs" errors + n_components=2, + cohorts=cohort_col, + max_cohort_size=max_cohort_size, + random_seed=random_seed, + **data_params, + ) + except ValueError as e: + if "Not enough SNPs" in str(e): + pytest.skip("Not enough SNPs available after downsampling to run test.") + else: + raise + + # Check types. + assert isinstance(pca_df, pd.DataFrame) + assert isinstance(pca_evr, np.ndarray) + + # Check basic structure. + assert len(pca_df) > 0 + assert "PC1" in pca_df.columns + assert "PC2" in pca_df.columns + assert "pca_fit" in pca_df.columns + assert pca_df["pca_fit"].all() + assert pca_evr.ndim == 1 + assert pca_evr.shape[0] == 2 + + # Check cohort counts. + final_cohort_counts = pca_df[cohort_col].value_counts() + for cohort, count in final_cohort_counts.items(): + assert count <= max_cohort_size + + # Test bad parameter combinations. + with pytest.raises(ValueError): + api.pca( + n_snps=100, + n_components=2, + cohorts=cohort_col, + # max_cohort_size is missing + **data_params, + ) + with pytest.raises(ValueError): + api.pca( + n_snps=100, + n_components=2, + cohorts=cohort_col, + max_cohort_size=max_cohort_size, + sample_indices=[0, 1, 2], + **data_params, + ) + with pytest.raises(ValueError): + api.pca( + n_snps=100, + n_components=2, + cohorts=cohort_col, + max_cohort_size=max_cohort_size, + cohort_size=10, + **data_params, + )