Skip to content

Commit f1ce9ca

Browse files
committed
refactor check_free_recall.py
1 parent e2eb10a commit f1ce9ca

File tree

8 files changed

+505
-788
lines changed

8 files changed

+505
-788
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ statsmodels = "^0.14.2"
2424
matplotlib = "^3.9.1"
2525
seaborn = "^0.13.2"
2626
pyarrow = "^17.0.0"
27+
scipy = "^1.14.1"
2728

2829
[tool.poetry.dev-dependencies]
2930
black = "^23.0"

scripts/run_model.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import string
2+
from pathlib import Path
3+
4+
import wandb
5+
6+
from movie_decoding.main import pipeline
7+
from movie_decoding.utils.initializer import *
8+
9+
# for patient in ['562', '563', '566', 'i728', '567', '572']:
10+
patient_list = ["i728", "572", "567", "566", "563", "562"]
11+
sd_list = [4, 4, 3.5, 4, 4, 3.5]
12+
# data_list = ['notch CAR4.5', 'notch CAR3.5', 'notch CAR4.5', 'notch CAR4', 'notch CAR3.5', 'notch CAR3.5']
13+
data_list = [
14+
"notch CAR-quant-neg",
15+
"notch CAR-quant-neg",
16+
"notch CAR-quant-neg",
17+
"notch CAR-quant-neg",
18+
"notch CAR-quant-neg",
19+
"notch CAR-quant-neg",
20+
]
21+
early_stop = [100, 100, 100, 50, 50, 75]
22+
for patient, sd, dd in zip(patient_list, sd_list, data_list):
23+
print()
24+
print("start: ", patient)
25+
for data_type in ["clusterless"]:
26+
for run in range(5, 6):
27+
# root_path = os.path.dirname(os.path.abspath(__file__))
28+
root_path = Path(__file__).parent.parent
29+
# save the results
30+
letters = string.ascii_lowercase
31+
# suffix = ''.join(random.choice(letters) for i in range(3))
32+
suffix = f"test53_optimalX_CARX_{run}"
33+
if data_type == "clusterless":
34+
use_clusterless = True
35+
use_lfp = False
36+
use_combined = False
37+
model_architecture = "multi-vit" #'multi-vit'
38+
elif data_type == "lfp":
39+
use_clusterless = False
40+
use_lfp = True
41+
use_combined = False
42+
model_architecture = "multi-vit"
43+
elif data_type == "combined":
44+
use_clusterless = True
45+
use_lfp = True
46+
use_combined = True
47+
model_architecture = "multi-crossvit"
48+
else:
49+
ValueError(f"undefined data_type: {data_type}")
50+
51+
args = initialize_configs(architecture=model_architecture)
52+
args["seed"] = 42
53+
args["device"] = "cuda:1"
54+
args["patient"] = patient
55+
args["use_spike"] = use_clusterless
56+
args["use_lfp"] = use_lfp
57+
args["use_combined"] = use_combined
58+
args["use_spontaneous"] = False
59+
if use_clusterless:
60+
args["use_shuffle"] = True
61+
elif use_lfp:
62+
args["use_shuffle"] = False
63+
64+
args["use_bipolar"] = False
65+
args["use_sleep"] = False
66+
args["use_overlap"] = False
67+
args["model_architecture"] = model_architecture
68+
69+
args["spike_data_mode"] = dd
70+
args["spike_data_mode_inference"] = dd
71+
args["spike_data_sd"] = [sd]
72+
args["spike_data_sd_inference"] = sd
73+
args["use_augment"] = False
74+
args["use_long_input"] = False
75+
args["use_shuffle_diagnostic"] = False
76+
args["model_aggregate_type"] = "sum"
77+
78+
train_save_path = os.path.join(
79+
root_path,
80+
"results/8concepts/{}_{}_{}_{}/train".format(args["patient"], data_type, model_architecture, suffix),
81+
)
82+
valid_save_path = os.path.join(
83+
root_path,
84+
"results/8concepts/{}_{}_{}_{}/valid".format(args["patient"], data_type, model_architecture, suffix),
85+
)
86+
test_save_path = os.path.join(
87+
root_path,
88+
"results/8concepts/{}_{}_{}_{}/test".format(args["patient"], data_type, model_architecture, suffix),
89+
)
90+
memory_save_path = os.path.join(
91+
root_path,
92+
"results/8concepts/{}_{}_{}_{}/memory".format(args["patient"], data_type, model_architecture, suffix),
93+
)
94+
os.makedirs(train_save_path, exist_ok=True)
95+
os.makedirs(valid_save_path, exist_ok=True)
96+
os.makedirs(test_save_path, exist_ok=True)
97+
os.makedirs(memory_save_path, exist_ok=True)
98+
args["train_save_path"] = train_save_path
99+
args["valid_save_path"] = valid_save_path
100+
args["test_save_path"] = test_save_path
101+
args["memory_save_path"] = memory_save_path
102+
103+
os.environ["WANDB_MODE"] = "offline"
104+
# os.environ['WANDB_API_KEY'] = '5a6051ed615a193c44eb9f655b81703925460851'
105+
wandb.login()
106+
if use_lfp:
107+
run_name = "LFP Concept level {} MultiEncoder".format(args["patient"])
108+
else:
109+
run_name = "Clusterless Concept level {} MultiEncoder".format(args["patient"])
110+
wandb.init(project="24_Concepts", name=run_name, reinit=True, entity="24")
111+
112+
trainer = pipeline(args)
113+
114+
print("Start training")
115+
# start_time = time.time()
116+
117+
trainer.train(args["epochs"], 1)
118+
print("done: ", patient)
119+
print()
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from typing import Dict, List, Optional
2+
3+
import numpy as np
4+
from pydantic import BaseModel, Field
5+
6+
7+
# Define the Event model
8+
class Event(BaseModel):
9+
values: List[int] = Field(..., description="List of integer values for the event")
10+
11+
12+
# Define the Experiment model
13+
class Experiment(BaseModel):
14+
events: Dict[str, Event] = Field(default_factory=dict, description="Dictionary of events within the experiment")
15+
neural_data: Optional[np.ndarray] = Field(None, description="Neural recording data as a NumPy array")
16+
17+
def __getitem__(self, event_name: str) -> Event:
18+
if event_name not in self.events:
19+
self.events[event_name] = Event(values=[])
20+
return self.events[event_name]
21+
22+
def __setitem__(self, event_name: str, values: List[int]):
23+
self.events[event_name] = Event(values=values)
24+
25+
26+
# Define the Patient model
27+
class Patient(BaseModel):
28+
experiments: Dict[str, Experiment] = Field(
29+
default_factory=dict, description="Dictionary of experiments for the patient"
30+
)
31+
32+
def __getitem__(self, experiment_name: str) -> Experiment:
33+
if experiment_name not in self.experiments:
34+
self.experiments[experiment_name] = Experiment()
35+
return self.experiments[experiment_name]
36+
37+
def __setitem__(self, experiment_name: str, event_data: Dict[str, List[int]]):
38+
experiment = self.experiments.get(experiment_name, Experiment())
39+
for event_name, values in event_data.items():
40+
experiment[event_name] = values
41+
self.experiments[experiment_name] = experiment
42+
43+
44+
# Define the overall PatientsData model
45+
class PatientsData(BaseModel):
46+
patients: Dict[str, Patient] = Field(default_factory=dict, description="Dictionary of patients")
47+
48+
def __getitem__(self, patient_id: str) -> Patient:
49+
if patient_id not in self.patients:
50+
self.patients[patient_id] = Patient()
51+
return self.patients[patient_id]
52+
53+
def __setitem__(self, patient_id: str, experiment_data: Dict[str, Dict[str, List[int]]]):
54+
patient = self.patients.get(patient_id, Patient())
55+
for experiment_name, event_data in experiment_data.items():
56+
patient[experiment_name] = event_data
57+
self.patients[patient_id] = patient
58+
59+
60+
# Example usage within the module (can be removed or commented out for production use)
61+
if __name__ == "__main__":
62+
patients_data = PatientsData()
63+
64+
# Using direct assignment for events
65+
patients_data["567"]["free_recall1"]["LA"] = [1234, 23456]
66+
67+
# Adding neural data to the experiment
68+
patients_data["567"]["free_recall1"].neural_data = np.array([1.0, 2.0, 3.0])
69+
70+
print(patients_data["567"]["free_recall1"]["LA"].values)
71+
print(patients_data["567"]["free_recall1"].neural_data)

0 commit comments

Comments
 (0)