Skip to content

Commit ebf6ab4

Browse files
authored
Merge pull request #11 from NxNiki/dev
check how to lost whole sleep data.
2 parents aeaadbc + 1dc9521 commit ebf6ab4

File tree

10 files changed

+567
-182
lines changed

10 files changed

+567
-182
lines changed

scripts/plot_activation.ipynb

Lines changed: 143 additions & 38 deletions
Large diffs are not rendered by default.

scripts/save_config.py

Lines changed: 54 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6,57 +6,61 @@
66
from brain_decoding.config.config import ExperimentConfig, PipelineConfig
77
from brain_decoding.config.file_path import CONFIG_FILE_PATH, DATA_PATH, RESULT_PATH
88

9-
if __name__ == "__main__":
10-
experiment_config = ExperimentConfig(name="sleep", patient=562)
9+
# if __name__ == "__main__":
10+
experiment_config = ExperimentConfig(name="sleep", patient=562)
1111

12-
config = PipelineConfig(experiment=experiment_config)
13-
config.model.architecture = "multi-vit"
14-
config.model.learning_rate = 1e-4
15-
config.model.batch_size = 128
16-
config.model.weight_decay = 1e-4
17-
config.model.epochs = 40
18-
config.model.lr_drop = 50
19-
config.model.validation_step = 10
20-
config.model.early_stop = 75
21-
config.model.num_labels = 8
22-
config.model.merge_label = True
23-
config.model.img_embedding_size = 192
24-
config.model.hidden_size = 256
25-
config.model.num_hidden_layers = 6
26-
config.model.num_attention_heads = 8
27-
config.model.patch_size = (1, 5)
28-
config.model.intermediate_size = 192 * 2
29-
config.model.classifier_proj_size = 192
12+
config = PipelineConfig(experiment=experiment_config)
13+
config.model.architecture = "multi-vit"
14+
config.model.learning_rate = 1e-4
15+
config.model.batch_size = 128
16+
config.model.weight_decay = 1e-4
17+
config.model.epochs = 40
18+
config.model.lr_drop = 50
19+
config.model.validation_step = 10
20+
config.model.early_stop = 75
21+
config.model.num_labels = 8
22+
config.model.merge_label = True
23+
config.model.img_embedding_size = 192
24+
config.model.hidden_size = 256
25+
config.model.num_hidden_layers = 6
26+
config.model.num_attention_heads = 8
27+
config.model.patch_size = (1, 5)
28+
config.model.intermediate_size = 192 * 2
29+
config.model.classifier_proj_size = 192
3030

31-
config.experiment.seed = 42
32-
config.experiment.use_spike = True
33-
config.experiment.use_lfp = False
34-
config.experiment.use_combined = False
35-
config.experiment.use_shuffle = True
36-
config.experiment.use_bipolar = False
37-
config.experiment.use_sleep = (
38-
True # set true to use sleep data as inference dataset, otherwise use free recall, is this right?
39-
)
40-
config.experiment.use_overlap = False
41-
config.experiment.use_long_input = False
42-
config.experiment.use_spontaneous = False
43-
config.experiment.use_augment = False
44-
config.experiment.use_shuffle_diagnostic = True
45-
config.experiment.testing_mode = True # in testing mode, a maximum of 1e4 clusterless data will be loaded.
46-
config.experiment.model_aggregate_type = "sum"
47-
config.experiment.train_phase = ["movie_1"]
48-
config.experiment.test_phase = ["sleep_2"]
31+
config.experiment.seed = 42
32+
config.experiment.use_spike = True
33+
config.experiment.use_lfp = False
34+
config.experiment.use_combined = False
35+
config.experiment.use_shuffle = True
36+
config.experiment.use_bipolar = False
37+
# config.experiment.use_sleep = (
38+
# True # set true to use sleep data as inference dataset, otherwise use free recall, is this right?
39+
# )
40+
config.experiment.use_overlap = False
41+
config.experiment.use_long_input = False
42+
config.experiment.use_spontaneous = False
43+
config.experiment.use_augment = False
44+
config.experiment.use_shuffle_diagnostic = True
45+
config.experiment.testing_mode = False # in testing mode, a maximum of 1e4 clusterless data will be loaded.
46+
config.experiment.model_aggregate_type = "sum"
47+
config.experiment.train_phases = ["movie_1"]
48+
config.experiment.test_phases = ["sleep_2"]
49+
config.experiment.compute_accuracy = False
4950

50-
config.data.result_path = str(RESULT_PATH)
51-
config.data.spike_path = str(DATA_PATH)
52-
config.data.lfp_path = "undefined"
53-
config.data.lfp_data_mode = "sf2000-bipolar-region-clean"
54-
config.data.spike_data_mode = "notch CAR-quant-neg"
55-
config.data.spike_data_mode_inference = "notch CAR-quant-neg"
56-
config.data.spike_data_sd = [3.5]
57-
config.data.spike_data_sd_inference = 3.5
58-
config.data.model_aggregate_type = "sum"
59-
config.data.movie_label_path = str(DATA_PATH / "8concepts_merged.npy")
60-
config.data.movie_sampling_rate = 30
51+
config.experiment.ensure_list("train_phases")
52+
config.experiment.ensure_list("test_phases")
6153

62-
config.export_config(CONFIG_FILE_PATH)
54+
config.data.result_path = str(RESULT_PATH)
55+
config.data.spike_path = str(DATA_PATH)
56+
config.data.lfp_path = "undefined"
57+
config.data.lfp_data_mode = "sf2000-bipolar-region-clean"
58+
config.data.spike_data_mode = "notch CAR-quant-neg"
59+
config.data.spike_data_mode_inference = "notch CAR-quant-neg"
60+
config.data.spike_data_sd = [3.5]
61+
config.data.spike_data_sd_inference = 3.5
62+
config.data.model_aggregate_type = "sum"
63+
config.data.movie_label_path = str(DATA_PATH / "8concepts_merged.npy")
64+
config.data.movie_sampling_rate = 30
65+
66+
# config.export_config(CONFIG_FILE_PATH)

src/brain_decoding/config/config.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,19 @@ def __setattr__(self, name, value):
3737
def __contains__(self, key: str) -> bool:
3838
return hasattr(self, key)
3939

40+
def __repr__(self):
41+
attrs = {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
42+
attr_str = "\n".join(f" {key}: {value!r}" for key, value in attrs.items())
43+
return f"{self.__class__.__name__}(\n{attr_str}\n)"
44+
4045
def set_alias(self, name: str, alias: str) -> None:
4146
self.__dict__["_alias"][alias] = name
4247

4348
def ensure_list(self, name: str):
49+
"""Mark the field to always be treated as a list"""
4450
value = getattr(self, name, None)
4551
if value is not None and not isinstance(value, list):
4652
setattr(self, name, [value])
47-
# Mark the field to always be treated as a list
4853
self._list_fields.add(name)
4954

5055

@@ -112,8 +117,9 @@ def export_config(self, output_file: Union[str, Path] = "config.yaml") -> None:
112117
dir_path = output_file.parent
113118
dir_path.mkdir(parents=True, exist_ok=True)
114119

120+
config_data = self.model_dump()
115121
with open(output_file, "w") as file:
116-
yaml.safe_dump(self.model_dump(), file)
122+
yaml.safe_dump(config_data, file)
117123

118124
@property
119125
def _file_tag(self) -> str:

src/brain_decoding/dataloader/save_clusterless.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626

2727
from brain_decoding.config.file_path import DATA_PATH
2828

29+
SECONDS_PER_HOUR = 3600
30+
2931
OFFSET = {
3032
"555_1": 4.58,
3133
"562_1": 0,
@@ -115,8 +117,9 @@
115117

116118
# is there a way to select the whole duration?
117119
SLEEP_TIME = {
118-
"562_1": (0, 2 * 3600), # memory test
119-
"562_2": (0, 5 * 3600), # memory test
120+
"562_1": (0, 2 * SECONDS_PER_HOUR), # memory test
121+
"562_2": (0, 5 * SECONDS_PER_HOUR), # memory test
122+
"562_3": (0, 10 * SECONDS_PER_HOUR), # memory test
120123
}
121124

122125
CONTROL = {
@@ -820,7 +823,7 @@ def sort_filename(filename):
820823
if __name__ == "__main__":
821824
version = "notch CAR-quant-neg"
822825
SPIKE_ROOT_PATH = "/Users/XinNiuAdmin/Library/CloudStorage/Box-Box/Vwani_Movie/Clusterless/"
823-
get_oneshot_clean("562", 2000, "Experiment6_MovieParadigm_notch", category="sleep", phase=2, version=version)
826+
get_oneshot_clean("562", 2000, "Experiment6_MovieParadigm_notch", category="sleep", phase=3, version=version)
824827
# get_oneshot_clean("562", 2000, "presleep", category="movie", phase=1, version=version)
825828
# get_oneshot_clean("562", 2000, "presleep", category="recall", phase="FR1", version=version)
826829
# get_oneshot_clean("562", 2000, "postsleep", category="recall", phase="FR2", version=version)

src/brain_decoding/dataloader/test_data.py

Lines changed: 7 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -27,56 +27,23 @@ class InferenceDataset(Dataset):
2727
def __init__(self, config):
2828
self.config = config
2929
self.lfp_channel_by_region = {}
30-
phases = config.data.phases
30+
3131
spikes_data = None
3232
if self.config.experiment["use_spike"]:
3333
data_path = "spike_path"
34-
if self.config.experiment["use_sleep"]:
35-
config.experiment["spike_data_mode_inference"] = ""
36-
spikes_data = self.read_recording_data(data_path, "time_sleep", phases[0])
37-
else:
38-
if (
39-
isinstance(self.config.experiment["free_recall_phase"], str)
40-
and "all" in self.config.experiment["free_recall_phase"]
41-
):
42-
for phase in phases:
43-
spikes_data = self.read_recording_data(data_path, "time_recall", phase)
44-
elif (
45-
isinstance(self.config.experiment["free_recall_phase"], str)
46-
and "control" in self.config.experiment["free_recall_phase"]
47-
):
48-
spikes_data = self.read_recording_data(data_path, "time", None)
49-
elif (
50-
isinstance(self.config.experiment["free_recall_phase"], str)
51-
and "movie" in self.config.experiment["free_recall_phase"]
52-
):
53-
spikes_data = self.read_recording_data(data_path, "time", None)
54-
else:
55-
spikes_data = self.read_recording_data(data_path, "time_recall", None)
34+
spikes_data = self.read_recording_data(data_path, "time", self.config.experiment.test_phases[0])
5635

5736
lfp_data = None
5837
if self.config.experiment["use_lfp"]:
5938
data_path = "lfp_path"
60-
if self.config.experiment.use_sleep:
61-
config["spike_data_mode_inference"] = ""
62-
lfp_data = self.read_recording_data(data_path, "spectrogram_sleep", "")
63-
else:
64-
if isinstance(self.config["free_recall_phase"], str) and "all" in self.config["free_recall_phase"]:
65-
for phase in phases:
66-
lfp_data = self.read_recording_data(data_path, "spectrogram_recall", phase)
67-
elif (
68-
isinstance(self.config["free_recall_phase"], str) and "control" in self.config["free_recall_phase"]
69-
):
70-
lfp_data = self.read_recording_data(data_path, "spectrogram", None)
71-
else:
72-
lfp_data = self.read_recording_data(data_path, "spectrogram_recall", None)
39+
lfp_data = self.read_recording_data(data_path, "spectrogram_recall", self.config.experiment.test_phases[0])
7340
# self.lfp_data = {key: np.concatenate(value_list, axis=0) for key, value_list in self.lfp_data.items()}
7441

7542
self.data = {"clusterless": spikes_data, "lfp": lfp_data}
7643
self.data_length = self.get_data_length()
7744
self.preprocess_data()
7845

79-
def read_recording_data(self, root_path: str, file_path_prefix: str, phase: Optional[str]) -> np.ndarray[float]:
46+
def read_recording_data(self, root_path: str, file_path_prefix: str, phase: str) -> np.ndarray[float]:
8047
"""
8148
read spike or lfp data.
8249
@@ -85,10 +52,7 @@ def read_recording_data(self, root_path: str, file_path_prefix: str, phase: Opti
8552
:param phase:
8653
:return:
8754
"""
88-
if phase == "":
89-
exp_file_path = file_path_prefix
90-
else:
91-
exp_file_path = f"{file_path_prefix}_{phase}"
55+
exp_file_path = f"{file_path_prefix}_{phase}"
9256

9357
recording_file_path = os.path.join(
9458
self.config.data[root_path],
@@ -100,7 +64,8 @@ def read_recording_data(self, root_path: str, file_path_prefix: str, phase: Opti
10064
recording_files = sorted(recording_files, key=sort_file_name)
10165

10266
if not recording_files:
103-
raise ValueError(f"not files found in: {recording_files}")
67+
error_msg = f"not files found in: {recording_files}"
68+
raise ValueError(error_msg)
10469

10570
if root_path == "spike_path":
10671
data = self.load_clustless(recording_files)

src/brain_decoding/dataloader/train_data.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,22 +58,13 @@ def __init__(self, config: PipelineConfig):
5858
self.smoothed_label = []
5959
self.lfp_channel_by_region = {}
6060

61-
if self.patient in ["564", "565"]:
62-
categories = ["Movie_1", "Movie_2"]
63-
else:
64-
categories = ["Movie_1"]
65-
66-
if self.use_spontaneous:
67-
categories.append("Control1")
68-
categories.append("Control2")
69-
7061
# create spike data
7162
if self.use_spike:
72-
self.data["clusterless"] = self.load_data(config.data["spike_path"], categories)
63+
self.data["clusterless"] = self.load_data(config.data["spike_path"], config.experiment.train_phases)
7364

7465
# create lfp data
7566
if self.use_lfp:
76-
self.data["lfp"] = self.load_data(config.data["lfp_path"], categories)
67+
self.data["lfp"] = self.load_data(config.data["lfp_path"], config.experiment.train_phases)
7768

7869
# for c, category in enumerate(categories):
7970
# size = sample_size[c]

src/brain_decoding/main.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from brain_decoding.config.config import PipelineConfig
1818
from brain_decoding.config.file_path import CONFIG_FILE_PATH
1919
from brain_decoding.param.base_param import device
20+
from scripts.save_config import config
2021

2122
# torch.autograd.set_detect_anomaly(True)
2223
# torch.backends.cuda.matmul.allow_tf32=True
@@ -28,17 +29,19 @@
2829

2930

3031
def set_config(
31-
config_file: Union[str, Path],
32+
config_file: Union[str, Path, PipelineConfig],
3233
patient_id: int,
33-
phases: Union[List[str], str],
34+
train_phases: Union[List[str], str],
35+
test_phases: Union[List[str], str],
3436
spike_data_sd: Union[List[float], float, None] = None,
3537
spike_data_sd_inference: Optional[float] = None,
3638
) -> PipelineConfig:
3739
"""
3840
set parameters based on config file.
3941
:param config_file:
4042
:param patient_id:
41-
:param phases:
43+
:param train_phases:
44+
:param test_phases:
4245
:param spike_data_sd:
4346
:param spike_data_sd_inference:
4447
:return:
@@ -47,15 +50,18 @@ def set_config(
4750
if isinstance(spike_data_sd, float):
4851
spike_data_sd = [spike_data_sd]
4952

50-
config = PipelineConfig.read_config(config_file)
53+
if isinstance(config_file, PipelineConfig):
54+
config = config_file
55+
else:
56+
config = PipelineConfig.read_config(config_file)
5157

5258
config.experiment["patient"] = patient_id
5359
config.experiment.name = "8concepts"
5460

55-
if isinstance(phases, str):
56-
config.data.phases = [phases]
57-
else:
58-
config.data.phases = phases
61+
config.experiment.train_phases = [train_phases]
62+
63+
config.experiment.test_phases = test_phases
64+
config.experiment.ensure_list("test_phases")
5965

6066
if spike_data_sd is not None:
6167
config.data.spike_data_sd = spike_data_sd
@@ -110,13 +116,16 @@ def pipeline(config: PipelineConfig) -> Trainer:
110116

111117
if __name__ == "__main__":
112118
patient = 562
113-
phase = "2"
119+
phase_train = "movie_1"
120+
phase_test = "sleep_3"
114121
CONFIG_FILE = CONFIG_FILE_PATH / "config_sleep-None-None_2024-10-16-19:17:43.yaml"
115122

116123
config = set_config(
117-
CONFIG_FILE,
124+
# CONFIG_FILE,
125+
config,
118126
patient,
119-
phase,
127+
phase_train,
128+
phase_test,
120129
)
121130

122131
print("start: ", patient)

src/brain_decoding/trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ def train(self, epochs, fold):
144144
)
145145
print()
146146
print("WELCOME MEMORY TEST at: ", epoch)
147-
stats_m = self.memory(epoch=epoch + 1, phase=self.config.data.phases[0], alongwith=[])
148-
# self.memory(1, epoch=epoch+1, phase='all')
147+
stats_m = self.memory(epoch=epoch + 1, phase=self.config.experiment.test_phases[0], alongwith=[])
148+
149149
if stats_m is not None:
150150
overall_p = list(stats_m.values())
151151
print("P: ", overall_p)
@@ -360,7 +360,7 @@ def memory(self, epoch=-1, phase: str = "free_recall1", alongwith=[]):
360360
torch.manual_seed(self.config.experiment["seed"])
361361
np.random.seed(self.config.experiment["seed"])
362362
random.seed(self.config.experiment["seed"])
363-
self.config.experiment["free_recall_phase"] = phase
363+
# self.config.experiment["test_phase"] = phase
364364
dataloaders = initialize_inference_dataloaders(self.config)
365365
model = initialize_model(self.config)
366366
# model = torch.compile(model)
@@ -448,7 +448,7 @@ def memory(self, epoch=-1, phase: str = "free_recall1", alongwith=[]):
448448
predictions = predictions[:, 0:8]
449449

450450
# Perform Statistic Method
451-
if not self.config.experiment["use_sleep"]:
451+
if self.config.experiment["compute_accuracy"]:
452452
sts = Permutate(
453453
config=self.config,
454454
phase=phase,

0 commit comments

Comments
 (0)