77import numpy as np
88import torch
99import scipy .sparse
10- import scanpy as sc
1110
1211
1312from sklearn .preprocessing import StandardScaler
@@ -201,13 +200,14 @@ def load(self, data_file, max_dim):
201200 self .labels = self .data_dict ["sample_labels" ]
202201 if self .embedding_name not in self .data_dict .keys ():
203202 raise ValueError ("Unknown embedding name %s" % self .embedding_name )
204- embedding = self .data_dict [self .embedding_name ]
205- scaler = StandardScaler ()
206- scaler .fit (embedding )
207- self .ncells = embedding .shape [0 ]
203+ self .data = self .data_dict [self .embedding_name ]
204+ if self .args .whiten :
205+ scaler = StandardScaler ()
206+ scaler .fit (self .data )
207+ self .data = scaler .transform (self .data )
208+ self .ncells = self .data .shape [0 ]
208209 assert self .labels .shape [0 ] == self .ncells
209210 # Scale so that embedding is normally distributed
210- self .data = scaler .transform (embedding )
211211
212212 delta_name = "delta_%s" % self .embedding_name
213213 if delta_name not in self .data_dict .keys ():
@@ -217,10 +217,11 @@ def load(self, data_file, max_dim):
217217 )
218218 self .use_velocity = False
219219 else :
220- delta = self .data_dict [delta_name ]
221- assert delta .shape [0 ] == self .ncells
220+ self . velocity = self .data_dict [delta_name ]
221+ assert self . velocity .shape [0 ] == self .ncells
222222 # Normalize ignoring mean from embedding
223- self .velocity = delta / scaler .scale_
223+ if self .args .whiten :
224+ self .velocity = self .velocity / scaler .scale_
224225
225226 if max_dim is not None and self .data .shape [1 ] > max_dim :
226227 print ("Warning: Clipping dimensionality to %d" % max_dim )
@@ -302,6 +303,8 @@ def load(self):
302303
303304class CustomAnnDataFromFile (CustomAnnData ):
304305 def __init__ (self , name , args ):
306+ import scanpy as sc
307+
305308 adata = sc .read_h5ad (name )
306309 super ().__init__ (adata , args )
307310
0 commit comments