1
+ from typing import TYPE_CHECKING
2
+
1
3
import pandas as pd
2
4
import umap
3
- import xarray as xr
4
5
from numpy .typing import NDArray
5
6
from sklearn .decomposition import PCA
6
7
from sklearn .preprocessing import StandardScaler
8
+ from xarray import Dataset
9
+
10
+ if TYPE_CHECKING :
11
+ from phate import PHATE
7
12
8
13
9
14
def compute_phate (
10
- embedding_dataset : NDArray | xr . Dataset ,
15
+ embedding_dataset : NDArray | Dataset ,
11
16
n_components : int = 2 ,
12
17
knn : int = 5 ,
13
18
decay : int = 40 ,
14
19
update_dataset : bool = False ,
15
20
** phate_kwargs ,
16
- ) -> tuple [object , NDArray ]:
21
+ ) -> tuple [PHATE , NDArray ]:
17
22
"""
18
23
Compute PHATE embeddings for features and optionally update dataset.
19
24
20
25
Parameters
21
26
----------
22
- embedding_dataset : xr.Dataset | NDArray
27
+ embedding_dataset : NDArray | Dataset
23
28
The dataset containing embeddings, timepoints, fov_name, and track_id,
24
29
or a numpy array of embeddings.
25
30
n_components : int, optional
@@ -35,7 +40,7 @@ def compute_phate(
35
40
36
41
Returns
37
42
-------
38
- tuple[object , NDArray]
43
+ tuple[phate.PHATE , NDArray]
39
44
PHATE model and PHATE embeddings
40
45
41
46
Raises
@@ -53,7 +58,7 @@ def compute_phate(
53
58
# Get embeddings from dataset if needed
54
59
embeddings = (
55
60
embedding_dataset ["features" ].values
56
- if isinstance (embedding_dataset , xr . Dataset )
61
+ if isinstance (embedding_dataset , Dataset )
57
62
else embedding_dataset
58
63
)
59
64
@@ -64,7 +69,7 @@ def compute_phate(
64
69
phate_embedding = phate_model .fit_transform (embeddings )
65
70
66
71
# Update dataset if requested
67
- if update_dataset and isinstance (embedding_dataset , xr . Dataset ):
72
+ if update_dataset and isinstance (embedding_dataset , Dataset ):
68
73
for i in range (
69
74
min (2 , phate_embedding .shape [1 ])
70
75
): # Only update PHATE1 and PHATE2
@@ -73,12 +78,12 @@ def compute_phate(
73
78
return phate_model , phate_embedding
74
79
75
80
76
- def compute_pca (embedding_dataset , n_components = None , normalize_features = True ):
81
+ def compute_pca (embedding_dataset : NDArray | Dataset , n_components = None , normalize_features = True ):
77
82
"""Compute PCA embeddings for features and optionally update dataset.
78
83
79
84
Parameters
80
85
----------
81
- embedding_dataset : xr. Dataset or NDArray
86
+ embedding_dataset : Dataset | NDArray
82
87
The dataset containing embeddings, timepoints, fov_name, and track_id,
83
88
or a numpy array of embeddings.
84
89
n_components : int, optional
@@ -93,7 +98,7 @@ def compute_pca(embedding_dataset, n_components=None, normalize_features=True):
93
98
"""
94
99
embeddings = (
95
100
embedding_dataset ["features" ].values
96
- if isinstance (embedding_dataset , xr . Dataset )
101
+ if isinstance (embedding_dataset , Dataset )
97
102
else embedding_dataset
98
103
)
99
104
@@ -107,7 +112,7 @@ def compute_pca(embedding_dataset, n_components=None, normalize_features=True):
107
112
pc_features = PCA_features .fit_transform (scaled_features )
108
113
109
114
# Create base dictionary with id and fov_name
110
- if isinstance (embedding_dataset , xr . Dataset ):
115
+ if isinstance (embedding_dataset , Dataset ):
111
116
pca_dict = {
112
117
"id" : embedding_dataset ["id" ].values ,
113
118
"fov_name" : embedding_dataset ["fov_name" ].values ,
@@ -139,13 +144,13 @@ def _fit_transform_umap(
139
144
140
145
141
146
def compute_umap (
142
- embedding_dataset : xr . Dataset , normalize_features : bool = True
147
+ embedding_dataset : Dataset , normalize_features : bool = True
143
148
) -> tuple [umap .UMAP , umap .UMAP , pd .DataFrame ]:
144
149
"""Compute UMAP embeddings for features and projections.
145
150
146
151
Parameters
147
152
----------
148
- embedding_dataset : xr. Dataset
153
+ embedding_dataset : Dataset
149
154
Xarray dataset with features and projections.
150
155
normalize_features : bool, optional
151
156
Scale the input to zero mean and unit variance before fitting UMAP,
0 commit comments