Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ slurm*.out

#lightning_logs directory
lightning_logs/
data/
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ dependencies = [
"matplotlib>=3.9.0",
"numpy",
"xarray",
"pytorch-metric-learning>2.0.0"]
"pytorch-metric-learning>2.0.0",
"anndata>=0.12.2",
]
dynamic = ["version"]

[project.optional-dependencies]
Expand All @@ -38,7 +40,7 @@ metrics = [
"imbalanced-learn>=0.12.0",
"torchmetrics[detection]>=1.6.3",
"ptflops>=0.7",
"umap-learn",
"umap-learn>=0.5.9",
"captum>=0.7.0",
"mahotas",
]
Expand Down
33 changes: 11 additions & 22 deletions viscy/representation/embedding_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Sequence

import anndata as ad
import numpy as np
import pandas as pd
import torch
Expand Down Expand Up @@ -70,7 +71,7 @@ def write_embedding_dataset(
overwrite: bool = False,
) -> None:
"""
Write embeddings to a zarr store in an Xarray-compatible format.
Write embeddings to an AnnData Zarr Store.

Parameters
----------
Expand Down Expand Up @@ -118,8 +119,13 @@ def write_embedding_dataset(

# Create a copy of the index DataFrame to avoid modifying the original
ultrack_indices = index_df.copy()
ultrack_indices["fov_name"] = ultrack_indices["fov_name"].str.strip("/")
n_samples = len(features)

adata = ad.AnnData(X=features, obs=ultrack_indices)
if projections is not None:
adata.obsm["X_projections"] = projections

# Set up default kwargs for each method
if umap_kwargs:
if umap_kwargs["n_neighbors"] >= n_samples:
Expand All @@ -130,8 +136,7 @@ def write_embedding_dataset(

_logger.debug(f"Using UMAP kwargs: {umap_kwargs}")
_, UMAP = _fit_transform_umap(features, **umap_kwargs)
for i in range(UMAP.shape[1]):
ultrack_indices[f"UMAP{i + 1}"] = UMAP[:, i]
adata.obsm["X_umap"] = UMAP

if phate_kwargs:
# Update with user-provided kwargs
Expand All @@ -147,8 +152,7 @@ def write_embedding_dataset(
try:
_logger.debug("Computing PHATE")
_, PHATE = compute_phate(features, **phate_kwargs)
for i in range(PHATE.shape[1]):
ultrack_indices[f"PHATE{i + 1}"] = PHATE[:, i]
adata.obsm["X_phate"] = PHATE
except Exception as e:
_logger.warning(f"PHATE computation failed: {str(e)}")

Expand All @@ -158,27 +162,12 @@ def write_embedding_dataset(
try:
_logger.debug("Computing PCA")
PCA_features, _ = compute_pca(features, **pca_kwargs)
for i in range(PCA_features.shape[1]):
ultrack_indices[f"PCA{i + 1}"] = PCA_features[:, i]
adata.obsm["X_pca"] = PCA_features
except Exception as e:
_logger.warning(f"PCA computation failed: {str(e)}")

# Create multi-index and dataset
index = pd.MultiIndex.from_frame(ultrack_indices)

# Create dataset dictionary with features
dataset_dict = {"features": (("sample", "features"), features)}

# Add projections if provided
if projections is not None:
dataset_dict["projections"] = (("sample", "projections"), projections)

# Create the dataset
dataset = Dataset(dataset_dict, coords={"sample": index}).reset_index("sample")

_logger.debug(f"Writing dataset to {output_path}")
with dataset.to_zarr(output_path, mode="w") as zarr_store:
zarr_store.close()
adata.write_zarr(output_path)


class EmbeddingWriter(BasePredictionWriter):
Expand Down
36 changes: 33 additions & 3 deletions viscy/representation/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
https://github.com/mehta-lab/dynacontrast/blob/master/analysis/gmm.py
"""

import anndata as ad
import pandas as pd

from viscy.data.triplet import TripletDataModule
Expand Down Expand Up @@ -42,9 +43,6 @@ def load_annotation(da, path, name, categories: dict | None = None):
# Read the annotation CSV file
annotation = pd.read_csv(path)

# Add a leading slash to 'fov name' column and set it as 'fov_name'
annotation["fov_name"] = "/" + annotation["fov_name"]

# Set the index of the annotation DataFrame to ['fov_name', 'id']
annotation = annotation.set_index(["fov_name", "id"])

Expand All @@ -63,6 +61,38 @@ def load_annotation(da, path, name, categories: dict | None = None):
return selected


def load_annotation_adata(
adata: ad.AnnData, path: str, name: str, categories: dict | None = None
):
"""
Load annotations from a CSV file and map them to the AnnData object.

Parameters
----------
adata : anndata.AnnData
The AnnData object to map the annotations to.
path : str
Path to the CSV file containing annotations.
name : str
The column name in the CSV file to be used as annotations.
categories : dict, optional
A dictionary to rename categories in the annotation column. Default is None.
"""
annotation = pd.read_csv(path)
annotation["fov_name"] = annotation["fov_name"].str.strip("/")

annotation = annotation.set_index(["fov_name", "id"])
print(annotation.head())

mi = pd.MultiIndex.from_arrays(
[adata.obs["fov_name"], adata.obs["id"]], names=["fov_name", "id"]
)
selected = annotation.loc[mi][name]
if categories:
selected = selected.astype("category").cat.rename_categories(categories)
return selected


def dataset_of_tracks(
data_path,
tracks_path,
Expand Down
9 changes: 7 additions & 2 deletions viscy/representation/evaluation/dimensionality_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,17 @@ def compute_pca(embedding_dataset, n_components=None, normalize_features=True):


def _fit_transform_umap(
embeddings: NDArray, n_components: int = 2, normalize: bool = True
embeddings: NDArray,
n_components: int = 2,
n_neighbors: int = 15,
normalize: bool = True,
) -> tuple[umap.UMAP, NDArray]:
"""Fit UMAP model and transform embeddings."""
if normalize:
embeddings = StandardScaler().fit_transform(embeddings)
umap_model = umap.UMAP(n_components=n_components, random_state=42)
umap_model = umap.UMAP(
n_components=n_components, n_neighbors=n_neighbors, random_state=42
)
umap_embedding = umap_model.fit_transform(embeddings)
return umap_model, umap_embedding

Expand Down