Skip to content

Commit 9baa70b

Browse files
authored
Merge pull request #3 from NxNiki/dev
Add config
2 parents 77fa1d0 + c43fe46 commit 9baa70b

File tree

15 files changed

+447
-333
lines changed

15 files changed

+447
-333
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,6 @@ src/movie_decoding/__pycache__/
3434

3535
data/
3636
._data
37+
config/*.yaml
3738
results/
3839
wandb/

scripts/save_config.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""
2+
This script is used to define the basic config parameters for a movie decoding project.
3+
Custom parameters can be added to any of the three fields of config (experiment, model, data).
4+
"""
5+
6+
from movie_decoding.config.config import ExperimentConfig, PipelineConfig
7+
from movie_decoding.config.file_path import CONFIG_FILE_PATH, DATA_PATH, RESULT_PATH
8+
9+
if __name__ == "__main__":
10+
experiment_config = ExperimentConfig(name="test", patient=562)
11+
12+
config = PipelineConfig(experiment=experiment_config)
13+
config.model.architecture = "multi-vit"
14+
config.model.learning_rate = 1e-4
15+
config.model.batch_size = 128
16+
config.model.weight_decay = 1e-4
17+
config.model.epochs = 5
18+
config.model.lr_drop = 50
19+
config.model.validation_step = 25
20+
config.model.early_stop = 75
21+
config.model.num_labels = 8
22+
config.model.merge_label = True
23+
config.model.img_embedding_size = 192
24+
config.model.hidden_size = 256
25+
config.model.num_hidden_layers = 6
26+
config.model.num_attention_heads = 8
27+
config.model.patch_size = (1, 5)
28+
config.model.intermediate_size = 192 * 2
29+
config.model.classifier_proj_size = 192
30+
31+
config.experiment.seed = 42
32+
config.experiment.use_spike = True
33+
config.experiment.use_lfp = False
34+
config.experiment.use_combined = False
35+
config.experiment.use_shuffle = True
36+
config.experiment.use_bipolar = False
37+
config.experiment.use_sleep = False
38+
config.experiment.use_overlap = False
39+
config.experiment.use_long_input = False
40+
config.experiment.use_spontaneous = False
41+
config.experiment.use_augment = False
42+
config.experiment.use_shuffle_diagnostic = True
43+
config.experiment.model_aggregate_type = "sum"
44+
45+
config.data.result_path = str(RESULT_PATH)
46+
config.data.spike_path = str(DATA_PATH)
47+
config.data.lfp_path = "undefined"
48+
config.data.lfp_data_mode = "sf2000-bipolar-region-clean"
49+
config.data.spike_data_mode = "notch CAR-quant-neg"
50+
config.data.spike_data_mode_inference = "notch CAR-quant-neg"
51+
config.data.spike_data_sd = [3.5]
52+
config.data.spike_data_sd_inference = 3.5
53+
config.data.use_augment = False
54+
config.data.use_long_input = False
55+
config.data.use_shuffle_diagnostic = False
56+
config.data.model_aggregate_type = "sum"
57+
config.data.movie_label_path = str(DATA_PATH / "8concepts_merged.npy")
58+
config.data.movie_sampling_rate = 30
59+
60+
config.export_config(CONFIG_FILE_PATH)

src/movie_decoding/dataloader/save_patients.py renamed to scripts/save_patients.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,9 @@
55
import numpy as np
66
import pandas as pd
77

8+
from movie_decoding.config.file_path import PATIENTS_FILE_PATH, SURROGATE_FILE_PATH
89
from movie_decoding.dataloader.patients import Patients
910

10-
PATIENTS_FILE_PATH = Path(__file__).resolve().parents[3] / "data/patients"
11-
SURROGATE_FILE_PATH = Path(__file__).resolve().parents[3] / "data/surrogate_windows"
12-
1311

1412
def read_annotation(annotation_file: str) -> List[int]:
1513
"""

src/movie_decoding/batch_main.sh

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
### common_average_job.sh START ###
21
#!/bin/bash
32
#$ -cwd
43
# error = Merged with joblog
@@ -30,7 +29,7 @@ conda activate movie_decoding
3029

3130
# in the following two lines substitute the command with the
3231
# needed command below:
33-
python main.py
32+
python main.py
3433

3534
# echo job info on joblog:
3635
echo "Job $JOB_ID ended on: " `hostname -s`
@@ -39,6 +38,3 @@ echo " "
3938
### extract_clusterless_parallel.job STOP ###
4039
# this site shows how to do array jobs: https://info.hpc.sussex.ac.uk/hpc-guide/how-to/array.html
4140
# (better than the Hoffman site https://www.hoffman2.idre.ucla.edu/Using-H2/Computing/Computing.html#how-to-build-a-submission-script)
42-
43-
44-

src/movie_decoding/config/__init__.py

Whitespace-only changes.
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
from datetime import datetime
2+
from pathlib import Path
3+
from typing import Any, Dict, List, Optional, Tuple, Union
4+
5+
import yaml
6+
from pydantic import BaseModel, Field
7+
8+
9+
class BaseConfig(BaseModel):
10+
alias: Dict[str, str] = {}
11+
param: Dict[str, Any] = {}
12+
13+
def __getitem__(self, key: str) -> Any:
14+
if key in self.param:
15+
return self.param[key]
16+
return getattr(self, key)
17+
18+
def __setitem__(self, key: str, value: Any):
19+
if key in self.model_fields:
20+
setattr(self, key, value)
21+
else:
22+
self.param[key] = value
23+
24+
def __getattr__(self, name):
25+
"""Handles alias access and custom parameters."""
26+
if name in self.alias:
27+
return getattr(self, self.alias[name])
28+
if name in self.param:
29+
return self.param[name]
30+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
31+
32+
def __setattr__(self, name, value):
33+
"""Handles alias assignment, field setting, or adding to _param."""
34+
if name in self.alias:
35+
name = self.alias[name]
36+
37+
# Check if it's a field defined in the model
38+
if name in self.model_fields:
39+
super().__setattr__(name, value)
40+
else:
41+
# Otherwise, treat it as a custom parameter
42+
self.param[name] = value
43+
44+
def __contains__(self, key: str) -> bool:
45+
return key in self.param or hasattr(self, key)
46+
47+
48+
class ExperimentConfig(BaseConfig):
49+
"""
50+
configurations regarding the experiment
51+
"""
52+
53+
name: Optional[str] = None
54+
patient: Optional[Union[List[int], int]] = None
55+
56+
57+
class ModelConfig(BaseConfig):
58+
name: Optional[str] = None
59+
learning_rate: Optional[float] = Field(1e-4, alias="lr")
60+
learning_rate_drop: Optional[int] = Field(50, alias="lr_drop")
61+
batch_size: Optional[int] = 128
62+
epochs: Optional[int] = 100
63+
hidden_size: Optional[int] = 192
64+
num_hidden_layers: Optional[int] = 4
65+
num_attention_heads: Optional[int] = 6
66+
patch_size: Optional[Tuple[int, int]] = None
67+
68+
alias: Dict[str, str] = {
69+
"lr": "learning_rate",
70+
"lr_drop": "learning_rate_drop",
71+
}
72+
73+
74+
class DataConfig(BaseConfig):
75+
data_type: Optional[str] = None
76+
sd: Optional[float] = None
77+
root_path: Optional[Union[str, Path]] = None
78+
data_path: Optional[Union[str, Path]] = None
79+
80+
81+
class PipelineConfig(BaseModel):
82+
experiment: Optional[ExperimentConfig] = ExperimentConfig()
83+
model: Optional[ModelConfig] = ModelConfig()
84+
data: Optional[DataConfig] = DataConfig()
85+
86+
# class Config:
87+
# arbitrary_types_allowed = True
88+
89+
@classmethod
90+
def read_config(cls, config_file: Union[str, Path]) -> "PipelineConfig":
91+
"""Reads a YAML configuration file and returns an instance of PipelineConfig."""
92+
with open(config_file, "r") as file:
93+
config_dict = yaml.safe_load(file)
94+
return cls(**config_dict)
95+
96+
def export_config(self, output_file: Union[str, Path] = "config.yaml") -> None:
97+
"""Exports current properties to a YAML configuration file."""
98+
if isinstance(output_file, str):
99+
output_file = Path(output_file)
100+
101+
if not output_file.suffix:
102+
output_file = output_file / "config.yaml"
103+
104+
# Create new path with the suffix added before the extension
105+
output_file = output_file.with_name(f"{output_file.stem}{self._file_tag}{output_file.suffix}")
106+
107+
dir_path = output_file.parent
108+
dir_path.mkdir(parents=True, exist_ok=True)
109+
110+
with open(output_file, "w") as file:
111+
yaml.safe_dump(self.model_dump(), file)
112+
113+
@property
114+
def _file_tag(self) -> str:
115+
current_time = datetime.now()
116+
formatted_time = current_time.strftime("%Y-%m-%d-%H:%M:%S")
117+
return f"_{self.experiment.name}-{self.model.name}-{self.data.data_type}_{formatted_time}"
118+
119+
120+
if __name__ == "__main__":
121+
pipeline_config = PipelineConfig()
122+
pipeline_config.model.name = "vit"
123+
pipeline_config.model.learning_rate = 0.001
124+
pipeline_config.experiment.name = "movie-decoding"
125+
126+
# Access and print properties
127+
print(f"Experiment Name: {pipeline_config.experiment.name}")
128+
print(f"Patient ID: {pipeline_config.experiment.patient}")
129+
print(f"Model Name: {pipeline_config.model.name}")
130+
print(f"Learning Rate: {pipeline_config.model.learning_rate}")
131+
print(f"Batch Size: {pipeline_config.model.batch_size}")
132+
133+
# Access using aliases
134+
print(f"Learning Rate (alias 'lr'): {pipeline_config.model['lr']}")
135+
print(f"Learning Rate (alias 'lr'): {pipeline_config.model.lr}")
136+
137+
# Set new custom parameters
138+
pipeline_config.model["new_param"] = "custom_value"
139+
print(f"Custom Parameter 'new_param': {pipeline_config.model['new_param']}")
140+
pipeline_config.model.new_param2 = "custom_value"
141+
print(f"Custom Parameter 'new_param2': {pipeline_config.model.new_param2}")
142+
143+
# Try to access a non-existent field (will raise AttributeError)
144+
try:
145+
print(pipeline_config.model.some_non_existent_field)
146+
except AttributeError as e:
147+
print(e)
148+
149+
# Export config:
150+
pipeline_config.export_config()
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from pathlib import Path
2+
3+
ROOT_PATH = Path(__file__).resolve().parents[3]
4+
DATA_PATH = ROOT_PATH / "data"
5+
PATIENTS_FILE_PATH = ROOT_PATH / "data/patients"
6+
SURROGATE_FILE_PATH = ROOT_PATH / "data/surrogate_windows"
7+
CONFIG_FILE_PATH = ROOT_PATH / "config"
8+
RESULT_PATH = ROOT_PATH / "results"

0 commit comments

Comments
 (0)