Skip to content

Commit 936a954

Browse files
committed
reverted some typhints
1 parent f0df162 commit 936a954

File tree

6 files changed

+62
-36
lines changed

6 files changed

+62
-36
lines changed

viscy/representation/embedding_writer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import numpy as np
77
import pandas as pd
88
import torch
9-
import xarray as xr
109
from lightning.pytorch import LightningModule, Trainer
1110
from lightning.pytorch.callbacks import BasePredictionWriter
1211
from numpy.typing import NDArray
@@ -23,7 +22,7 @@
2322
_logger = logging.getLogger("lightning.pytorch")
2423

2524

26-
def read_embedding_dataset(path: Path) -> xr.Dataset:
25+
def read_embedding_dataset(path: Path) -> Dataset:
2726
"""Read the embedding dataset written by the EmbeddingWriter callback.
2827
2928
Supports both legacy datasets (without x/y coordinates) and new datasets.
@@ -35,7 +34,7 @@ def read_embedding_dataset(path: Path) -> xr.Dataset:
3534
3635
Returns
3736
-------
38-
xr.Dataset
37+
Dataset
3938
Xarray dataset with features and projections.
4039
"""
4140
dataset = open_zarr(path)

viscy/representation/evaluation/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,19 @@
1818
from pathlib import Path
1919

2020
import pandas as pd
21-
import xarray as xr
2221
from viscy.data.triplet import TripletDataModule
22+
from xarray import DataArray
2323

2424

2525
def load_annotation(
26-
da: xr.DataArray, path: str, name: str, categories: dict | None = None
26+
da: DataArray, path: str, name: str, categories: dict | None = None
2727
) -> pd.Series:
2828
"""
2929
Load annotations from a CSV file and map them to the dataset.
3030
3131
Parameters
3232
----------
33-
da : xr.DataArray
33+
da : DataArray
3434
The dataset array containing 'fov_name' and 'id' coordinates.
3535
path : str
3636
Path to the CSV file containing annotations.

viscy/representation/evaluation/clustering.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Methods for evaluating clustering performance."""
2+
13
import numpy as np
24
from numpy.typing import ArrayLike, NDArray
35
from scipy.spatial.distance import cdist
@@ -10,12 +12,16 @@
1012
from sklearn.neighbors import KNeighborsClassifier
1113

1214

13-
def knn_accuracy(embeddings, annotations, k=5):
15+
def knn_accuracy(embeddings: NDArray, annotations: NDArray, k: int = 5) -> float:
1416
"""
1517
Evaluate the k-NN classification accuracy.
1618
1719
Parameters
1820
----------
21+
embeddings : NDArray
22+
Embeddings to cluster.
23+
annotations : NDArray
24+
Ground truth labels.
1925
k : int, optional
2026
Number of neighbors to use for k-NN. Default is 5.
2127

viscy/representation/evaluation/dimensionality_reduction.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,30 @@
1+
from typing import TYPE_CHECKING
2+
13
import pandas as pd
24
import umap
3-
import xarray as xr
45
from numpy.typing import NDArray
56
from sklearn.decomposition import PCA
67
from sklearn.preprocessing import StandardScaler
8+
from xarray import Dataset
9+
10+
if TYPE_CHECKING:
11+
from phate import PHATE
712

813

914
def compute_phate(
10-
embedding_dataset: NDArray | xr.Dataset,
15+
embedding_dataset: NDArray | Dataset,
1116
n_components: int = 2,
1217
knn: int = 5,
1318
decay: int = 40,
1419
update_dataset: bool = False,
1520
**phate_kwargs,
16-
) -> tuple[object, NDArray]:
21+
) -> tuple[PHATE, NDArray]:
1722
"""
1823
Compute PHATE embeddings for features and optionally update dataset.
1924
2025
Parameters
2126
----------
22-
embedding_dataset : xr.Dataset | NDArray
27+
embedding_dataset : NDArray | Dataset
2328
The dataset containing embeddings, timepoints, fov_name, and track_id,
2429
or a numpy array of embeddings.
2530
n_components : int, optional
@@ -35,7 +40,7 @@ def compute_phate(
3540
3641
Returns
3742
-------
38-
tuple[object, NDArray]
43+
tuple[phate.PHATE, NDArray]
3944
PHATE model and PHATE embeddings
4045
4146
Raises
@@ -53,7 +58,7 @@ def compute_phate(
5358
# Get embeddings from dataset if needed
5459
embeddings = (
5560
embedding_dataset["features"].values
56-
if isinstance(embedding_dataset, xr.Dataset)
61+
if isinstance(embedding_dataset, Dataset)
5762
else embedding_dataset
5863
)
5964

@@ -64,7 +69,7 @@ def compute_phate(
6469
phate_embedding = phate_model.fit_transform(embeddings)
6570

6671
# Update dataset if requested
67-
if update_dataset and isinstance(embedding_dataset, xr.Dataset):
72+
if update_dataset and isinstance(embedding_dataset, Dataset):
6873
for i in range(
6974
min(2, phate_embedding.shape[1])
7075
): # Only update PHATE1 and PHATE2
@@ -73,12 +78,12 @@ def compute_phate(
7378
return phate_model, phate_embedding
7479

7580

76-
def compute_pca(embedding_dataset, n_components=None, normalize_features=True):
81+
def compute_pca(embedding_dataset: NDArray | Dataset, n_components=None, normalize_features=True):
7782
"""Compute PCA embeddings for features and optionally update dataset.
7883
7984
Parameters
8085
----------
81-
embedding_dataset : xr.Dataset or NDArray
86+
embedding_dataset : Dataset | NDArray
8287
The dataset containing embeddings, timepoints, fov_name, and track_id,
8388
or a numpy array of embeddings.
8489
n_components : int, optional
@@ -93,7 +98,7 @@ def compute_pca(embedding_dataset, n_components=None, normalize_features=True):
9398
"""
9499
embeddings = (
95100
embedding_dataset["features"].values
96-
if isinstance(embedding_dataset, xr.Dataset)
101+
if isinstance(embedding_dataset, Dataset)
97102
else embedding_dataset
98103
)
99104

@@ -107,7 +112,7 @@ def compute_pca(embedding_dataset, n_components=None, normalize_features=True):
107112
pc_features = PCA_features.fit_transform(scaled_features)
108113

109114
# Create base dictionary with id and fov_name
110-
if isinstance(embedding_dataset, xr.Dataset):
115+
if isinstance(embedding_dataset, Dataset):
111116
pca_dict = {
112117
"id": embedding_dataset["id"].values,
113118
"fov_name": embedding_dataset["fov_name"].values,
@@ -139,13 +144,13 @@ def _fit_transform_umap(
139144

140145

141146
def compute_umap(
142-
embedding_dataset: xr.Dataset, normalize_features: bool = True
147+
embedding_dataset: Dataset, normalize_features: bool = True
143148
) -> tuple[umap.UMAP, umap.UMAP, pd.DataFrame]:
144149
"""Compute UMAP embeddings for features and projections.
145150
146151
Parameters
147152
----------
148-
embedding_dataset : xr.Dataset
153+
embedding_dataset : Dataset
149154
Xarray dataset with features and projections.
150155
normalize_features : bool, optional
151156
Scale the input to zero mean and unit variance before fitting UMAP,

viscy/representation/evaluation/distance.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,32 @@
22
from typing import Literal
33

44
import numpy as np
5-
import xarray as xr
5+
from numpy.typing import NDArray
66
from sklearn.metrics.pairwise import cosine_similarity
7+
from xarray import Dataset
78

89

910
def calculate_cosine_similarity_cell(
10-
embedding_dataset: xr.Dataset, fov_name: str, track_id: int
11-
):
12-
"""Extract embeddings and calculate cosine similarities for a specific cell"""
11+
embedding_dataset: Dataset, fov_name: str, track_id: int
12+
) -> tuple[NDArray, NDArray]:
13+
"""
14+
15+
Extract embeddings and calculate cosine similarities for a specific cell
16+
17+
Parameters
18+
----------
19+
embedding_dataset : Dataset
20+
Dataset containing embeddings and metadata
21+
fov_name : str
22+
Field of view identifier
23+
track_id : int
24+
Track identifier for the specific cell
25+
26+
Returns
27+
-------
28+
tuple[NDArray, NDArray]
29+
Time points and cosine similarities for the specific cell
30+
"""
1331
filtered_data = embedding_dataset.where(
1432
(embedding_dataset["fov_name"] == fov_name)
1533
& (embedding_dataset["track_id"] == track_id),
@@ -25,7 +43,7 @@ def calculate_cosine_similarity_cell(
2543

2644

2745
def compute_displacement(
28-
embedding_dataset: xr.Dataset,
46+
embedding_dataset: Dataset,
2947
distance_metric: Literal["euclidean_squared", "cosine"] = "euclidean_squared",
3048
) -> dict[int, list[float]]:
3149
"""Compute the displacement or mean square displacement (MSD) of embeddings.
@@ -37,15 +55,13 @@ def compute_displacement(
3755
3856
Parameters
3957
----------
40-
embedding_dataset : xarray.Dataset
58+
embedding_dataset : Dataset
4159
Dataset containing embeddings and metadata
42-
distance_metric : str
60+
distance_metric : Literal["euclidean_squared", "cosine"]
4361
The metric to use for computing distances between embeddings.
4462
Valid options are:
45-
- "euclidean": Euclidean distance (L2 norm)
4663
- "euclidean_squared": Squared Euclidean distance (for MSD, default)
4764
- "cosine": Cosine similarity
48-
- "cosine_dissimilarity": 1 - cosine similarity
4965
5066
Returns
5167
-------
@@ -152,13 +168,13 @@ def compute_dynamic_range(mean_displacement_per_tau: dict[int, float]):
152168
return max(displacements) - min(displacements)
153169

154170

155-
def compute_rms_per_track(embedding_dataset: xr.Dataset):
171+
def compute_rms_per_track(embedding_dataset: Dataset):
156172
"""
157173
Compute RMS of the time derivative of embeddings per track.
158174
159175
Parameters
160176
----------
161-
embedding_dataset : xarray.Dataset
177+
embedding_dataset : Dataset
162178
The dataset containing embeddings, timepoints, fov_name, and track_id.
163179
164180
Returns
@@ -204,13 +220,13 @@ def compute_rms_per_track(embedding_dataset: xr.Dataset):
204220

205221

206222
def calculate_normalized_euclidean_distance_cell(
207-
embedding_dataset: xr.Dataset, fov_name: str, track_id: int
223+
embedding_dataset: Dataset, fov_name: str, track_id: int
208224
):
209225
"""Calculate normalized euclidean distance for a specific cell track.
210226
211227
Parameters
212228
----------
213-
embedding_dataset : xr.Dataset
229+
embedding_dataset : Dataset
214230
Dataset containing embedding data with fov_name and track_id coordinates
215231
fov_name : str
216232
Field of view identifier

viscy/representation/evaluation/lca.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,18 @@
55
import pandas as pd
66
import torch
77
import torch.nn as nn
8-
import xarray as xr
98
from captum.attr import IntegratedGradients, Occlusion
109
from numpy.typing import NDArray
1110
from sklearn.linear_model import LogisticRegression
1211
from sklearn.metrics import classification_report
1312
from sklearn.preprocessing import StandardScaler
1413
from torch import Tensor
1514
from viscy.representation.contrastive import ContrastiveEncoder
15+
from xarray import DataArray
1616

1717

1818
def fit_logistic_regression(
19-
features: xr.DataArray,
19+
features: DataArray,
2020
annotations: pd.Series,
2121
train_fovs: list[str],
2222
remove_background_class: bool = True,
@@ -32,7 +32,7 @@ def fit_logistic_regression(
3232
3333
Parameters
3434
----------
35-
features : xr.DataArray
35+
features : DataArray
3636
Xarray of features.
3737
annotations : pd.Series
3838
Categorical class annotations with label values starting from 0.

0 commit comments

Comments
 (0)