Skip to content

Commit 9f83da5

Browse files
add changes to hyak
1 parent 73f9022 commit 9f83da5

File tree

1 file changed

+7
-16
lines changed

1 file changed

+7
-16
lines changed

src/model_fc/models.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
from mpi4py import MPI
32
from nilearn.connectome import ConnectivityMeasure
43
from pyuoi.linear_model import UoI_Lasso
54
from sklearn.linear_model import ElasticNetCV, LassoCV, LassoLarsIC, RidgeCV
@@ -14,7 +13,6 @@ def run_model(train_ts, test_ts, n_rois, model, **kwargs):
1413
n_rois: number of rois in parcellation
1514
model: model object
1615
17-
1816
"""
1917
assert train_ts.shape[1] == n_rois == test_ts.shape[1]
2018
fc_mat = np.empty((n_rois, n_rois))
@@ -23,7 +21,7 @@ def run_model(train_ts, test_ts, n_rois, model, **kwargs):
2321

2422
for target_idx in range(train_ts.shape[1]):
2523
results_dict[f"node_{target_idx}"] = {}
26-
24+
print(f"*****ECHO {target_idx}***********")
2725
y_train = np.array(train_ts[:, target_idx])
2826
X_train = np.delete(train_ts, target_idx, axis=1)
2927

@@ -35,7 +33,6 @@ def run_model(train_ts, test_ts, n_rois, model, **kwargs):
3533
fc_mat[target_idx, :] = np.insert(model.coef_, target_idx, 1)
3634
test_rsq, train_rsq = eval_metrics(X_train, y_train, X_test, y_test, model)
3735

38-
results_dict[f"node_{target_idx}"]["model"] = model
3936
results_dict[f"node_{target_idx}"]["train_r2"] = train_rsq
4037
results_dict[f"node_{target_idx}"]["test_r2"] = test_rsq
4138

@@ -57,25 +54,19 @@ def init_model(
5754
model_str, max_iter, random_state, stability_selection=16, selection_frac=0.7
5855
):
5956
"""Initialize model object for FC calculations."""
60-
if model_str == "uoi-lasso":
57+
if model_str == "uoiLasso":
6158
uoi_lasso = UoI_Lasso(estimation_score="BIC")
62-
comm = MPI.COMM_WORLD
63-
6459
uoi_lasso.selection_frac = selection_frac
6560
uoi_lasso.stability_selection = stability_selection
6661
uoi_lasso.copy_X = True
6762
uoi_lasso.estimation_target = None
6863
uoi_lasso.logger = None
69-
uoi_lasso.warm_start = False
70-
uoi_lasso.comm = comm
71-
uoi_lasso.random_state = 1
64+
uoi_lasso.warm_start = True
7265
uoi_lasso.n_lambdas = 100
7366
uoi_lasso.max_iter = max_iter
74-
uoi_lasso.random_state = random_state
75-
7667
model = uoi_lasso
7768

78-
elif model_str == "lasso-cv":
69+
elif model_str == "lassoCV":
7970
lasso = LassoCV(
8071
fit_intercept=True,
8172
cv=5,
@@ -85,12 +76,12 @@ def init_model(
8576
)
8677

8778
model = lasso
88-
elif model_str == "ridge-cv":
89-
ridge = RidgeCV(fit_intercept=True, max_iter=max_iter)
79+
elif model_str == "ridgeCV":
80+
ridge = RidgeCV(fit_intercept=True)
9081

9182
model = ridge
9283

93-
elif model_str == "lasso-bic":
84+
elif model_str == "lassoBIC":
9485
lasso = LassoLarsIC(criterion="bic", fit_intercept=True, max_iter=max_iter)
9586

9687
model = lasso

0 commit comments

Comments
 (0)