|
6 | 6 | from brain_decoding.config.config import ExperimentConfig, PipelineConfig |
7 | 7 | from brain_decoding.config.file_path import CONFIG_FILE_PATH, DATA_PATH, RESULT_PATH |
8 | 8 |
|
9 | | -if __name__ == "__main__": |
10 | | - experiment_config = ExperimentConfig(name="sleep", patient=562) |
| 9 | +# if __name__ == "__main__": |
| 10 | +experiment_config = ExperimentConfig(name="sleep", patient=562) |
11 | 11 |
|
12 | | - config = PipelineConfig(experiment=experiment_config) |
13 | | - config.model.architecture = "multi-vit" |
14 | | - config.model.learning_rate = 1e-4 |
15 | | - config.model.batch_size = 128 |
16 | | - config.model.weight_decay = 1e-4 |
17 | | - config.model.epochs = 40 |
18 | | - config.model.lr_drop = 50 |
19 | | - config.model.validation_step = 10 |
20 | | - config.model.early_stop = 75 |
21 | | - config.model.num_labels = 8 |
22 | | - config.model.merge_label = True |
23 | | - config.model.img_embedding_size = 192 |
24 | | - config.model.hidden_size = 256 |
25 | | - config.model.num_hidden_layers = 6 |
26 | | - config.model.num_attention_heads = 8 |
27 | | - config.model.patch_size = (1, 5) |
28 | | - config.model.intermediate_size = 192 * 2 |
29 | | - config.model.classifier_proj_size = 192 |
| 12 | +config = PipelineConfig(experiment=experiment_config) |
| 13 | +config.model.architecture = "multi-vit" |
| 14 | +config.model.learning_rate = 1e-4 |
| 15 | +config.model.batch_size = 128 |
| 16 | +config.model.weight_decay = 1e-4 |
| 17 | +config.model.epochs = 40 |
| 18 | +config.model.lr_drop = 50 |
| 19 | +config.model.validation_step = 10 |
| 20 | +config.model.early_stop = 75 |
| 21 | +config.model.num_labels = 8 |
| 22 | +config.model.merge_label = True |
| 23 | +config.model.img_embedding_size = 192 |
| 24 | +config.model.hidden_size = 256 |
| 25 | +config.model.num_hidden_layers = 6 |
| 26 | +config.model.num_attention_heads = 8 |
| 27 | +config.model.patch_size = (1, 5) |
| 28 | +config.model.intermediate_size = 192 * 2 |
| 29 | +config.model.classifier_proj_size = 192 |
30 | 30 |
|
31 | | - config.experiment.seed = 42 |
32 | | - config.experiment.use_spike = True |
33 | | - config.experiment.use_lfp = False |
34 | | - config.experiment.use_combined = False |
35 | | - config.experiment.use_shuffle = True |
36 | | - config.experiment.use_bipolar = False |
37 | | - config.experiment.use_sleep = ( |
38 | | - True # set true to use sleep data as inference dataset, otherwise use free recall, is this right? |
39 | | - ) |
40 | | - config.experiment.use_overlap = False |
41 | | - config.experiment.use_long_input = False |
42 | | - config.experiment.use_spontaneous = False |
43 | | - config.experiment.use_augment = False |
44 | | - config.experiment.use_shuffle_diagnostic = True |
45 | | - config.experiment.testing_mode = True # in testing mode, a maximum of 1e4 clusterless data will be loaded. |
46 | | - config.experiment.model_aggregate_type = "sum" |
47 | | - config.experiment.train_phase = ["movie_1"] |
48 | | - config.experiment.test_phase = ["sleep_2"] |
| 31 | +config.experiment.seed = 42 |
| 32 | +config.experiment.use_spike = True |
| 33 | +config.experiment.use_lfp = False |
| 34 | +config.experiment.use_combined = False |
| 35 | +config.experiment.use_shuffle = True |
| 36 | +config.experiment.use_bipolar = False |
| 37 | +# config.experiment.use_sleep = ( |
| 38 | +# True # set true to use sleep data as inference dataset, otherwise use free recall, is this right? |
| 39 | +# ) |
| 40 | +config.experiment.use_overlap = False |
| 41 | +config.experiment.use_long_input = False |
| 42 | +config.experiment.use_spontaneous = False |
| 43 | +config.experiment.use_augment = False |
| 44 | +config.experiment.use_shuffle_diagnostic = True |
| 45 | +config.experiment.testing_mode = False # in testing mode, a maximum of 1e4 clusterless data will be loaded. |
| 46 | +config.experiment.model_aggregate_type = "sum" |
| 47 | +config.experiment.train_phases = ["movie_1"] |
| 48 | +config.experiment.test_phases = ["sleep_2"] |
| 49 | +config.experiment.compute_accuracy = False |
49 | 50 |
|
50 | | - config.data.result_path = str(RESULT_PATH) |
51 | | - config.data.spike_path = str(DATA_PATH) |
52 | | - config.data.lfp_path = "undefined" |
53 | | - config.data.lfp_data_mode = "sf2000-bipolar-region-clean" |
54 | | - config.data.spike_data_mode = "notch CAR-quant-neg" |
55 | | - config.data.spike_data_mode_inference = "notch CAR-quant-neg" |
56 | | - config.data.spike_data_sd = [3.5] |
57 | | - config.data.spike_data_sd_inference = 3.5 |
58 | | - config.data.model_aggregate_type = "sum" |
59 | | - config.data.movie_label_path = str(DATA_PATH / "8concepts_merged.npy") |
60 | | - config.data.movie_sampling_rate = 30 |
| 51 | +config.experiment.ensure_list("train_phases") |
| 52 | +config.experiment.ensure_list("test_phases") |
61 | 53 |
|
62 | | - config.export_config(CONFIG_FILE_PATH) |
| 54 | +config.data.result_path = str(RESULT_PATH) |
| 55 | +config.data.spike_path = str(DATA_PATH) |
| 56 | +config.data.lfp_path = "undefined" |
| 57 | +config.data.lfp_data_mode = "sf2000-bipolar-region-clean" |
| 58 | +config.data.spike_data_mode = "notch CAR-quant-neg" |
| 59 | +config.data.spike_data_mode_inference = "notch CAR-quant-neg" |
| 60 | +config.data.spike_data_sd = [3.5] |
| 61 | +config.data.spike_data_sd_inference = 3.5 |
| 62 | +config.data.model_aggregate_type = "sum" |
| 63 | +config.data.movie_label_path = str(DATA_PATH / "8concepts_merged.npy") |
| 64 | +config.data.movie_sampling_rate = 30 |
| 65 | + |
| 66 | +# config.export_config(CONFIG_FILE_PATH) |
0 commit comments