@@ -71,6 +71,7 @@ def train(self, epochs, fold):
7171 best_f1 = - 1
7272 self .model .train ()
7373 os .makedirs (self .config .data ["train_save_path" ], exist_ok = True )
74+ os .makedirs (self .config .data ["train_save_path" ], exist_ok = True )
7475 for epoch in tqdm (range (epochs )):
7576 meter = Meter (fold )
7677
@@ -131,7 +132,7 @@ def train(self, epochs, fold):
131132 )
132133
133134 model_save_path = os .path .join (
134- self .config ["train_save_path" ],
135+ self .config . data ["train_save_path" ],
135136 "best_weights_fold{}.tar" .format (fold + 1 ),
136137 )
137138 torch .save (
@@ -356,15 +357,11 @@ def permutation_p(label, activation):
356357 df .to_csv (os .path .join (self .config ["test_save_path" ], "p_values.csv" ))
357358
358359 def memory (self , epoch = - 1 , phase : str = "free_recall1" , alongwith = []):
359- torch .manual_seed (self .config ["seed" ])
360- np .random .seed (self .config ["seed" ])
361- random .seed (self .config ["seed" ])
362- self .config ["free_recall_phase" ] = phase
363- if self .config ["patient" ] == "i728" and "1" in phase :
364- self .config ["free_recall_phase" ] = "free_recall1a"
365- dataloaders = initialize_inference_dataloaders (self .config )
366- else :
367- dataloaders = initialize_inference_dataloaders (self .config )
360+ torch .manual_seed (self .config .experiment ["seed" ])
361+ np .random .seed (self .config .experiment ["seed" ])
362+ random .seed (self .config .experiment ["seed" ])
363+ self .config .experiment ["free_recall_phase" ] = phase
364+ dataloaders = initialize_inference_dataloaders (self .config )
368365 model = initialize_model (self .config )
369366 # model = torch.compile(model)
370367 model = model .to (device_name )
@@ -375,79 +372,52 @@ def memory(self, epoch=-1, phase: str = "free_recall1", alongwith=[]):
375372
376373 # load the model with best F1-score
377374 # model_dir = os.path.join(self.config['train_save_path'], 'best_weights_fold{}.tar'.format(fold + 1))
378- model_dir = os .path .join (self .config ["train_save_path" ], "model_weights_epoch{}.tar" .format (epoch ))
375+ model_dir = os .path .join (self .config . data ["train_save_path" ], "model_weights_epoch{}.tar" .format (epoch ))
379376 model .load_state_dict (torch .load (model_dir )["model_state_dict" ])
380377 # print('Resume model: %s' % model_dir)
381378 model .eval ()
382379
383- predictions_all = np .empty ((0 , self .config ["num_labels" ]))
380+ predictions_all = np .empty ((0 , self .config . model ["num_labels" ]))
384381 predictions_length = {}
385382 with torch .no_grad ():
386- if self .config ["patient" ] == "i728" and "1" in phase :
387- # load the best epoch number from the saved "model_results" structure
388- for ph in ["FR1a" , "FR1b" ]:
389- predictions = np .empty ((0 , self .config ["num_labels" ]))
390- self .config ["free_recall_phase" ] = ph
391- dataloaders = initialize_inference_dataloaders (self .config )
392- # y_true = np.empty((0, self.config['num_labels']))
393- for i , (feature , index ) in enumerate (dataloaders ["inference" ]):
394- # target = target.to(self.device)
395- spike , lfp = self .extract_feature (feature )
396- # forward pass
397-
398- # start_time = time.time()
399- spike_emb , lfp_emb , output = model (lfp , spike )
400- # end_time = time.time()
401- # print('inference time: ', end_time - start_time)
402- output = torch .sigmoid (output )
403- pred = output .cpu ().detach ().numpy ()
404- predictions = np .concatenate ([predictions , pred ], axis = 0 )
405-
406- if self .config ["use_overlap" ]:
407- fake_activation = np .mean (predictions , axis = 0 )
408- predictions = np .vstack ((fake_activation , predictions , fake_activation ))
409-
410- predictions_all = np .concatenate ([predictions_all , predictions ], axis = 0 )
411- predictions_length [phase ] = len (predictions_all )
412- else :
413- self .config ["free_recall_phase" ] = phase
414- dataloaders = initialize_inference_dataloaders (self .config )
415- predictions = np .empty ((0 , self .config ["num_labels" ]))
416- # y_true = np.empty((0, self.config['num_labels']))
417- for i , (feature , index ) in enumerate (dataloaders ["inference" ]):
418- # target = target.to(self.device)
419- spike , lfp = self .extract_feature (feature )
420- # forward pass
383+ self .config .experiment ["free_recall_phase" ] = phase
384+ dataloaders = initialize_inference_dataloaders (self .config )
385+ predictions = np .empty ((0 , self .config .model ["num_labels" ]))
386+ # y_true = np.empty((0, self.config['num_labels']))
387+ for i , (feature , index ) in enumerate (dataloaders ["inference" ]):
388+ # target = target.to(self.device)
389+ spike , lfp = self .extract_feature (feature )
390+ # forward pass
421391
422- # start_time = time.time()
423- spike_emb , lfp_emb , output = model (lfp , spike )
424- # end_time = time.time()
425- # print('inference time: ', end_time - start_time)
426- output = torch .sigmoid (output )
427- pred = output .cpu ().detach ().numpy ()
428- predictions = np .concatenate ([predictions , pred ], axis = 0 )
392+ # start_time = time.time()
393+ spike_emb , lfp_emb , output = model (lfp , spike )
394+ # end_time = time.time()
395+ # print('inference time: ', end_time - start_time)
396+ output = torch .sigmoid (output )
397+ pred = output .cpu ().detach ().numpy ()
398+ predictions = np .concatenate ([predictions , pred ], axis = 0 )
429399
430- if self .config ["use_overlap" ]:
431- fake_activation = np .mean (predictions , axis = 0 )
432- predictions = np .vstack ((fake_activation , predictions , fake_activation ))
400+ if self .config . experiment ["use_overlap" ]:
401+ fake_activation = np .mean (predictions , axis = 0 )
402+ predictions = np .vstack ((fake_activation , predictions , fake_activation ))
433403
434- predictions_length [phase ] = len (predictions )
435- predictions_all = np .concatenate ([predictions_all , predictions ], axis = 0 )
404+ predictions_length [phase ] = len (predictions )
405+ predictions_all = np .concatenate ([predictions_all , predictions ], axis = 0 )
436406
437407 # np.save(os.path.join(self.config['memory_save_path'], 'free_recall_{}_results.npy'.format(phase)), predictions)
438- save_path = os .path .join (self .config ["memory_save_path" ], "prediction" )
408+ save_path = os .path .join (self .config . data ["memory_save_path" ], "prediction" )
439409 os .makedirs (save_path , exist_ok = True )
440410 np .save (
441411 os .path .join (save_path , "epoch{}_free_recall_{}_results.npy" .format (epoch , phase )),
442412 predictions_all ,
443413 )
444414
445415 for ph in alongwith :
446- self .config ["free_recall_phase" ] = ph
416+ self .config . experiment ["free_recall_phase" ] = ph
447417 dataloaders = initialize_inference_dataloaders (self .config )
448418 with torch .no_grad ():
449419 # load the best epoch number from the saved "model_results" structure
450- predictions = np .empty ((0 , self .config ["num_labels" ]))
420+ predictions = np .empty ((0 , self .config . model ["num_labels" ]))
451421 # y_true = np.empty((0, self.config['num_labels']))
452422 for i , (feature , index ) in enumerate (dataloaders ["inference" ]):
453423 # target = target.to(self.device)
@@ -462,7 +432,7 @@ def memory(self, epoch=-1, phase: str = "free_recall1", alongwith=[]):
462432 pred = output .cpu ().detach ().numpy ()
463433 predictions = np .concatenate ([predictions , pred ], axis = 0 )
464434
465- if self .config ["use_overlap" ]:
435+ if self .config . experiment ["use_overlap" ]:
466436 fake_activation = np .mean (predictions , axis = 0 )
467437 predictions = np .vstack ((fake_activation , predictions , fake_activation ))
468438
0 commit comments