Skip to content

Commit ec65bdf

Browse files
committed
Fixes to eval.py to bring it up to new main.py format
1 parent f991696 commit ec65bdf

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

TrajectoryNet/eval.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def main(args):
309309
if args.use_cpu:
310310
device = torch.device("cpu")
311311

312-
data = dataset.SCData.factory(args.dataset, args.max_dim)
312+
data = dataset.SCData.factory(args.dataset, args)
313313

314314
args.timepoints = data.get_unique_times()
315315

@@ -321,9 +321,10 @@ def main(args):
321321
model = build_model_tabular(args, data.get_shape()[0], regularization_fns).to(
322322
device
323323
)
324-
growth_model_path = data.get_growth_net_path()
325-
#growth_model_path = "/home/atong/TrajectoryNet/data/externel/growth_model_v2.ckpt"
326-
growth_model = torch.load(growth_model_path, map_location=device)
324+
if args.use_growth:
325+
growth_model_path = data.get_growth_net_path()
326+
#growth_model_path = "/home/atong/TrajectoryNet/data/externel/growth_model_v2.ckpt"
327+
growth_model = torch.load(growth_model_path, map_location=device)
327328
if args.spectral_norm:
328329
add_spectral_norm(model)
329330
set_cnf_options(args, model)
@@ -340,8 +341,8 @@ def main(args):
340341
args.int_tps = (np.arange(max(args.timepoints) + 1) + 1.0) * args.time_scale
341342

342343
print('integrating backwards')
343-
end_time_data = data.data_dict['mphate_expression']
344-
#end_time_data = data.get_data()[args.data.get_times()==np.max(args.data.get_times())]
344+
#end_time_data = data.data_dict[args.embedding_name]
345+
end_time_data = data.get_data()[args.data.get_times()==np.max(args.data.get_times())]
345346
#np.random.permutation(end_time_data)
346347
#rand_idx = np.random.randint(end_time_data.shape[0], size=5000)
347348
#end_time_data = end_time_data[rand_idx,:]

0 commit comments

Comments
 (0)