Skip to content

Commit eb95256

Browse files
author
dPys
committed
[DOC] Update pynets_bids config templates to reflect the -em flag
1 parent b55186b commit eb95256

File tree

1 file changed

+78
-74
lines changed

1 file changed

+78
-74
lines changed

pynets/stats/prediction.py

Lines changed: 78 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,8 @@ def bootstrapped_nested_cv(X, y, n_boots=10, var_thr=.8, k_folds=10,
775775
return grand_mean_best_estimator, grand_mean_best_Rsquared, grand_mean_best_MSE, mega_feat_imp_dict
776776

777777

778-
def make_subject_dict(modalities, base_dir, thr_type, mets, embedding_types):
778+
def make_subject_dict(modalities, base_dir, thr_type, mets, embedding_types,
779+
template):
779780
from joblib import Parallel, delayed
780781

781782
subject_dict = {}
@@ -821,7 +822,8 @@ def make_subject_dict(modalities, base_dir, thr_type, mets, embedding_types):
821822

822823
outs = Parallel(n_jobs=-1)(
823824
delayed(populate_subject_dict)(id, modality, grid,
824-
subject_dict, alg, mets=mets,
825+
subject_dict, alg, base_dir,
826+
template, thr_type, mets=mets,
825827
df_top=df_top) for id in ids)
826828
for d in outs:
827829
subject_dict.update(d)
@@ -830,8 +832,9 @@ def make_subject_dict(modalities, base_dir, thr_type, mets, embedding_types):
830832
return subject_dict, modality_grids
831833

832834

833-
def populate_subject_dict(id, modality, grid, subject_dict, alg, mets=None,
834-
df_top=None):
835+
def populate_subject_dict(id, modality, grid, subject_dict, alg, base_dir,
836+
template, thr_type, mets=None, df_top=None):
837+
835838
def filter_cols_from_targets(df_top, targets):
836839
base = r'^{}'
837840
expr = '(?=.*{})'
@@ -1292,7 +1295,9 @@ def _run_interface(self, runtime):
12921295
return runtime
12931296

12941297

1295-
def create_wf(base_dir, dict_file_path, modality_grids, drop_cols):
1298+
def create_wf(base_dir, dict_file_path, modality_grids, drop_cols,
1299+
target_vars, embedding_types):
1300+
12961301
ml_wf = pe.Workflow(name="ensemble_connectometry")
12971302
ml_wf.base_dir = f"{base_dir}/pynets_ml"
12981303

@@ -1420,73 +1425,72 @@ def create_wf(base_dir, dict_file_path, modality_grids, drop_cols):
14201425
return ml_wf
14211426

14221427

1423-
# if __name__ == "__main__":
1424-
# __spec__ = "ModuleSpec(name='builtins', loader=<class '_" \
1425-
# "frozen_importlib.BuiltinImporter'>)"
1426-
#
1427-
# base_dir = '/working/tuning_set/outputs_shaeffer'
1428-
# df = pd.read_csv(
1429-
# '/working/tuning_set/outputs_shaeffer/df_rum_persist_all.csv',
1430-
# index_col=False)
1431-
#
1432-
# # target_vars = ['rum_persist', 'dep_1', 'age']
1433-
# target_vars = ['rum_persist']
1434-
# thr_type = 'MST'
1435-
# drop_cols = ['rum_persist', 'dep_1', 'age', 'sex']
1436-
# # embedding_types = ['OMNI', 'ASE']
1437-
# embedding_types = ['OMNI']
1438-
# modalities = ['func', 'dwi']
1439-
# template = 'MNI152_T1'
1440-
# mets = ["global_efficiency", "average_clustering",
1441-
# "average_shortest_path_length", "average_betweenness_centrality",
1442-
# "average_eigenvector_centrality", "average_degree_centrality",
1443-
# "average_diversity_coefficient",
1444-
# "average_participation_coefficient"]
1445-
#
1446-
# hyperparams_func = ["rsn", "res", "model", 'hpass', 'extract', 'smooth']
1447-
# hyperparams_dwi = ["rsn", "res", "model", 'directget', 'minlength']
1448-
#
1449-
# ses = 1
1450-
#
1451-
# subject_dict, modality_grids = make_subject_dict(modalities, base_dir,
1452-
# thr_type)
1453-
# sub_dict_clean = cleanNullTerms(subject_dict)
1454-
#
1455-
# subject_dict_file_path = f"{base_dir}/pynets_subject_dict.pkl"
1456-
# with open(subject_dict_file_path, 'wb') as f:
1457-
# pickle.dump(sub_dict_clean, f, protocol=2)
1458-
# f.close()
1459-
#
1460-
# # Subset only those participants which have usable data
1461-
# df = df[df['participant_id'].isin(list(subject_dict.keys()))]
1462-
# df = df[['participant_id', 'rum_persist', 'dep_1', 'age', 'sex']]
1463-
#
1464-
# dict_file_path = make_feature_space_dict(df, modalities, subject_dict,
1465-
# ses, base_dir)
1466-
#
1467-
# ml_wf = create_wf(base_dir, dict_file_path, modality_grids, drop_cols)
1468-
#
1469-
# execution_dict = {}
1470-
# execution_dict["crashdump_dir"] = str(ml_wf.base_dir)
1471-
# execution_dict["poll_sleep_duration"] = 1
1472-
# execution_dict["crashfile_format"] = 'txt'
1473-
# execution_dict['local_hash_check'] = False
1474-
# execution_dict['hash_method'] = 'timestamp'
1475-
#
1476-
# cfg = dict(execution=execution_dict)
1477-
#
1478-
# for key in cfg.keys():
1479-
# for setting, value in cfg[key].items():
1480-
# ml_wf.config[key][setting] = value
1481-
#
1482-
# nthreads = psutil.cpu_count()
1483-
# procmem = [int(nthreads),
1484-
# int(list(psutil.virtual_memory())[4]/1000000000) - 2]
1485-
# plugin_args = {
1486-
# "n_procs": int(procmem[0]),
1487-
# "memory_gb": int(procmem[1]),
1488-
# "scheduler": "mem_thread",
1489-
# }
1490-
# # out = ml_wf.run(plugin='MultiProc', plugin_args=plugin_args)
1491-
# out = ml_wf.run(plugin='Linear', plugin_args=plugin_args)
1428+
if __name__ == "__main__":
1429+
__spec__ = "ModuleSpec(name='builtins', loader=<class '_" \
1430+
"frozen_importlib.BuiltinImporter'>)"
1431+
1432+
base_dir = '/working/tuning_set/outputs_shaeffer'
1433+
df = pd.read_csv(
1434+
'/working/tuning_set/outputs_shaeffer/df_rum_persist_all.csv',
1435+
index_col=False)
1436+
1437+
target_vars = ['rum_persist', 'dep_1', 'age']
1438+
# target_vars = ['rum_persist']
1439+
thr_type = 'MST'
1440+
drop_cols = ['rum_persist', 'dep_1', 'age', 'sex']
1441+
# embedding_types = ['OMNI', 'ASE']
1442+
embedding_types = ['topology']
1443+
modalities = ['func', 'dwi']
1444+
template = 'MNI152_T1'
1445+
mets = ["global_efficiency", "average_clustering",
1446+
"average_shortest_path_length", "average_betweenness_centrality",
1447+
"average_eigenvector_centrality", "average_degree_centrality",
1448+
"average_diversity_coefficient",
1449+
"average_participation_coefficient"]
1450+
1451+
hyperparams_func = ["rsn", "res", "model", 'hpass', 'extract', 'smooth']
1452+
hyperparams_dwi = ["rsn", "res", "model", 'directget', 'minlength']
1453+
1454+
ses = 1
1455+
1456+
subject_dict, modality_grids = make_subject_dict(modalities, base_dir, thr_type, mets, embedding_types, template)
1457+
sub_dict_clean = cleanNullTerms(subject_dict)
1458+
1459+
subject_dict_file_path = f"{base_dir}/pynets_subject_dict.pkl"
1460+
with open(subject_dict_file_path, 'wb') as f:
1461+
pickle.dump(sub_dict_clean, f, protocol=2)
1462+
f.close()
1463+
1464+
# Subset only those participants which have usable data
1465+
df = df[df['participant_id'].isin(list(subject_dict.keys()))]
1466+
df = df[['participant_id', 'rum_persist', 'dep_1', 'age', 'sex']]
1467+
1468+
dict_file_path = make_feature_space_dict(df, modalities, subject_dict,
1469+
ses, base_dir)
1470+
1471+
ml_wf = create_wf(base_dir, dict_file_path, modality_grids, drop_cols)
1472+
1473+
execution_dict = {}
1474+
execution_dict["crashdump_dir"] = str(ml_wf.base_dir)
1475+
execution_dict["poll_sleep_duration"] = 1
1476+
execution_dict["crashfile_format"] = 'txt'
1477+
execution_dict['local_hash_check'] = False
1478+
execution_dict['hash_method'] = 'timestamp'
1479+
1480+
cfg = dict(execution=execution_dict)
1481+
1482+
for key in cfg.keys():
1483+
for setting, value in cfg[key].items():
1484+
ml_wf.config[key][setting] = value
1485+
1486+
nthreads = psutil.cpu_count()
1487+
procmem = [int(nthreads),
1488+
int(list(psutil.virtual_memory())[4]/1000000000) - 2]
1489+
plugin_args = {
1490+
"n_procs": int(procmem[0]),
1491+
"memory_gb": int(procmem[1]),
1492+
"scheduler": "mem_thread",
1493+
}
1494+
# out = ml_wf.run(plugin='MultiProc', plugin_args=plugin_args)
1495+
out = ml_wf.run(plugin='Linear', plugin_args=plugin_args)
14921496

0 commit comments

Comments
 (0)