Skip to content

Commit bfba0bf

Browse files
authored
Merge pull request #4 from NxNiki/dev
[refactor] bug fix on config
2 parents 6598189 + 6350dda commit bfba0bf

File tree

8 files changed

+127
-132
lines changed

8 files changed

+127
-132
lines changed

.run/main.run.xml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
<component name="ProjectRunConfigurationManager">
2-
<configuration default="false" name="main" type="PythonConfigurationType" factoryName="Python" nameIsGenerated="true">
3-
<module name="movie_decoding" />
2+
<configuration default="false" name="main" type="PythonConfigurationType" factoryName="Python">
3+
<module name="brain_decoding" />
44
<option name="ENV_FILES" value="" />
55
<option name="INTERPRETER_OPTIONS" value="" />
66
<option name="PARENT_ENVS" value="true" />
@@ -9,11 +9,11 @@
99
</envs>
1010
<option name="SDK_HOME" value="" />
1111
<option name="SDK_NAME" value="movie_decoding" />
12-
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
12+
<option name="WORKING_DIRECTORY" value="" />
1313
<option name="IS_MODULE_SDK" value="false" />
1414
<option name="ADD_CONTENT_ROOTS" value="true" />
1515
<option name="ADD_SOURCE_ROOTS" value="true" />
16-
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/src/movie_decoding/main.py" />
16+
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/src/brain_decoding/main.py" />
1717
<option name="PARAMETERS" value="" />
1818
<option name="SHOW_COMMAND_LINE" value="false" />
1919
<option name="EMULATE_TERMINAL" value="false" />

.run/save_config.run.xml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
<component name="ProjectRunConfigurationManager">
2+
<configuration default="false" name="save_config" type="PythonConfigurationType" factoryName="Python">
3+
<module name="brain_decoding" />
4+
<option name="ENV_FILES" value="" />
5+
<option name="INTERPRETER_OPTIONS" value="" />
6+
<option name="PARENT_ENVS" value="true" />
7+
<envs>
8+
<env name="PYTHONUNBUFFERED" value="1" />
9+
</envs>
10+
<option name="SDK_HOME" value="" />
11+
<option name="SDK_NAME" value="movie_decoding" />
12+
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
13+
<option name="IS_MODULE_SDK" value="false" />
14+
<option name="ADD_CONTENT_ROOTS" value="true" />
15+
<option name="ADD_SOURCE_ROOTS" value="true" />
16+
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/scripts/save_config.py" />
17+
<option name="PARAMETERS" value="" />
18+
<option name="SHOW_COMMAND_LINE" value="false" />
19+
<option name="EMULATE_TERMINAL" value="false" />
20+
<option name="MODULE_MODE" value="false" />
21+
<option name="REDIRECT_INPUT" value="false" />
22+
<option name="INPUT_FILE" value="" />
23+
<method v="2" />
24+
</configuration>
25+
</component>

scripts/save_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
config.model.weight_decay = 1e-4
1717
config.model.epochs = 5
1818
config.model.lr_drop = 50
19-
config.model.validation_step = 25
19+
config.model.validation_step = 2
2020
config.model.early_stop = 75
2121
config.model.num_labels = 8
2222
config.model.merge_label = True

src/brain_decoding/dataloader/free_recall.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,47 +26,49 @@ def __init__(self, config):
2626
self.lfp_channel_by_region = {}
2727

2828
spikes_data = None
29-
if self.config["use_spike"]:
30-
if self.config["use_sleep"]:
31-
config["spike_data_mode_inference"] = ""
32-
spikes_data = self.read_recording_data("spike_path", "time_sleep", "")
29+
if self.config.experiment["use_spike"]:
30+
data_path = "spike_path"
31+
if self.config.experiment["use_sleep"]:
32+
config.experiment["spike_data_mode_inference"] = ""
33+
spikes_data = self.read_recording_data(data_path, "time_sleep", "")
3334
else:
34-
if isinstance(self.config["free_recall_phase"], str) and "all" in self.config["free_recall_phase"]:
35-
if self.config["patient"] == "i728":
36-
phases = ["FR1a", "FR1b"]
37-
else:
38-
# phases = ["FR1", "FR2"]
39-
phases = ["FR1"]
35+
if (
36+
isinstance(self.config.experiment["free_recall_phase"], str)
37+
and "all" in self.config.experiment["free_recall_phase"]
38+
):
39+
phases = ["FR1"]
4040
for phase in phases:
41-
spikes_data = self.read_recording_data("spike_path", "time_recall", phase)
41+
spikes_data = self.read_recording_data(data_path, "time_recall", phase)
4242
elif (
43-
isinstance(self.config["free_recall_phase"], str) and "control" in self.config["free_recall_phase"]
43+
isinstance(self.config.experiment["free_recall_phase"], str)
44+
and "control" in self.config.experiment["free_recall_phase"]
45+
):
46+
spikes_data = self.read_recording_data(data_path, "time", None)
47+
elif (
48+
isinstance(self.config.experiment["free_recall_phase"], str)
49+
and "movie" in self.config.experiment["free_recall_phase"]
4450
):
45-
spikes_data = self.read_recording_data("spike_path", "time", None)
46-
elif isinstance(self.config["free_recall_phase"], str) and "movie" in self.config["free_recall_phase"]:
47-
spikes_data = self.read_recording_data("spike_path", "time", None)
51+
spikes_data = self.read_recording_data(data_path, "time", None)
4852
else:
49-
spikes_data = self.read_recording_data("spike_path", "time_recall", None)
53+
spikes_data = self.read_recording_data(data_path, "time_recall", None)
5054

5155
lfp_data = None
52-
if self.config["use_lfp"]:
53-
if self.use_sleep:
56+
if self.config.experiment["use_lfp"]:
57+
data_path = "lfp_path"
58+
if self.config.experiment.use_sleep:
5459
config["spike_data_mode_inference"] = ""
55-
lfp_data = self.read_recording_data("lfp_path", "spectrogram_sleep", "")
60+
lfp_data = self.read_recording_data(data_path, "spectrogram_sleep", "")
5661
else:
5762
if isinstance(self.config["free_recall_phase"], str) and "all" in self.config["free_recall_phase"]:
58-
if self.config["patient"] == "i728":
59-
phases = [1, 3]
60-
else:
61-
phases = [1, 2]
63+
phases = [1, 2]
6264
for phase in phases:
63-
lfp_data = self.read_recording_data("lfp_path", "spectrogram_recall", phase)
65+
lfp_data = self.read_recording_data(data_path, "spectrogram_recall", phase)
6466
elif (
6567
isinstance(self.config["free_recall_phase"], str) and "control" in self.config["free_recall_phase"]
6668
):
67-
lfp_data = self.read_recording_data("lfp_path", "spectrogram", None)
69+
lfp_data = self.read_recording_data(data_path, "spectrogram", None)
6870
else:
69-
lfp_data = self.read_recording_data("lfp_path", "spectrogram_recall", None)
71+
lfp_data = self.read_recording_data(data_path, "spectrogram_recall", None)
7072
# self.lfp_data = {key: np.concatenate(value_list, axis=0) for key, value_list in self.lfp_data.items()}
7173

7274
self.data = {"clusterless": spikes_data, "lfp": lfp_data}
@@ -85,14 +87,12 @@ def read_recording_data(self, root_path: str, file_path_prefix: str, phase: Opti
8587
if phase == "":
8688
exp_file_path = file_path_prefix
8789
else:
88-
if phase is None:
89-
phase = self.config["free_recall_phase"]
9090
exp_file_path = f"{file_path_prefix}_{phase}"
9191

9292
recording_file_path = os.path.join(
93-
self.config[root_path],
94-
self.config["patient"],
95-
self.config["spike_data_mode_inference"],
93+
self.config.data[root_path],
94+
str(self.config.experiment["patient"]),
95+
self.config.data["spike_data_mode_inference"],
9696
exp_file_path,
9797
)
9898
recording_files = glob.glob(os.path.join(recording_file_path, "*.npz"))
@@ -146,7 +146,7 @@ def load_clustless(self, files) -> np.ndarray[float]:
146146
# spike[spike < self.spike_data_sd] = 0
147147
# vmax, vmin = self.channel_max(spike)
148148
# normalized_spike = 2 * (spike - vmin[None, None, :, None]) / (vmax[None, None, :, None] - vmin[None, None, :, None]) - 1
149-
spike[spike < self.config["spike_data_sd_inference"]] = 0
149+
spike[spike < self.config.data["spike_data_sd_inference"]] = 0
150150
# spike[spike > 500] = 0
151151
vmax = np.max(spike)
152152
normalized_spike = spike / vmax
@@ -270,7 +270,7 @@ def load_pickle(self, fn):
270270
return lookup
271271

272272
def preprocess_data(self):
273-
if self.config["use_combined"]:
273+
if self.config.experiment["use_combined"]:
274274
assert self.data["clusterless"].shape[0] == self.data["lfp"].shape[0]
275275

276276
# self.label = np.array(self.ml_label).transpose()[:length, :].astype(np.float32)
@@ -352,8 +352,8 @@ def create_inference_combined_loaders(
352352
# np.random.seed(seed)
353353
np.random.shuffle(all_indices)
354354

355-
spike_inference = dataset.data["clusterless"][all_indices] if config["use_spike"] else None
356-
lfp_inference = dataset.data["lfp"][all_indices] if config["use_lfp"] else None
355+
spike_inference = dataset.data["clusterless"][all_indices] if config.experiment["use_spike"] else None
356+
lfp_inference = dataset.data["lfp"][all_indices] if config.experiment["use_lfp"] else None
357357

358358
# label_inference = dataset.smoothed_label[all_indices]
359359
label_inference = None

src/brain_decoding/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def pipeline(config: PipelineConfig) -> Trainer:
8787

8888
if __name__ == "__main__":
8989
patient = 562
90-
config_file = CONFIG_FILE_PATH / "config_test-None-None_2024-10-02-13:10:10.yaml"
90+
config_file = CONFIG_FILE_PATH / "config_test-None-None_2024-10-02-17:31:47.yaml"
9191

9292
config = set_config(
9393
config_file,

src/brain_decoding/trainer.py

Lines changed: 33 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def train(self, epochs, fold):
7171
best_f1 = -1
7272
self.model.train()
7373
os.makedirs(self.config.data["train_save_path"], exist_ok=True)
74+
os.makedirs(self.config.data["train_save_path"], exist_ok=True)
7475
for epoch in tqdm(range(epochs)):
7576
meter = Meter(fold)
7677

@@ -131,7 +132,7 @@ def train(self, epochs, fold):
131132
)
132133

133134
model_save_path = os.path.join(
134-
self.config["train_save_path"],
135+
self.config.data["train_save_path"],
135136
"best_weights_fold{}.tar".format(fold + 1),
136137
)
137138
torch.save(
@@ -356,15 +357,11 @@ def permutation_p(label, activation):
356357
df.to_csv(os.path.join(self.config["test_save_path"], "p_values.csv"))
357358

358359
def memory(self, epoch=-1, phase: str = "free_recall1", alongwith=[]):
359-
torch.manual_seed(self.config["seed"])
360-
np.random.seed(self.config["seed"])
361-
random.seed(self.config["seed"])
362-
self.config["free_recall_phase"] = phase
363-
if self.config["patient"] == "i728" and "1" in phase:
364-
self.config["free_recall_phase"] = "free_recall1a"
365-
dataloaders = initialize_inference_dataloaders(self.config)
366-
else:
367-
dataloaders = initialize_inference_dataloaders(self.config)
360+
torch.manual_seed(self.config.experiment["seed"])
361+
np.random.seed(self.config.experiment["seed"])
362+
random.seed(self.config.experiment["seed"])
363+
self.config.experiment["free_recall_phase"] = phase
364+
dataloaders = initialize_inference_dataloaders(self.config)
368365
model = initialize_model(self.config)
369366
# model = torch.compile(model)
370367
model = model.to(device_name)
@@ -375,79 +372,52 @@ def memory(self, epoch=-1, phase: str = "free_recall1", alongwith=[]):
375372

376373
# load the model with best F1-score
377374
# model_dir = os.path.join(self.config['train_save_path'], 'best_weights_fold{}.tar'.format(fold + 1))
378-
model_dir = os.path.join(self.config["train_save_path"], "model_weights_epoch{}.tar".format(epoch))
375+
model_dir = os.path.join(self.config.data["train_save_path"], "model_weights_epoch{}.tar".format(epoch))
379376
model.load_state_dict(torch.load(model_dir)["model_state_dict"])
380377
# print('Resume model: %s' % model_dir)
381378
model.eval()
382379

383-
predictions_all = np.empty((0, self.config["num_labels"]))
380+
predictions_all = np.empty((0, self.config.model["num_labels"]))
384381
predictions_length = {}
385382
with torch.no_grad():
386-
if self.config["patient"] == "i728" and "1" in phase:
387-
# load the best epoch number from the saved "model_results" structure
388-
for ph in ["FR1a", "FR1b"]:
389-
predictions = np.empty((0, self.config["num_labels"]))
390-
self.config["free_recall_phase"] = ph
391-
dataloaders = initialize_inference_dataloaders(self.config)
392-
# y_true = np.empty((0, self.config['num_labels']))
393-
for i, (feature, index) in enumerate(dataloaders["inference"]):
394-
# target = target.to(self.device)
395-
spike, lfp = self.extract_feature(feature)
396-
# forward pass
397-
398-
# start_time = time.time()
399-
spike_emb, lfp_emb, output = model(lfp, spike)
400-
# end_time = time.time()
401-
# print('inference time: ', end_time - start_time)
402-
output = torch.sigmoid(output)
403-
pred = output.cpu().detach().numpy()
404-
predictions = np.concatenate([predictions, pred], axis=0)
405-
406-
if self.config["use_overlap"]:
407-
fake_activation = np.mean(predictions, axis=0)
408-
predictions = np.vstack((fake_activation, predictions, fake_activation))
409-
410-
predictions_all = np.concatenate([predictions_all, predictions], axis=0)
411-
predictions_length[phase] = len(predictions_all)
412-
else:
413-
self.config["free_recall_phase"] = phase
414-
dataloaders = initialize_inference_dataloaders(self.config)
415-
predictions = np.empty((0, self.config["num_labels"]))
416-
# y_true = np.empty((0, self.config['num_labels']))
417-
for i, (feature, index) in enumerate(dataloaders["inference"]):
418-
# target = target.to(self.device)
419-
spike, lfp = self.extract_feature(feature)
420-
# forward pass
383+
self.config.experiment["free_recall_phase"] = phase
384+
dataloaders = initialize_inference_dataloaders(self.config)
385+
predictions = np.empty((0, self.config.model["num_labels"]))
386+
# y_true = np.empty((0, self.config['num_labels']))
387+
for i, (feature, index) in enumerate(dataloaders["inference"]):
388+
# target = target.to(self.device)
389+
spike, lfp = self.extract_feature(feature)
390+
# forward pass
421391

422-
# start_time = time.time()
423-
spike_emb, lfp_emb, output = model(lfp, spike)
424-
# end_time = time.time()
425-
# print('inference time: ', end_time - start_time)
426-
output = torch.sigmoid(output)
427-
pred = output.cpu().detach().numpy()
428-
predictions = np.concatenate([predictions, pred], axis=0)
392+
# start_time = time.time()
393+
spike_emb, lfp_emb, output = model(lfp, spike)
394+
# end_time = time.time()
395+
# print('inference time: ', end_time - start_time)
396+
output = torch.sigmoid(output)
397+
pred = output.cpu().detach().numpy()
398+
predictions = np.concatenate([predictions, pred], axis=0)
429399

430-
if self.config["use_overlap"]:
431-
fake_activation = np.mean(predictions, axis=0)
432-
predictions = np.vstack((fake_activation, predictions, fake_activation))
400+
if self.config.experiment["use_overlap"]:
401+
fake_activation = np.mean(predictions, axis=0)
402+
predictions = np.vstack((fake_activation, predictions, fake_activation))
433403

434-
predictions_length[phase] = len(predictions)
435-
predictions_all = np.concatenate([predictions_all, predictions], axis=0)
404+
predictions_length[phase] = len(predictions)
405+
predictions_all = np.concatenate([predictions_all, predictions], axis=0)
436406

437407
# np.save(os.path.join(self.config['memory_save_path'], 'free_recall_{}_results.npy'.format(phase)), predictions)
438-
save_path = os.path.join(self.config["memory_save_path"], "prediction")
408+
save_path = os.path.join(self.config.data["memory_save_path"], "prediction")
439409
os.makedirs(save_path, exist_ok=True)
440410
np.save(
441411
os.path.join(save_path, "epoch{}_free_recall_{}_results.npy".format(epoch, phase)),
442412
predictions_all,
443413
)
444414

445415
for ph in alongwith:
446-
self.config["free_recall_phase"] = ph
416+
self.config.experiment["free_recall_phase"] = ph
447417
dataloaders = initialize_inference_dataloaders(self.config)
448418
with torch.no_grad():
449419
# load the best epoch number from the saved "model_results" structure
450-
predictions = np.empty((0, self.config["num_labels"]))
420+
predictions = np.empty((0, self.config.model["num_labels"]))
451421
# y_true = np.empty((0, self.config['num_labels']))
452422
for i, (feature, index) in enumerate(dataloaders["inference"]):
453423
# target = target.to(self.device)
@@ -462,7 +432,7 @@ def memory(self, epoch=-1, phase: str = "free_recall1", alongwith=[]):
462432
pred = output.cpu().detach().numpy()
463433
predictions = np.concatenate([predictions, pred], axis=0)
464434

465-
if self.config["use_overlap"]:
435+
if self.config.experiment["use_overlap"]:
466436
fake_activation = np.mean(predictions, axis=0)
467437
predictions = np.vstack((fake_activation, predictions, fake_activation))
468438

src/brain_decoding/utils/initializer.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,23 +58,23 @@ def initialize_configs(architecture) -> Dict:
5858
return args
5959

6060

61-
def initialize_inference_dataloaders(config):
62-
if config["use_sleep"]:
61+
def initialize_inference_dataloaders(config: PipelineConfig):
62+
if config.experiment["use_sleep"]:
6363
dataset = InferenceDataset(
64-
config["data_path"],
65-
config["patient"],
66-
config["use_lfp"],
67-
config["use_spike"],
68-
config["use_bipolar"],
69-
config["use_sleep"],
70-
config["free_recall_phase"],
71-
config["hour"],
64+
config.data["data_path"],
65+
config.experiment["patient"],
66+
config.experiment["use_lfp"],
67+
config.experiment["use_spike"],
68+
config.experiment["use_bipolar"],
69+
config.experiment["use_sleep"],
70+
config.experiment["free_recall_phase"],
71+
config.experiment["hour"],
7272
)
7373
else:
7474
dataset = InferenceDataset(config)
7575

76-
LFP_CHANNEL[config["patient"]] = dataset.lfp_channel_by_region
77-
test_loader = create_inference_combined_loaders(dataset, config, batch_size=config["batch_size"])
76+
LFP_CHANNEL[config.experiment["patient"]] = dataset.lfp_channel_by_region
77+
test_loader = create_inference_combined_loaders(dataset, config, batch_size=config.model["batch_size"])
7878

7979
dataloaders = {"train": None, "valid": None, "inference": test_loader}
8080
return dataloaders

0 commit comments

Comments
 (0)