77import numpy as np
88import torch
99import scipy .sparse
10- import scanpy
10+ import scanpy as sc
1111
1212
1313from sklearn .preprocessing import StandardScaler
@@ -150,7 +150,12 @@ def factory(name, args):
150150 return EBData ("pcs" , max_dim = args .max_dim )
151151
152152 # If none of the above, we assume a path to a .npz file is supplied
153- return CustomData (name , args )
153+ if name .endswith (".h5ad" ):
154+ return CustomAnnDataFromFile (name , args )
155+ if name .endswith (".npz" ):
156+ return CustomData (name , args )
157+
158+ raise KeyError (f"Unknown dataset name { name } " )
154159
155160
156161def _get_data_points (adata , basis ) -> np .ndarray :
@@ -164,51 +169,24 @@ def _get_data_points(adata, basis) -> np.ndarray:
164169 else :
165170 raise KeyError (
166171 f"Could not find entry in `obsm` for '{ basis } '.\n "
167- f"Available keys are: { list (adata .obsm .keys ()} ."
172+ f"Available keys are: { list (adata .obsm .keys ()) } ."
168173 )
169174
170- return np .array (adata .obsm [basis_key ])[:, offset : offset + n_dims ]
171-
172-
173- class CustomAnnData (CustomData ):
174- def __init__ (self , name , args ):
175- super ().__init__ ()
176- self .args = args
177- self .embedding_name = args .embedding_name
178- self .load (name , args .max_dim )
175+ data_points = np .array (adata .obsm [basis_key ])
176+ velocity_points = None
179177
178+ if f"velocity_{ basis } " in adata .obsm .keys ():
179+ velocity_basis_key = f"velocity_{ basis } "
180+ velocity_points = np .array (adata .obsm [velocity_basis_key ])
181+ else :
182+ print (
183+ f"Could not find entry in `obsm` for 'velocity_{ basis } '.\n "
184+ f"Available keys are: { list (adata .obsm .keys ())} .\n "
185+ f"Assuming no velocity data."
186+ )
180187
181- def load (self , data_file , max_dim ):
182- self .adata = sc .read_h5ad (data_file )
183- self .labels = self .data_dict ["sample_labels" ]
184- self .data = _get_data_points (self .adata , self .embedding_name )
188+ return data_points , velocity_points
185189
186- if self .args .whiten :
187- scaler = StandardScaler ()
188- scaler .fit (self .data )
189- self .data = scaler .transform (self .data )
190-
191- self .ncells = self .data .shape [0 ]
192- assert self .labels .shape [0 ] == self .ncells
193-
194- delta_name = "delta_%s" % self .embedding_name
195- if delta_name not in self .data_dict .keys ():
196- print (
197- "No velocity found for embedding %s skipping velocity"
198- % self .embedding_name
199- )
200- self .use_velocity = False
201- else :
202- delta = self .data_dict [delta_name ]
203- assert delta .shape [0 ] == self .ncells
204- # Normalize ignoring mean from embedding
205- self .velocity = delta / scaler .scale_
206-
207- if max_dim is not None and self .data .shape [1 ] > max_dim :
208- print ("Warning: Clipping dimensionality to %d" % max_dim )
209- self .data = self .data [:, :max_dim ]
210- if self .use_velocity :
211- self .velocity = self .velocity [:, :max_dim ]
212190
213191
214192class CustomData (SCData ):
@@ -293,6 +271,42 @@ def sample_index(self, n, label_subset):
293271 return np .random .choice (arr , size = n )
294272
295273
274+ class CustomAnnData (CustomData ):
275+ def __init__ (self , adata , args ):
276+ self .args = args
277+ self .adata = adata
278+ self .load ()
279+
280+ def load (self ):
281+ self .labels = np .array (self .adata .obs ["sample_labels" ])
282+ self .data , self .velocity = _get_data_points (self .adata , self .args .embedding_name )
283+
284+ if self .args .whiten :
285+ scaler = StandardScaler ()
286+ scaler .fit (self .data )
287+ self .data = scaler .transform (self .data )
288+ if self .velocity is not None :
289+ self .velocity = self .velocity / scaler .scale_
290+ self .use_velocity = self .velocity is not None
291+
292+ self .ncells = self .data .shape [0 ]
293+ assert self .labels .shape [0 ] == self .ncells
294+
295+ max_dim = self .args .max_dim
296+ if max_dim is not None and self .data .shape [1 ] > max_dim :
297+ print (f"Warning: Clipping dimensionality from { self .data .shape [1 ]} to { max_dim } " )
298+ self .data = self .data [:, :max_dim ]
299+ if self .use_velocity :
300+ self .velocity = self .velocity [:, :max_dim ]
301+
302+
303+ class CustomAnnDataFromFile (CustomAnnData ):
304+ def __init__ (self , name , args ):
305+ adata = sc .read_h5ad (name )
306+ super ().__init__ (adata , args )
307+
308+
309+
296310class EBData (SCData ):
297311 def __init__ (
298312 self , embedding_name = "phate" , max_dim = None , use_velocity = True , version = 5
@@ -486,7 +500,6 @@ def interpolate_with_ot(p0, p1, tmap, interp_frac, size):
486500 p1 = p1 .toarray () if scipy .sparse .isspmatrix (p1 ) else p1
487501 p0 = np .asarray (p0 , dtype = np .float64 )
488502 p1 = np .asarray (p1 , dtype = np .float64 )
489- print (p0 .shape , p1 .shape )
490503 tmap = np .asarray (tmap , dtype = np .float64 )
491504 if p0 .shape [1 ] != p1 .shape [1 ]:
492505 raise ValueError ("Unable to interpolate. Number of genes do not match" )
0 commit comments