Skip to content

Commit 52e7dcf

Browse files
Merge pull request #4 from mckenziephagen/save_bic
Save bic
2 parents 0df6c23 + 9f83da5 commit 52e7dcf

File tree

2 files changed

+22
-20
lines changed

2 files changed

+22
-20
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
*ipynb*
2+
*cache*
3+
*egg*

src/model_fc/models.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,27 @@
11
import numpy as np
2-
from mpi4py import MPI
32
from nilearn.connectome import ConnectivityMeasure
43
from pyuoi.linear_model import UoI_Lasso
5-
from sklearn.linear_model import ElasticNetCV, LassoCV, LassoLarsIC
4+
from sklearn.linear_model import ElasticNetCV, LassoCV, LassoLarsIC, RidgeCV
65
from sklearn.metrics import r2_score
76

87

98
def run_model(train_ts, test_ts, n_rois, model, **kwargs):
109
"""Calculate a model based functional connectivity matrix.
1110
12-
1311
train_ts: training timeseries
1412
test_ts: testing timeseries
1513
n_rois: number of rois in parcellation
1614
model: model object
1715
18-
1916
"""
2017
assert train_ts.shape[1] == n_rois == test_ts.shape[1]
21-
fc_mat = np.zeros((n_rois, n_rois))
18+
fc_mat = np.empty((n_rois, n_rois))
2219

23-
inner_rsq_dict = {"train": [], "test": []}
20+
results_dict = {}
2421

2522
for target_idx in range(train_ts.shape[1]):
23+
results_dict[f"node_{target_idx}"] = {}
24+
print(f"*****ECHO {target_idx}***********")
2625
y_train = np.array(train_ts[:, target_idx])
2726
X_train = np.delete(train_ts, target_idx, axis=1)
2827

@@ -31,13 +30,15 @@ def run_model(train_ts, test_ts, n_rois, model, **kwargs):
3130

3231
model.fit(X=X_train, y=y_train)
3332

34-
fc_mat[target_idx, :] = np.insert(model.coef_, target_idx, 0)
33+
fc_mat[target_idx, :] = np.insert(model.coef_, target_idx, 1)
3534
test_rsq, train_rsq = eval_metrics(X_train, y_train, X_test, y_test, model)
3635

37-
inner_rsq_dict["test"].append(test_rsq)
38-
inner_rsq_dict["train"].append(train_rsq)
36+
results_dict[f"node_{target_idx}"]["train_r2"] = train_rsq
37+
results_dict[f"node_{target_idx}"]["test_r2"] = test_rsq
38+
39+
results_dict["fc_matrix"] = fc_mat
3940

40-
return (fc_mat, inner_rsq_dict, model)
41+
return results_dict
4142

4243

4344
def eval_metrics(X_train, y_train, X_test, y_test, model):
@@ -53,25 +54,19 @@ def init_model(
5354
model_str, max_iter, random_state, stability_selection=16, selection_frac=0.7
5455
):
5556
"""Initialize model object for FC calculations."""
56-
if model_str == "uoi-lasso":
57+
if model_str == "uoiLasso":
5758
uoi_lasso = UoI_Lasso(estimation_score="BIC")
58-
comm = MPI.COMM_WORLD
59-
6059
uoi_lasso.selection_frac = selection_frac
6160
uoi_lasso.stability_selection = stability_selection
6261
uoi_lasso.copy_X = True
6362
uoi_lasso.estimation_target = None
6463
uoi_lasso.logger = None
65-
uoi_lasso.warm_start = False
66-
uoi_lasso.comm = comm
67-
uoi_lasso.random_state = 1
64+
uoi_lasso.warm_start = True
6865
uoi_lasso.n_lambdas = 100
6966
uoi_lasso.max_iter = max_iter
70-
uoi_lasso.random_state = random_state
71-
7267
model = uoi_lasso
7368

74-
elif model_str == "lasso-cv":
69+
elif model_str == "lassoCV":
7570
lasso = LassoCV(
7671
fit_intercept=True,
7772
cv=5,
@@ -81,8 +76,12 @@ def init_model(
8176
)
8277

8378
model = lasso
79+
elif model_str == "ridgeCV":
80+
ridge = RidgeCV(fit_intercept=True)
81+
82+
model = ridge
8483

85-
elif model_str == "lasso-bic":
84+
elif model_str == "lassoBIC":
8685
lasso = LassoLarsIC(criterion="bic", fit_intercept=True, max_iter=max_iter)
8786

8887
model = lasso

0 commit comments

Comments
 (0)