Skip to content

Commit ae63fd0

Browse files
committed
[wip] replace Brain_Data.regress with nilearn and use new ResultsContainer class
1 parent faa825b commit ae63fd0

File tree

4 files changed

+111
-50
lines changed

4 files changed

+111
-50
lines changed

nltools/data/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
from .brain_data import Brain_Data, Groupby
66
from .adjacency import Adjacency
77
from .design_matrix import Design_Matrix, Design_Matrix_Series
8+
from .results import ResultsContainer
89

910
__all__ = [
1011
"Brain_Data",
1112
"Adjacency",
1213
"Groupby",
1314
"Design_Matrix",
1415
"Design_Matrix_Series",
16+
"ResultsContainer",
1517
]

nltools/data/brain_data.py

Lines changed: 43 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# Need to figure out how to speed up loading and resampling of data
1111

1212
from nilearn.signal import clean
13+
from nilearn.glm.first_level import FirstLevelModel
1314
from scipy.stats import ttest_1samp, pearsonr, spearmanr
1415
from scipy.spatial.distance import cdist
1516
from scipy.stats import t as t_dist
@@ -63,7 +64,6 @@
6364
find_spikes,
6465
regress_permutation,
6566
)
66-
from nltools.stats import regress as regression
6767
from .adjacency import Adjacency
6868
from nltools.prefs import MNI_Template
6969
from nilearn.decoding import SearchLight
@@ -72,6 +72,9 @@
7272
from contextlib import redirect_stdout
7373

7474

75+
warnings.filterwarnings("ignore", category=UserWarning, module="nilearn")
76+
warnings.filterwarnings("ignore", category=RuntimeWarning, module="nilearn")
77+
7578
# Optional dependencies
7679
nx = attempt_to_import("networkx", "nx")
7780
tables = attempt_to_import("tables")
@@ -879,26 +882,26 @@ def iplot(self, threshold=0, surface=False, anatomical=None, **kwargs):
879882
self, threshold=threshold, surface=surface, anatomical=anatomical, **kwargs
880883
)
881884

882-
def regress(self, mode="ols", **kwargs):
883-
"""Run a mass-univariate regression across voxels. Three types of regressions can be run:
884-
1) Standard OLS (default)
885-
2) Robust OLS (heteroscedasticty and/or auto-correlation robust errors), i.e. OLS with "sandwich estimators"
886-
3) ARMA (auto-regressive and moving-average lags = 1 by default; experimental)
885+
def regress(self, noise_model="ols", **kwargs):
886+
"""Runs a mass-univariate GLM analyses using the `Design_Matrix` supplied to `.X`
887887
888-
For more information see the help for nltools.stats.regress
888+
This is a wrapper around [`nilearn.glm.first_level.FirstLevelModel`](https://nilearn.github.io/stable/modules/generated/nilearn.glm.first_level.FirstLevelModel.html#nilearn.glm.first_level.FirstLevelModel) which you can reference for additional information about what `**kwargs` are supported.
889889
890-
ARMA notes: This experimental mode is similar to AFNI's 3dREMLFit but without spatial smoothing of voxel auto-correlation estimates. It can be **very computationally intensive** so parallelization is used by default to try to speed things up. Speed is limited because a unique ARMA model is fit to *each voxel* (like AFNI/FSL), but unlike SPM, which assumes the same AR parameters (~0.2) at each voxel. While coefficient results are typically very similar to OLS, std-errors and so t-stats, dfs and and p-vals can differ greatly depending on how much auto-correlation is explaining the response in a voxel
891-
relative to other regressors in the design matrix.
890+
However, we override some defaults:
891+
- no smoothing (use `.smooth()`)
892+
- no scaling (use `.scale()`
893+
- no drift model (should already be in the `Design_Matrix` set to `.X`)
894+
- OLS noise model (use `noise_model = 'ar1'` to switch but takes more time)
892895
893896
Args:
894-
mode (str): kind of model to fit; must be one of 'ols' (default), 'robust', or 'arma'
895-
kwargs (dict): keyword arguments to nltools.stats.regress
897+
noise_model (str, optional): temporal variance model. Defaults to "ols"
896898
897-
Returns:
898-
out: dictionary of regression statistics in Brain_Data instances
899-
{'beta','t','p','df','residual'}
900899
900+
Returns:
901+
ResultsContainer: with keys for each convolved column of `.X` and values as `Brain_Data` objects of the GLM statistics
901902
"""
903+
# Avoid circular import
904+
from .results import ResultsContainer
902905

903906
if not isinstance(self.X, pd.DataFrame):
904907
raise ValueError("Make sure self.X is a pandas DataFrame.")
@@ -909,43 +912,34 @@ def regress(self, mode="ols", **kwargs):
909912
if self.data.shape[0] != self.X.shape[0]:
910913
raise ValueError("self.X does not match the correct size of self.data")
911914

912-
b, se, t, p, df, res = regression(self.X, self.data, mode=mode, **kwargs)
913-
914-
# Prevent copy of all data in self multiple times; instead start with an empty instance and copy only needed attributes from self, and use this as a template for other outputs
915-
b_out = self.__class__()
916-
b_out.mask = deepcopy(self.mask)
917-
b_out.nifti_masker = deepcopy(self.nifti_masker)
918-
919-
# Use this as template for other outputs before setting data
920-
se_out = b_out.copy()
921-
t_out = b_out.copy()
922-
p_out = b_out.copy()
923-
df_out = b_out.copy()
924-
res_out = b_out.copy()
925-
(
926-
b_out.data,
927-
se_out.data,
928-
t_out.data,
929-
p_out.data,
930-
df_out.data,
931-
res_out.data,
932-
) = (
933-
b,
934-
se,
935-
t,
936-
p,
937-
df,
938-
res,
915+
# Nilearn FirstLevelModel default overrides
916+
smoothing_fwhm = kwargs.get("smoothing_fwhm", None)
917+
drift_model = kwargs.get("drift_model", None)
918+
signal_scaling = kwargs.get("signal_scaling", False)
919+
stat_type = kwargs.get("stat_type", "t")
920+
output_type = kwargs.get("output_type", "all")
921+
922+
# Run GLM
923+
glm = FirstLevelModel(
924+
t_r=1 / self.X.sampling_freq,
925+
mask_img=self.mask,
926+
smoothing_fwhm=smoothing_fwhm,
927+
drift_model=drift_model,
928+
signal_scaling=signal_scaling,
929+
noise_model=noise_model,
939930
)
940-
941-
return {
942-
"beta": b_out,
943-
"t": t_out,
944-
"p": p_out,
945-
"df": df_out,
946-
"sigma": se_out,
947-
"residual": res_out,
931+
glm.fit(self.to_nifti(), design_matrices=self.X)
932+
self.glm = glm
933+
934+
# Assemble results
935+
regressors_of_interest = self.X.convolved
936+
results = {
937+
r: ResultsContainer(
938+
glm.compute_contrast(r, stat_type=stat_type, output_type=output_type)
939+
)
940+
for r in regressors_of_interest
948941
}
942+
return results
949943

950944
def randomise(
951945
self, n_permute=5000, threshold_dict=None, return_mask=False, **kwargs

nltools/data/results.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from typing import Dict, Any
2+
import nibabel as nib
3+
from nltools.data import Brain_Data
4+
5+
6+
class ResultsContainer(object):
7+
"""A genericcontainer that dynamically creates attributes from a dictionary of string: Nifti1Image entries, where each attribute is a Brain_Data instance initialized from the corresponding Nifti image.
8+
9+
Args:
10+
images_dict (Dict[str, nib.Nifti1Image]): Dictionary mapping attribute names to Nifti images.
11+
12+
Example:
13+
>>> rc = ResultsContainer({'foo': img1, 'bar': img2})
14+
>>> rc.foo # Brain_Data instance
15+
>>> rc.bar # Brain_Data instance
16+
"""
17+
18+
def __init__(self, images=None):
19+
self._is_single = True
20+
self.data = []
21+
if isinstance(images, dict):
22+
for key, img in images.items():
23+
if not isinstance(img, nib.Nifti1Image):
24+
raise TypeError(
25+
f"Value for key '{key}' is not a Nifti1Image it's a {type(img)}."
26+
)
27+
if key == "stat":
28+
new_key = "t"
29+
elif key == "p_value":
30+
new_key = "p"
31+
elif key == "effect_size":
32+
new_key = "beta"
33+
elif key == "effect_variance":
34+
new_key = "se"
35+
else:
36+
new_key = key
37+
setattr(self, new_key, Brain_Data(img))
38+
elif isinstance(images, list):
39+
self._is_single = False
40+
for img in images:
41+
self.data.append(ResultsContainer(img))
42+
43+
def __repr__(self):
44+
attr_names = [k for k in self.__dict__.keys()]
45+
return f"ResultsContainer(attributes={attr_names})"
46+
47+
def append(self, image_dict):
48+
self._is_single = False
49+
# image_dict = (
50+
# ResultsContainer(image_dict)
51+
# if not isinstance(image_dict, ResultsContainer)
52+
# else image_dict
53+
# )
54+
self.data.append(image_dict)
55+
56+
def __getitem__(self, index):
57+
return self.data[index]
58+
59+
def __len__(self):
60+
return len(self.data)
61+
62+
def __iter__(self):
63+
return iter(self.data)

nltools/file_reader.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ def onsets_to_dm(
8686
polys = [c for c in dm.columns if "drift" in c or "constant" in c]
8787
else:
8888
convolved, polys = [], []
89-
dm = Design_Matrix(dm, convolved=convolved, sampling_freq=1 / TR, polys=polys)
89+
dm = Design_Matrix(
90+
dm, convolved=convolved, sampling_freq=1 / TR, polys=polys
91+
).reset_index(drop=True)
9092
out.append(dm)
9193

9294
return out if len(out) > 1 else out[0]

0 commit comments

Comments
 (0)