Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ coverage.xml
slurm*.out

#lightning_logs directory
lightning_logs/
lightning_logs/
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ dependencies = [
"matplotlib>=3.9.0",
"numpy",
"xarray",
"pytorch-metric-learning>2.0.0"]
"pytorch-metric-learning>2.0.0",
"anndata>=0.12.2",
]
dynamic = ["version"]

[project.optional-dependencies]
Expand All @@ -38,7 +40,7 @@ metrics = [
"imbalanced-learn>=0.12.0",
"torchmetrics[detection]>=1.6.3",
"ptflops>=0.7",
"umap-learn",
"umap-learn>=0.5.9",
"captum>=0.7.0",
"mahotas",
]
Expand Down
33 changes: 11 additions & 22 deletions viscy/representation/embedding_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Sequence

import anndata as ad
import numpy as np
import pandas as pd
import torch
Expand Down Expand Up @@ -70,7 +71,7 @@ def write_embedding_dataset(
overwrite: bool = False,
) -> None:
"""
Write embeddings to a zarr store in an Xarray-compatible format.
Write embeddings to an AnnData Zarr Store.

Parameters
----------
Expand Down Expand Up @@ -118,8 +119,13 @@ def write_embedding_dataset(

# Create a copy of the index DataFrame to avoid modifying the original
ultrack_indices = index_df.copy()
ultrack_indices["fov_name"] = ultrack_indices["fov_name"].str.strip("/")
n_samples = len(features)

adata = ad.AnnData(X=features, obs=ultrack_indices)
if projections is not None:
adata.obsm["X_projections"] = projections

# Set up default kwargs for each method
if umap_kwargs:
if umap_kwargs["n_neighbors"] >= n_samples:
Expand All @@ -130,8 +136,7 @@ def write_embedding_dataset(

_logger.debug(f"Using UMAP kwargs: {umap_kwargs}")
_, UMAP = _fit_transform_umap(features, **umap_kwargs)
for i in range(UMAP.shape[1]):
ultrack_indices[f"UMAP{i + 1}"] = UMAP[:, i]
adata.obsm["X_umap"] = UMAP

if phate_kwargs:
# Update with user-provided kwargs
Expand All @@ -147,8 +152,7 @@ def write_embedding_dataset(
try:
_logger.debug("Computing PHATE")
_, PHATE = compute_phate(features, **phate_kwargs)
for i in range(PHATE.shape[1]):
ultrack_indices[f"PHATE{i + 1}"] = PHATE[:, i]
adata.obsm["X_phate"] = PHATE
except Exception as e:
_logger.warning(f"PHATE computation failed: {str(e)}")

Expand All @@ -158,27 +162,12 @@ def write_embedding_dataset(
try:
_logger.debug("Computing PCA")
PCA_features, _ = compute_pca(features, **pca_kwargs)
for i in range(PCA_features.shape[1]):
ultrack_indices[f"PC{i + 1}"] = PCA_features[:, i]
adata.obsm["X_pca"] = PCA_features
except Exception as e:
_logger.warning(f"PCA computation failed: {str(e)}")

# Create multi-index and dataset
index = pd.MultiIndex.from_frame(ultrack_indices)

# Create dataset dictionary with features
dataset_dict = {"features": (("sample", "features"), features)}

# Add projections if provided
if projections is not None:
dataset_dict["projections"] = (("sample", "projections"), projections)

# Create the dataset
dataset = Dataset(dataset_dict, coords={"sample": index}).reset_index("sample")

_logger.debug(f"Writing dataset to {output_path}")
with dataset.to_zarr(output_path, mode="w") as zarr_store:
zarr_store.close()
adata.write_zarr(output_path)


class EmbeddingWriter(BasePredictionWriter):
Expand Down
123 changes: 120 additions & 3 deletions viscy/representation/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
https://github.com/mehta-lab/dynacontrast/blob/master/analysis/gmm.py
"""

from pathlib import Path

import anndata as ad
import natsort as ns
import numpy as np
import pandas as pd
import xarray as xr

from viscy.data.triplet import TripletDataModule

Expand Down Expand Up @@ -42,9 +48,6 @@ def load_annotation(da, path, name, categories: dict | None = None):
# Read the annotation CSV file
annotation = pd.read_csv(path)

# Add a leading slash to 'fov name' column and set it as 'fov_name'
annotation["fov_name"] = "/" + annotation["fov_name"]

# Set the index of the annotation DataFrame to ['fov_name', 'id']
annotation = annotation.set_index(["fov_name", "id"])

Expand All @@ -63,6 +66,120 @@ def load_annotation(da, path, name, categories: dict | None = None):
return selected


def load_annotation_anndata(
adata: ad.AnnData, path: str, name: str, categories: dict | None = None
):
"""
Load annotations from a CSV file and map them to the AnnData object.

Parameters
----------
adata : anndata.AnnData
The AnnData object to map the annotations to.
path : str
Path to the CSV file containing annotations.
name : str
The column name in the CSV file to be used as annotations.
categories : dict, optional
A dictionary to rename categories in the annotation column. Default is None.
"""
annotation = pd.read_csv(path)
annotation["fov_name"] = annotation["fov_name"].str.strip("/")

annotation = annotation.set_index(["fov_name", "id"])
print(annotation.head())

mi = pd.MultiIndex.from_arrays(
[adata.obs["fov_name"], adata.obs["id"]], names=["fov_name", "id"]
)
selected = annotation.loc[mi][name]
if categories:
selected = selected.astype("category").cat.rename_categories(categories)
return selected


def convert_xarray_annotation_to_anndata(
embeddings_ds: xr.Dataset,
output_path: Path,
overwrite: bool = False,
return_anndata: bool = False,
) -> ad.AnnData | None:
"""
Convert an Xarray embeddings dataset to an AnnData object.

Parameters
----------
embeddings_ds : xr.Dataset
The Xarray embeddings dataset to convert.
output_path : Path
Path to the zarr store to write the AnnData object to.
overwrite : bool, optional
Whether to overwrite existing zarr store, by default False.
return_anndata : bool, optional
Whether to return the AnnData object, by default False.

Returns
-------
ad.AnnData | None
The AnnData object if return_anndata is True, otherwise None.

Raises
------
FileExistsError
If output_path exists and overwrite is False.

Examples
--------
>>> embeddings_ds = xr.open_zarr(embeddings_path)
>>> adata = convert_xarray_annotation_to_anndata(embeddings_ds, output_path, overwrite=True, return_anndata=True)
>>> adata
AnnData object with n_obs × n_vars = 18861 × 768
obs: 'id', 'fov_name', 'track_id', 'parent_track_id', 'parent_id', 't', 'y', 'x'
obsm: 'X_projections', 'X_PCA', 'X_UMAP', 'X_PHATE'
"""
# Check if output_path exists
if output_path.exists() and not overwrite:
raise FileExistsError(f"Output path {output_path} already exists.")

# Tracking
tracking_df = pd.DataFrame(
data={
"id": embeddings_ds.coords["id"].data,
"fov_name": embeddings_ds.coords["fov_name"].to_pandas().str.strip("/"),
"track_id": embeddings_ds.coords["track_id"].data,
"parent_track_id": embeddings_ds.coords["parent_track_id"].data,
"parent_id": embeddings_ds.coords["parent_id"].data,
"t": embeddings_ds.coords["t"].data,
"y": embeddings_ds.coords["y"].data,
"x": embeddings_ds.coords["x"].data,
},
)

obsm = {}
# Projections
if "projections" in embeddings_ds.coords:
obsm["X_projections"] = embeddings_ds.coords["projections"].data

# Embeddings
for embedding in ["PCA", "UMAP", "PHATE"]:
embedding_coords = ns.natsorted(
[coord for coord in embeddings_ds.coords if embedding in coord]
)
if embedding_coords:
obsm[f"X_{embedding.lower()}"] = np.column_stack(
[embeddings_ds.coords[coord] for coord in embedding_coords]
)

# X, "expression" matrix
X = embeddings_ds["features"].data

adata = ad.AnnData(X=X, obs=tracking_df, obsm=obsm)

adata.write_zarr(output_path)
if return_anndata:
return adata


def dataset_of_tracks(
data_path,
tracks_path,
Expand Down
9 changes: 7 additions & 2 deletions viscy/representation/evaluation/dimensionality_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,17 @@ def compute_pca(embedding_dataset, n_components=None, normalize_features=True):


def _fit_transform_umap(
embeddings: NDArray, n_components: int = 2, normalize: bool = True
embeddings: NDArray,
n_components: int = 2,
n_neighbors: int = 15,
normalize: bool = True,
) -> tuple[umap.UMAP, NDArray]:
"""Fit UMAP model and transform embeddings."""
if normalize:
embeddings = StandardScaler().fit_transform(embeddings)
umap_model = umap.UMAP(n_components=n_components, random_state=42)
umap_model = umap.UMAP(
n_components=n_components, n_neighbors=n_neighbors, random_state=42
)
umap_embedding = umap_model.fit_transform(embeddings)
return umap_model, umap_embedding

Expand Down
71 changes: 71 additions & 0 deletions viscy/scripts/anndata_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# %%
"""
Example script for converting xarray embeddings to AnnData format and loading annotations.

This script demonstrates:
1. Converting xarray embeddings to AnnData format
2. Loading annotations into AnnData objects
3. Simple Plotting example
"""

from pathlib import Path

import xarray as xr

from viscy.representation.evaluation import (
convert_xarray_annotation_to_anndata,
load_annotation_anndata,
)

# %%
# Define paths
data_path = Path("/hpc/mydata/sricharan.varra/repos/VisCy/data/2024_11_21_A549_TOMM20_DENV/")
annotations_path = data_path / "annotations" / "track_infection_annotation.csv"
embeddings_path = data_path / "embeddings" / "phase_160patch_104ckpt_ver3max.zarr"
output_path = data_path / "track_data_anndata.zarr"

# %%
# Load embeddings
embeddings_dataset = xr.open_zarr(embeddings_path)

# %%
# Convert xarray to AnnData
adata = convert_xarray_annotation_to_anndata(
embeddings_dataset,
output_path,
overwrite=True,
return_anndata=True,
)

# %%
# Load annotations
adata_annotated = load_annotation_anndata(
adata=adata,
path=annotations_path,
name="infection_status",
)

# %%
# Show results
print(adata_annotated.obs)

# %%
# Simple Accessing and Plotting (matplotlib)
# Plot the first two PCs colored by fov_name
import seaborn as sns

Check failure on line 55 in viscy/scripts/anndata_annotations.py

View workflow job for this annotation

GitHub Actions / Lint

Ruff (E402)

viscy/scripts/anndata_annotations.py:55:1: E402 Module level import not at top of file
import pandas as pd

Check failure on line 56 in viscy/scripts/anndata_annotations.py

View workflow job for this annotation

GitHub Actions / Lint

Ruff (E402)

viscy/scripts/anndata_annotations.py:56:1: E402 Module level import not at top of file

Check failure on line 56 in viscy/scripts/anndata_annotations.py

View workflow job for this annotation

GitHub Actions / Lint

Ruff (I001)

viscy/scripts/anndata_annotations.py:55:1: I001 Import block is un-sorted or un-formatted

sns.scatterplot(
data=pd.DataFrame(adata.obsm["X_pca"], index=adata.obs_names).join(adata.obs),
x=0,
y=1,
hue="fov_name",
)


# %%
# Simple Plotting (scanpy)
# Plot the first two PCs colored by fov_name
import scanpy as sc

Check failure on line 69 in viscy/scripts/anndata_annotations.py

View workflow job for this annotation

GitHub Actions / Lint

Ruff (E402)

viscy/scripts/anndata_annotations.py:69:1: E402 Module level import not at top of file

sc.pl.pca(adata, color="fov_name")
Loading