Skip to content

Commit 73fe93b

Browse files
committed
Working anndata loader
1 parent 1424031 commit 73fe93b

File tree

1 file changed

+56
-43
lines changed

1 file changed

+56
-43
lines changed

TrajectoryNet/dataset.py

Lines changed: 56 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
import torch
99
import scipy.sparse
10-
import scanpy
10+
import scanpy as sc
1111

1212

1313
from sklearn.preprocessing import StandardScaler
@@ -150,7 +150,12 @@ def factory(name, args):
150150
return EBData("pcs", max_dim=args.max_dim)
151151

152152
# If none of the above, we assume a path to a .npz file is supplied
153-
return CustomData(name, args)
153+
if name.endswith(".h5ad"):
154+
return CustomAnnDataFromFile(name, args)
155+
if name.endswith(".npz"):
156+
return CustomData(name, args)
157+
158+
raise KeyError(f"Unknown dataset name {name}")
154159

155160

156161
def _get_data_points(adata, basis) -> np.ndarray:
@@ -164,51 +169,24 @@ def _get_data_points(adata, basis) -> np.ndarray:
164169
else:
165170
raise KeyError(
166171
f"Could not find entry in `obsm` for '{basis}'.\n"
167-
f"Available keys are: {list(adata.obsm.keys()}."
172+
f"Available keys are: {list(adata.obsm.keys())}."
168173
)
169174

170-
return np.array(adata.obsm[basis_key])[:, offset : offset + n_dims]
171-
172-
173-
class CustomAnnData(CustomData):
174-
def __init__(self, name, args):
175-
super().__init__()
176-
self.args = args
177-
self.embedding_name = args.embedding_name
178-
self.load(name, args.max_dim)
175+
data_points = np.array(adata.obsm[basis_key])
176+
velocity_points = None
179177

178+
if f"velocity_{basis}" in adata.obsm.keys():
179+
velocity_basis_key = f"velocity_{basis}"
180+
velocity_points = np.array(adata.obsm[velocity_basis_key])
181+
else:
182+
print(
183+
f"Could not find entry in `obsm` for 'velocity_{basis}'.\n"
184+
f"Available keys are: {list(adata.obsm.keys())}.\n"
185+
f"Assuming no velocity data."
186+
)
180187

181-
def load(self, data_file, max_dim):
182-
self.adata = sc.read_h5ad(data_file)
183-
self.labels = self.data_dict["sample_labels"]
184-
self.data = _get_data_points(self.adata, self.embedding_name)
188+
return data_points, velocity_points
185189

186-
if self.args.whiten:
187-
scaler = StandardScaler()
188-
scaler.fit(self.data)
189-
self.data = scaler.transform(self.data)
190-
191-
self.ncells = self.data.shape[0]
192-
assert self.labels.shape[0] == self.ncells
193-
194-
delta_name = "delta_%s" % self.embedding_name
195-
if delta_name not in self.data_dict.keys():
196-
print(
197-
"No velocity found for embedding %s skipping velocity"
198-
% self.embedding_name
199-
)
200-
self.use_velocity = False
201-
else:
202-
delta = self.data_dict[delta_name]
203-
assert delta.shape[0] == self.ncells
204-
# Normalize ignoring mean from embedding
205-
self.velocity = delta / scaler.scale_
206-
207-
if max_dim is not None and self.data.shape[1] > max_dim:
208-
print("Warning: Clipping dimensionality to %d" % max_dim)
209-
self.data = self.data[:, :max_dim]
210-
if self.use_velocity:
211-
self.velocity = self.velocity[:, :max_dim]
212190

213191

214192
class CustomData(SCData):
@@ -293,6 +271,42 @@ def sample_index(self, n, label_subset):
293271
return np.random.choice(arr, size=n)
294272

295273

274+
class CustomAnnData(CustomData):
275+
def __init__(self, adata, args):
276+
self.args = args
277+
self.adata = adata
278+
self.load()
279+
280+
def load(self):
281+
self.labels = np.array(self.adata.obs["sample_labels"])
282+
self.data, self.velocity = _get_data_points(self.adata, self.args.embedding_name)
283+
284+
if self.args.whiten:
285+
scaler = StandardScaler()
286+
scaler.fit(self.data)
287+
self.data = scaler.transform(self.data)
288+
if self.velocity is not None:
289+
self.velocity = self.velocity / scaler.scale_
290+
self.use_velocity = self.velocity is not None
291+
292+
self.ncells = self.data.shape[0]
293+
assert self.labels.shape[0] == self.ncells
294+
295+
max_dim = self.args.max_dim
296+
if max_dim is not None and self.data.shape[1] > max_dim:
297+
print(f"Warning: Clipping dimensionality from {self.data.shape[1]} to {max_dim}")
298+
self.data = self.data[:, :max_dim]
299+
if self.use_velocity:
300+
self.velocity = self.velocity[:, :max_dim]
301+
302+
303+
class CustomAnnDataFromFile(CustomAnnData):
304+
def __init__(self, name, args):
305+
adata = sc.read_h5ad(name)
306+
super().__init__(adata, args)
307+
308+
309+
296310
class EBData(SCData):
297311
def __init__(
298312
self, embedding_name="phate", max_dim=None, use_velocity=True, version=5
@@ -486,7 +500,6 @@ def interpolate_with_ot(p0, p1, tmap, interp_frac, size):
486500
p1 = p1.toarray() if scipy.sparse.isspmatrix(p1) else p1
487501
p0 = np.asarray(p0, dtype=np.float64)
488502
p1 = np.asarray(p1, dtype=np.float64)
489-
print(p0.shape, p1.shape)
490503
tmap = np.asarray(tmap, dtype=np.float64)
491504
if p0.shape[1] != p1.shape[1]:
492505
raise ValueError("Unable to interpolate. Number of genes do not match")

0 commit comments

Comments
 (0)