Skip to content

Commit 8eb8831

Browse files
committed
Fix dataset whiten and add scanpy requirement
1 parent dfd6f68 commit 8eb8831

File tree

3 files changed

+14
-9
lines changed

3 files changed

+14
-9
lines changed

TrajectoryNet/dataset.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import numpy as np
88
import torch
99
import scipy.sparse
10-
import scanpy as sc
1110

1211

1312
from sklearn.preprocessing import StandardScaler
@@ -201,13 +200,14 @@ def load(self, data_file, max_dim):
201200
self.labels = self.data_dict["sample_labels"]
202201
if self.embedding_name not in self.data_dict.keys():
203202
raise ValueError("Unknown embedding name %s" % self.embedding_name)
204-
embedding = self.data_dict[self.embedding_name]
205-
scaler = StandardScaler()
206-
scaler.fit(embedding)
207-
self.ncells = embedding.shape[0]
203+
self.data = self.data_dict[self.embedding_name]
204+
if self.args.whiten:
205+
scaler = StandardScaler()
206+
scaler.fit(self.data)
207+
self.data = scaler.transform(self.data)
208+
self.ncells = self.data.shape[0]
208209
assert self.labels.shape[0] == self.ncells
209210
# Scale so that embedding is normally distributed
210-
self.data = scaler.transform(embedding)
211211

212212
delta_name = "delta_%s" % self.embedding_name
213213
if delta_name not in self.data_dict.keys():
@@ -217,10 +217,11 @@ def load(self, data_file, max_dim):
217217
)
218218
self.use_velocity = False
219219
else:
220-
delta = self.data_dict[delta_name]
221-
assert delta.shape[0] == self.ncells
220+
self.velocity = self.data_dict[delta_name]
221+
assert self.velocity.shape[0] == self.ncells
222222
# Normalize ignoring mean from embedding
223-
self.velocity = delta / scaler.scale_
223+
if self.args.whiten:
224+
self.velocity = self.velocity / scaler.scale_
224225

225226
if max_dim is not None and self.data.shape[1] > max_dim:
226227
print("Warning: Clipping dimensionality to %d" % max_dim)
@@ -302,6 +303,8 @@ def load(self):
302303

303304
class CustomAnnDataFromFile(CustomAnnData):
304305
def __init__(self, name, args):
306+
import scanpy as sc
307+
305308
adata = sc.read_h5ad(name)
306309
super().__init__(adata, args)
307310

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ argparse
22
matplotlib>=3.2.1
33
numpy>=1.18.4
44
POT>=0.7.0
5+
scanpy
56
scikit-learn>=0.23.1
67
scipy>=1.4.1
78
torch>=1.5.0

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"matplotlib>=3.2.1",
88
"numpy>=1.18.4",
99
"POT>=0.7.0",
10+
"scanpy",
1011
"scikit-learn>=0.23.1",
1112
"scipy>=1.4.1",
1213
"torch>=1.5.0",

0 commit comments

Comments
 (0)