@@ -309,7 +309,7 @@ def main(args):
309309 if args .use_cpu :
310310 device = torch .device ("cpu" )
311311
312- data = dataset .SCData .factory (args .dataset , args . max_dim )
312+ data = dataset .SCData .factory (args .dataset , args )
313313
314314 args .timepoints = data .get_unique_times ()
315315
@@ -321,9 +321,10 @@ def main(args):
321321 model = build_model_tabular (args , data .get_shape ()[0 ], regularization_fns ).to (
322322 device
323323 )
324- growth_model_path = data .get_growth_net_path ()
325- #growth_model_path = "/home/atong/TrajectoryNet/data/externel/growth_model_v2.ckpt"
326- growth_model = torch .load (growth_model_path , map_location = device )
324+ if args .use_growth :
325+ growth_model_path = data .get_growth_net_path ()
326+ #growth_model_path = "/home/atong/TrajectoryNet/data/externel/growth_model_v2.ckpt"
327+ growth_model = torch .load (growth_model_path , map_location = device )
327328 if args .spectral_norm :
328329 add_spectral_norm (model )
329330 set_cnf_options (args , model )
@@ -340,8 +341,8 @@ def main(args):
340341 args .int_tps = (np .arange (max (args .timepoints ) + 1 ) + 1.0 ) * args .time_scale
341342
342343 print ('integrating backwards' )
343- end_time_data = data .data_dict ['mphate_expression' ]
344- # end_time_data = data.get_data()[args.data.get_times()==np.max(args.data.get_times())]
344+ # end_time_data = data.data_dict[args.embedding_name ]
345+ end_time_data = data .get_data ()[args .data .get_times ()== np .max (args .data .get_times ())]
345346 #np.random.permutation(end_time_data)
346347 #rand_idx = np.random.randint(end_time_data.shape[0], size=5000)
347348 #end_time_data = end_time_data[rand_idx,:]
0 commit comments