diff --git a/pyproject.toml b/pyproject.toml index 0930bbfc2..f583a8714 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,9 @@ dependencies = [ "matplotlib>=3.9.0", "numpy", "xarray", - "pytorch-metric-learning>2.0.0"] + "pytorch-metric-learning>2.0.0", + "anndata>=0.12.2", +] dynamic = ["version"] [project.optional-dependencies] diff --git a/tests/representation/test_annotations.py b/tests/representation/test_annotations.py new file mode 100644 index 000000000..7d9767f59 --- /dev/null +++ b/tests/representation/test_annotations.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +from pathlib import Path + +import anndata as ad +import numpy as np +import pandas as pd +import pytest +import xarray as xr +from iohub import open_ome_zarr +from pytest import TempPathFactory + +from viscy.representation.evaluation import ( + convert_xarray_annotation_to_anndata, + load_annotation_anndata, +) + + +@pytest.fixture(scope="function") +def xr_embeddings_dataset( + tracks_hcs_dataset: Path, tmp_path_factory: TempPathFactory +) -> Path: + """ + Provides a mock xarray embeddings dataset with tracking information from tracks_hcs_dataset. + + Parameters + ---------- + tracks_hcs_dataset : Path + Path to the HCS dataset with tracking CSV files. + + Returns + ------- + Path + Path to the zarr store containing the embeddings dataset. + """ + dataset_path = tmp_path_factory.mktemp("xr_embeddings.zarr") + + all_tracks = [] + + dataset = open_ome_zarr(tracks_hcs_dataset) + + for fov_name, _ in dataset.positions(): + tracks_csv_path = tracks_hcs_dataset / fov_name / "tracks.csv" + tracks_df = pd.read_csv(tracks_csv_path) + tracks_df["fov_name"] = fov_name + all_tracks.append(tracks_df) + + # Combine all tracks + tracks_df = pd.concat(all_tracks, ignore_index=True) + + n_samples = len(tracks_df) + n_features = 32 + + rng = np.random.default_rng(42) + + # Generate synthetic features (embeddings) + features = rng.normal(size=(n_samples, n_features)).astype(np.float32) + + # Create coordinates (PCA, UMAP, PHATE, projections) + pca_coords = rng.normal(size=(n_samples, 2)).astype(np.float32) + umap_coords = rng.normal(size=(n_samples, 2)).astype(np.float32) + phate_coords = rng.normal(size=(n_samples, 2)).astype(np.float32) + projections = rng.normal(size=(n_samples, 2)).astype( + np.float32 + ) # 2 projection dims + + # Create the xarray dataset + ds = xr.Dataset( + data_vars={ + "features": (["sample", "feature"], features), + }, + coords={ + "fov_name": ("sample", tracks_df["fov_name"]), + "track_id": ("sample", tracks_df["track_id"]), + "t": ("sample", tracks_df["t"]), + "id": ("sample", tracks_df["id"]), + "parent_track_id": ("sample", tracks_df["parent_track_id"]), + "parent_id": ("sample", tracks_df["parent_id"]), + "y": ("sample", tracks_df["y"]), + "x": ("sample", tracks_df["x"]), + "PCA_0": ("sample", pca_coords[:, 0]), + "PCA_1": ("sample", pca_coords[:, 1]), + "UMAP_0": ("sample", umap_coords[:, 0]), + "UMAP_1": ("sample", umap_coords[:, 1]), + "PHATE_0": ("sample", phate_coords[:, 0]), + "PHATE_1": ("sample", phate_coords[:, 1]), + "projections": (["sample", "projection"], projections), + "sample": np.arange(n_samples), + "feature": np.arange(n_features), + "projection": np.arange(projections.shape[1]), + }, + ) + + # Save to zarr + ds.to_zarr(dataset_path) + + return dataset_path + + +@pytest.fixture(scope="function") +def anndata_embeddings( + xr_embeddings_dataset: Path, tmp_path_factory: TempPathFactory +) -> Path: + """ + Provides an AnnData zarr store created from xarray embeddings dataset. + + Parameters + ---------- + xr_embeddings_dataset : Path + Path to the xarray embeddings dataset. + + Returns + ------- + Path + Path to the AnnData zarr store. + """ + rng = np.random.default_rng(42) + + # Create output path for AnnData + adata_path = tmp_path_factory.mktemp("anndata_embeddings.zarr") + + # Load the xarray dataset + embeddings_ds = xr.open_zarr(xr_embeddings_dataset) + + # Extract features as X matrix + n_samples = len(embeddings_ds.coords["sample"]) + + X = rng.normal(size=(n_samples, 32)).astype(np.float32) + + obs_data = { + "id": embeddings_ds.coords["id"].values, + "fov_name": embeddings_ds.coords["fov_name"].values.astype(str), + "track_id": embeddings_ds.coords["track_id"].values, + "parent_track_id": embeddings_ds.coords["parent_track_id"].values, + "parent_id": embeddings_ds.coords["parent_id"].values, + "t": embeddings_ds.coords["t"].values, + "y": embeddings_ds.coords["y"].values, + "x": embeddings_ds.coords["x"].values, + } + obs_df = pd.DataFrame(obs_data) + + # Get the number of samples from the dataset + + adata = ad.AnnData( + X=X, + obs=obs_df, + obsm={ + "X_projections": rng.normal(size=(n_samples, 3)).astype(np.float32), + "X_pca": rng.normal(size=(n_samples, 3)).astype(np.float32), + "X_umap": rng.uniform(-10, 10, size=(n_samples, 3)).astype(np.float32), + "X_phate": rng.normal(scale=0.5, size=(n_samples, 3)).astype(np.float32), + }, + ) + + # Write to zarr + adata.write_zarr(adata_path) + + return adata_path + + +def test_convert_xarray_annotation_to_anndata(xr_embeddings_dataset, tmp_path): + """Test that convert_xarray_annotation_to_anndata correctly converts xarray to AnnData.""" + # Load the xarray dataset + embeddings_ds = xr.open_zarr(xr_embeddings_dataset) + + # Define output path + output_path = tmp_path / "test_converted.zarr" + + # Convert to AnnData using the function we're testing + adata_result = convert_xarray_annotation_to_anndata( + embeddings_ds=embeddings_ds, + output_path=output_path, + overwrite=True, + return_anndata=True, + ) + # Second conversion with overwrite=False should raise FileExistsError + with pytest.raises( + FileExistsError, match=f"Output path {output_path} already exists" + ): + convert_xarray_annotation_to_anndata( + embeddings_ds=embeddings_ds, + output_path=output_path, + overwrite=False, + return_anndata=False, + ) + + assert isinstance(adata_result, ad.AnnData) + + assert output_path.exists() + + adata_loaded = ad.read_zarr(output_path) + + np.testing.assert_allclose(adata_loaded.X, embeddings_ds["features"].values) + + # Verify obs columns + expected_obs_columns = [ + "id", + "fov_name", + "track_id", + "parent_track_id", + "parent_id", + "t", + "y", + "x", + ] + for col in expected_obs_columns: + assert col in adata_loaded.obs.columns + if col == "fov_name": + assert list(adata_loaded.obs[col]) == list( + embeddings_ds.coords[col].values.astype(str) + ) + else: + np.testing.assert_allclose( + adata_loaded.obs[col].values, embeddings_ds.coords[col].values + ) + + # Verify obsm (embeddings) + assert all( + embedding_key in adata_loaded.obsm + for embedding_key in ["X_pca", "X_umap", "X_phate", "X_projections"] + ) + + # Check projections + np.testing.assert_allclose( + adata_loaded.obsm["X_projections"], embeddings_ds.coords["projections"].values + ) + + # Check PCA + np.testing.assert_allclose( + adata_loaded.obsm["X_pca"][:, 0], embeddings_ds.coords["PCA_0"].values + ) + np.testing.assert_allclose( + adata_loaded.obsm["X_pca"][:, 1], embeddings_ds.coords["PCA_1"].values + ) + + # Check UMAP + np.testing.assert_allclose( + adata_loaded.obsm["X_umap"][:, 0], embeddings_ds.coords["UMAP_0"].values + ) + np.testing.assert_allclose( + adata_loaded.obsm["X_umap"][:, 1], embeddings_ds.coords["UMAP_1"].values + ) + + # Check PHATE + np.testing.assert_allclose( + adata_loaded.obsm["X_phate"][:, 0], embeddings_ds.coords["PHATE_0"].values + ) + np.testing.assert_allclose( + adata_loaded.obsm["X_phate"][:, 1], embeddings_ds.coords["PHATE_1"].values + ) + + +def test_load_annotation_anndata(tracks_hcs_dataset, anndata_embeddings, tmp_path): + """Test that load_annotation_anndata correctly loads annotations from an AnnData object.""" + # Load the AnnData object + adata = ad.read_zarr(anndata_embeddings) + + A11_annotations_path = tracks_hcs_dataset / "A" / "1" / "1" / "tracks.csv" + + A11_annotations_df = pd.read_csv(A11_annotations_path) + + rng = np.random.default_rng(42) + A11_annotations_df["fov_name"] = "A/1/1" + A11_annotations_df["infection_state"] = rng.choice( + [-1, 0, 1], size=len(A11_annotations_df) + ) + + # Save the modified annotations to a new CSV file + annotations_path = tmp_path / "test_annotations.csv" + A11_annotations_df.to_csv(annotations_path, index=False) + + # Test the function with the new CSV file + result = load_annotation_anndata(adata, str(annotations_path), "infection_state") + + assert len(result) == 2 # Only 2 observations from A/1/1 have annotations + + expected_values = A11_annotations_df["infection_state"].values + actual_values = result.values + np.testing.assert_array_equal(actual_values, expected_values) + + # Verify the index structure + assert result.index.names == ["fov_name", "id"] + assert all(result.index.get_level_values("fov_name") == "A/1/1") diff --git a/viscy/representation/embedding_writer.py b/viscy/representation/embedding_writer.py index 9bfa2c5bc..2eb464c64 100644 --- a/viscy/representation/embedding_writer.py +++ b/viscy/representation/embedding_writer.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Any, Dict, Literal, Optional, Sequence +import anndata as ad import numpy as np import pandas as pd import torch @@ -18,10 +19,46 @@ compute_phate, ) -__all__ = ["read_embedding_dataset", "EmbeddingWriter", "write_embedding_dataset"] +__all__ = [ + "read_embedding_dataset", + "EmbeddingWriter", + "write_embedding_dataset", + "get_available_index_columns", +] _logger = logging.getLogger("lightning.pytorch") +def get_available_index_columns( + dataset: Dataset, dataset_path: str | None = None +) -> list[str]: + """ + Get available index columns from a dataset with logging for missing columns. + + Parameters + ---------- + dataset : Dataset + The xarray dataset to check for index columns. + dataset_path : str, optional + Path to the dataset for logging purposes. If None, uses generic message. + + Returns + ------- + list[str] + List of available index columns from INDEX_COLUMNS. + """ + available_cols = [col for col in INDEX_COLUMNS if col in dataset.coords] + missing_cols = set(INDEX_COLUMNS) - set(available_cols) + + if missing_cols: + path_msg = f" at {dataset_path}" if dataset_path else "" + _logger.warning( + f"Dataset{path_msg} is missing index columns: {sorted(missing_cols)}. " + "This appears to be a legacy dataset format." + ) + + return available_cols + + def read_embedding_dataset(path: Path) -> Dataset: """ Read the embedding dataset written by the EmbeddingWriter callback. @@ -38,17 +75,7 @@ def read_embedding_dataset(path: Path) -> Dataset: Xarray dataset with features and projections. """ dataset = open_zarr(path) - # Check which index columns are present in the dataset - available_cols = [col for col in INDEX_COLUMNS if col in dataset.coords] - - # Warn if any INDEX_COLUMNS are missing - missing_cols = set(INDEX_COLUMNS) - set(available_cols) - if missing_cols: - _logger.warning( - f"Dataset at {path} is missing index columns: {sorted(missing_cols)}. " - "This appears to be a legacy dataset format." - ) - + available_cols = get_available_index_columns(dataset, str(path)) return dataset.set_index(sample=available_cols) @@ -70,7 +97,7 @@ def write_embedding_dataset( overwrite: bool = False, ) -> None: """ - Write embeddings to a zarr store in an Xarray-compatible format. + Write embeddings to an AnnData Zarr Store. Parameters ---------- @@ -118,8 +145,13 @@ def write_embedding_dataset( # Create a copy of the index DataFrame to avoid modifying the original ultrack_indices = index_df.copy() + ultrack_indices["fov_name"] = ultrack_indices["fov_name"].str.strip("/") n_samples = len(features) + adata = ad.AnnData(X=features, obs=ultrack_indices) + if projections is not None: + adata.obsm["X_projections"] = projections + # Set up default kwargs for each method if umap_kwargs: if umap_kwargs["n_neighbors"] >= n_samples: @@ -130,8 +162,7 @@ def write_embedding_dataset( _logger.debug(f"Using UMAP kwargs: {umap_kwargs}") _, UMAP = _fit_transform_umap(features, **umap_kwargs) - for i in range(UMAP.shape[1]): - ultrack_indices[f"UMAP{i + 1}"] = UMAP[:, i] + adata.obsm["X_umap"] = UMAP if phate_kwargs: # Update with user-provided kwargs @@ -147,8 +178,7 @@ def write_embedding_dataset( try: _logger.debug("Computing PHATE") _, PHATE = compute_phate(features, **phate_kwargs) - for i in range(PHATE.shape[1]): - ultrack_indices[f"PHATE{i + 1}"] = PHATE[:, i] + adata.obsm["X_phate"] = PHATE except Exception as e: _logger.warning(f"PHATE computation failed: {str(e)}") @@ -158,27 +188,12 @@ def write_embedding_dataset( try: _logger.debug("Computing PCA") PCA_features, _ = compute_pca(features, **pca_kwargs) - for i in range(PCA_features.shape[1]): - ultrack_indices[f"PC{i + 1}"] = PCA_features[:, i] + adata.obsm["X_pca"] = PCA_features except Exception as e: _logger.warning(f"PCA computation failed: {str(e)}") - # Create multi-index and dataset - index = pd.MultiIndex.from_frame(ultrack_indices) - - # Create dataset dictionary with features - dataset_dict = {"features": (("sample", "features"), features)} - - # Add projections if provided - if projections is not None: - dataset_dict["projections"] = (("sample", "projections"), projections) - - # Create the dataset - dataset = Dataset(dataset_dict, coords={"sample": index}).reset_index("sample") - _logger.debug(f"Writing dataset to {output_path}") - with dataset.to_zarr(output_path, mode="w") as zarr_store: - zarr_store.close() + adata.write_zarr(output_path) class EmbeddingWriter(BasePredictionWriter): diff --git a/viscy/representation/evaluation/__init__.py b/viscy/representation/evaluation/__init__.py index c474aec82..3bb2edb49 100644 --- a/viscy/representation/evaluation/__init__.py +++ b/viscy/representation/evaluation/__init__.py @@ -14,6 +14,7 @@ https://github.com/mehta-lab/dynacontrast/blob/master/analysis/gmm.py """ +import anndata as ad import pandas as pd from viscy.data.triplet import TripletDataModule @@ -42,19 +43,18 @@ def load_annotation(da, path, name, categories: dict | None = None): # Read the annotation CSV file annotation = pd.read_csv(path) - # Add a leading slash to 'fov name' column and set it as 'fov_name' - annotation["fov_name"] = "/" + annotation["fov_name"] - # Set the index of the annotation DataFrame to ['fov_name', 'id'] + annotation["fov_name"] = annotation["fov_name"].str.strip("/") annotation = annotation.set_index(["fov_name", "id"]) # Create a MultiIndex from the dataset array's 'fov_name' and 'id' values mi = pd.MultiIndex.from_arrays( - [da["fov_name"].values, da["id"].values], names=["fov_name", "id"] + [da["fov_name"].to_pandas().str.strip("/"), da["id"].values], + names=["fov_name", "id"], ) - # Select the annotations corresponding to the MultiIndex - selected = annotation.loc[mi][name] + # This will return NaN for observations that don't have annotations, then just drop'em + selected = annotation.reindex(mi)[name].dropna() # If categories are provided, rename the categories in the selected annotations if categories: @@ -63,6 +63,41 @@ def load_annotation(da, path, name, categories: dict | None = None): return selected +def load_annotation_anndata( + adata: ad.AnnData, path: str, name: str, categories: dict | None = None +): + """ + Load annotations from a CSV file and map them to the AnnData object. + + Parameters + ---------- + adata : anndata.AnnData + The AnnData object to map the annotations to. + path : str + Path to the CSV file containing annotations. + name : str + The column name in the CSV file to be used as annotations. + categories : dict, optional + A dictionary to rename categories in the annotation column. Default is None. + """ + annotation = pd.read_csv(path) + annotation["fov_name"] = annotation["fov_name"].str.strip("/") + + annotation = annotation.set_index(["fov_name", "id"]) + + mi = pd.MultiIndex.from_arrays( + [adata.obs["fov_name"], adata.obs["id"]], names=["fov_name", "id"] + ) + + # Use reindex to handle missing annotations gracefully + # This will return NaN for observations that don't have annotations, then just drop'em + selected = annotation.reindex(mi)[name].dropna() + + if categories: + selected = selected.astype("category").cat.rename_categories(categories) + return selected + + def dataset_of_tracks( data_path, tracks_path, diff --git a/viscy/representation/evaluation/convert_xarray_annotation_to_anndata.py b/viscy/representation/evaluation/convert_xarray_annotation_to_anndata.py new file mode 100644 index 000000000..5623bf008 --- /dev/null +++ b/viscy/representation/evaluation/convert_xarray_annotation_to_anndata.py @@ -0,0 +1,122 @@ +from pathlib import Path + +import anndata as ad +import numpy as np +import pandas as pd +import xarray as xr +from natsort import natsorted + +from viscy.representation.embedding_writer import get_available_index_columns + + +def convert_xarray_annotation_to_anndata( + embeddings_ds: xr.Dataset | Path, + output_path: Path, + overwrite: bool = False, + return_anndata: bool = False, +) -> ad.AnnData | None: + """ + Convert an Xarray embeddings dataset to an AnnData object. + + Parameters + ---------- + embeddings_ds : xr.Dataset | Path + The Xarray embeddings dataset to convert or the path to the embeddings dataset. + output_path : Path + Path to the zarr store to write the AnnData object to. + overwrite : bool, optional + Whether to overwrite existing zarr store, by default False. + return_anndata : bool, optional + Whether to return the AnnData object, by default False. + + Returns + ------- + ad.AnnData | None + The AnnData object if return_anndata is True, otherwise None. + + Raises + ------ + FileExistsError + If output_path exists and overwrite is False. + + Examples + -------- + >>> embeddings_ds = xr.open_zarr(embeddings_path) + >>> adata = convert_xarray_annotation_to_anndata(embeddings_ds, output_path, overwrite=True, return_anndata=True) + >>> adata + AnnData object with n_obs × n_vars = 18861 × 768 + obs: 'id', 'fov_name', 'track_id', 'parent_track_id', 'parent_id', 't', 'y', 'x' + obsm: 'X_projections', 'X_PCA', 'X_UMAP', 'X_PHATE' + """ + # Check if output_path exists + if output_path.exists() and not overwrite: + raise FileExistsError(f"Output path {output_path} already exists.") + + # Tracking + if isinstance(embeddings_ds, Path): + embeddings_ds = xr.open_zarr(embeddings_ds) + + available_cols = get_available_index_columns(embeddings_ds) + tracking_df = pd.DataFrame( + { + col: embeddings_ds.coords[col].data + if col != "fov_name" + else embeddings_ds.coords[col].to_pandas().str.strip("/") + for col in available_cols + } + ) + + obsm = {} + # Projections + if "projections" in embeddings_ds.coords: + obsm["X_projections"] = embeddings_ds.coords["projections"].data + + # Embeddings + for embedding in ["PCA", "UMAP", "PHATE"]: + embedding_coords = natsorted( + [coord for coord in embeddings_ds.coords if embedding in coord] + ) + if embedding_coords: + obsm[f"X_{embedding.lower()}"] = np.column_stack( + [embeddings_ds.coords[coord] for coord in embedding_coords] + ) + + # X, "expression" matrix (NN embedding features) + X = embeddings_ds["features"].data + + adata = ad.AnnData(X=X, obs=tracking_df, obsm=obsm) + + adata.write_zarr(output_path) + if return_anndata: + return adata + + +def main( + input_path: Path, + output_path: Path, + overwrite: bool = False, +): + """ + CLI entry point for converting Xarray embeddings to AnnData format. + + Parameters + ---------- + input_path : Path + Path to the input Xarray zarr store. + output_path : Path + Path to the output AnnData zarr store. + overwrite : bool, optional + Whether to overwrite existing output, by default False. + """ + return convert_xarray_annotation_to_anndata( + embeddings_ds=input_path, + output_path=output_path, + overwrite=overwrite, + return_anndata=False, + ) + + +if __name__ == "__main__": + from jsonargparse import CLI + + CLI(main) diff --git a/viscy/representation/evaluation/dimensionality_reduction.py b/viscy/representation/evaluation/dimensionality_reduction.py index ee22b7f8e..81d0194f6 100644 --- a/viscy/representation/evaluation/dimensionality_reduction.py +++ b/viscy/representation/evaluation/dimensionality_reduction.py @@ -131,12 +131,17 @@ def compute_pca(embedding_dataset, n_components=None, normalize_features=True): def _fit_transform_umap( - embeddings: NDArray, n_components: int = 2, normalize: bool = True + embeddings: NDArray, + n_components: int = 2, + n_neighbors: int = 15, + normalize: bool = True, ) -> tuple[umap.UMAP, NDArray]: """Fit UMAP model and transform embeddings.""" if normalize: embeddings = StandardScaler().fit_transform(embeddings) - umap_model = umap.UMAP(n_components=n_components, random_state=42) + umap_model = umap.UMAP( + n_components=n_components, n_neighbors=n_neighbors, random_state=42 + ) umap_embedding = umap_model.fit_transform(embeddings) return umap_model, umap_embedding diff --git a/viscy/scripts/anndata_annotations.py b/viscy/scripts/anndata_annotations.py new file mode 100644 index 000000000..5d056d678 --- /dev/null +++ b/viscy/scripts/anndata_annotations.py @@ -0,0 +1,77 @@ +# %% +""" +Example script for converting xarray embeddings to AnnData format and loading annotations. + +This script demonstrates: +1. Converting xarray embeddings to AnnData format +2. Loading annotations into AnnData objects +3. Simple Plotting example +""" + +from pathlib import Path + +import pandas as pd +import seaborn as sns + +# Optional for plotting directly with AnnData objects w/o manual accessing patterns +# import scanpy as sc +import xarray as xr + +from viscy.representation.evaluation import ( + convert_xarray_annotation_to_anndata, + load_annotation_anndata, +) + +# %% +# Define paths +embeddings_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_21_A549_TOMM20_DENV/4-phenotyping/1-predictions/phase_160patch_104ckpt_ver3max.zarr" +) +annotations_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_21_A549_TOMM20_DENV/4-phenotyping/0-annotations/track_infection_annotation.csv" +) +output_path = Path("../output/track_data_anndata.zarr") + +# %% +# Load embeddings +embeddings_dataset = xr.open_zarr(embeddings_path) + +# %% +# Convert xarray to AnnData +adata = convert_xarray_annotation_to_anndata( + embeddings_dataset, + output_path, + overwrite=True, + return_anndata=True, +) +print(adata) + +# %% +# Load annotations +adata_annotated = load_annotation_anndata( + adata=adata, + path=annotations_path, + name="infection_status", +) + +# %% +# Show results +print(adata_annotated) + +# %% +# Simple Accessing and Plotting (matplotlib) +# Plot the first two Phate embeddings colored by fov_name + +sns.scatterplot( + data=pd.DataFrame(adata.obsm["X_phate"], index=adata.obs_names).join(adata.obs), + x=0, + y=1, + hue="fov_name", +) + + +# %% +# Simple Plotting with scanpy if you have it installed. +# Plot the first two Phate embeddings colored by fov_name + +# sc.pl.embedding(basis="phate", adata=adata, color="fov_name")