diff --git a/nltools/__init__.py b/nltools/__init__.py index acf5af8d..7f5c4463 100644 --- a/nltools/__init__.py +++ b/nltools/__init__.py @@ -18,7 +18,7 @@ from .cross_validation import set_cv from .data import Brain_Data, Adjacency, Groupby, Design_Matrix, Design_Matrix_Series from .simulator import Simulator -from .prefs import MNI_Template, resolve_mni_path +from .prefs import MNI_Template from .version import __version__ from .mask import expand_mask, collapse_mask, create_sphere from .external import SRM, DetSRM diff --git a/nltools/data/brain_data.py b/nltools/data/brain_data.py index 274becd0..e99453b2 100644 --- a/nltools/data/brain_data.py +++ b/nltools/data/brain_data.py @@ -68,7 +68,7 @@ ) from nltools.stats import regress as regression from .adjacency import Adjacency -from nltools.prefs import MNI_Template, resolve_mni_path +from nltools.prefs import MNI_Template from nilearn.decoding import SearchLight from pathlib import Path import warnings @@ -82,7 +82,6 @@ class Brain_Data(object): - """ Brain_Data is a class to represent neuroimaging data in python as a vector rather than a 3-dimensional matrix.This makes it easier to perform data @@ -106,7 +105,7 @@ def __init__(self, data=None, Y=None, X=None, mask=None, **kwargs): # Setup default or specified nifti masker if mask is None: # Load default mask - self.mask = nib.load(resolve_mni_path(MNI_Template)["mask"]) + self.mask = nib.load(MNI_Template.mask) elif isinstance(mask, (str, Path)): self.mask = nib.load(str(mask)) elif isinstance(mask, nib.Nifti1Image): @@ -653,9 +652,8 @@ def plot( else: raise ValueError("anatomical is not a nibabel instance") else: - # anatomical = nib.load(resolve_mni_path(MNI_Template)['plot']) - anatomical = get_mni_from_img_resolution(self, img_type="plot") + anatomical = get_mni_from_img_resolution(self, img_type="plot") if self.data.ndim == 1: if axes is None: _, axes = plt.subplots(nrows=1, figsize=figsize) @@ -728,7 +726,6 @@ def iplot(self, threshold=0, surface=False, anatomical=None, **kwargs): else: raise ValueError("anatomical is not a nibabel instance") else: - # anatomical = nib.load(resolve_mni_path(MNI_Template)['brain']) anatomical = get_mni_from_img_resolution(self, img_type="brain") return plot_interactive_brain( self, threshold=threshold, surface=surface, anatomical=anatomical, **kwargs @@ -1269,9 +1266,9 @@ def predict(self, algorithm=None, cv_dict=None, plot=True, verbose=True, **kwarg self.data[test] ) else: - output["dist_from_hyperplane_xval"][ - test - ] = predictor_cv.decision_function(self.data[test]) + output["dist_from_hyperplane_xval"][test] = ( + predictor_cv.decision_function(self.data[test]) + ) if ( predictor_settings["algorithm"] == "svm" and predictor_cv.probability diff --git a/nltools/mask.py b/nltools/mask.py index 6ea42498..19432bfd 100644 --- a/nltools/mask.py +++ b/nltools/mask.py @@ -12,7 +12,7 @@ import os import nibabel as nib -from nltools.prefs import MNI_Template, resolve_mni_path +from nltools.prefs import MNI_Template import pandas as pd import numpy as np import warnings @@ -42,7 +42,7 @@ def create_sphere(coordinates, radius=5, mask=None): ) else: - mask = nib.load(resolve_mni_path(MNI_Template)["mask"]) + mask = nib.load(MNI_Template.mask) def sphere(r, p, mask): """create a sphere of given radius at some point p in the brain mask diff --git a/nltools/plotting.py b/nltools/plotting.py index ebbad826..5248d786 100644 --- a/nltools/plotting.py +++ b/nltools/plotting.py @@ -29,8 +29,8 @@ from numpy.fft import fft, fftfreq from nltools.stats import two_sample_permutation, one_sample_permutation from nilearn.plotting import plot_glass_brain, plot_stat_map, view_img, view_img_on_surf -from nltools.prefs import MNI_Template, resolve_mni_path -from nltools.utils import attempt_to_import +from nltools.prefs import MNI_Template +from nltools.utils import attempt_to_import, get_mni_from_img_resolution import warnings import sklearn import os @@ -220,7 +220,7 @@ def plot_t_brain( cut_coords=c, display_mode=v, cmap=cmap, - bg_img=resolve_mni_path(MNI_Template)["brain"], + bg_img=MNI_Template.brain, **kwargs, ) elif how == "glass": @@ -239,7 +239,7 @@ def plot_t_brain( cut_coords=c, display_mode=v, cmap=cmap, - bg_img=resolve_mni_path(MNI_Template)["brain"], + bg_img=MNI_Template.brain, **kwargs, ) del obj @@ -314,8 +314,7 @@ def plot_brain(objIn, how="full", thr_upper=None, thr_lower=None, save=False, ** cut_coords=c, display_mode=v, cmap=cmap, - bg_img=resolve_mni_path(MNI_Template)["brain"], - **kwargs, + bg_img=get_mni_from_img_resolution(obj, img_type="brain") ** kwargs, ) if save: plt.savefig(savefile, bbox_inches="tight") @@ -337,8 +336,7 @@ def plot_brain(objIn, how="full", thr_upper=None, thr_lower=None, save=False, ** cut_coords=c, display_mode=v, cmap=cmap, - bg_img=resolve_mni_path(MNI_Template)["brain"], - **kwargs, + bg_img=get_mni_from_img_resolution(obj, img_type="brain") ** kwargs, ) if save: plt.savefig(savefile, bbox_inches="tight") diff --git a/nltools/prefs.py b/nltools/prefs.py index c8bdedd5..1fe72797 100644 --- a/nltools/prefs.py +++ b/nltools/prefs.py @@ -1,11 +1,10 @@ import os -from nltools.utils import get_resource_path -__all__ = ["MNI_Template", "resolve_mni_path"] +__all__ = ["MNI_Template"] -class MNI_Template_Factory(dict): - """Class to build the default MNI_Template dictionary. This should never be used +class MNI_Template_Factory(object): + """Class to build the default MNI_Template instance. This should never be used directly, instead just `from nltools.prefs import MNI_Template` and update that object's attributes to change MNI templates.""" @@ -13,25 +12,45 @@ def __init__( self, resolution="2mm", mask_type="with_ventricles", - mask=os.path.join(get_resource_path(), "MNI152_T1_2mm_brain_mask.nii.gz"), - plot=os.path.join(get_resource_path(), "MNI152_T1_2mm.nii.gz"), - brain=os.path.join(get_resource_path(), "MNI152_T1_2mm_brain.nii.gz"), + mni_version="nonlin_6thgen", ): + self._supported_resolutions = ["2mm", "3mm"] + self._supported_mni_versions = ["nonlin_6thgen", "2009a", "2009c"] + # Only applies to nonlin_6thgen + self._supported_mask_types = ["with_ventricles", "no_ventricles"] + self._resolution = resolution self._mask_type = mask_type - self._mask = mask - self._plot = plot - self._brain = brain - - self.update( - { - "resolution": self.resolution, - "mask_type": self.mask_type, - "mask": self.mask, - "plot": self.plot, - "brain": self.brain, - } - ) + self._mni_version = mni_version + + # Auto-populated (derive) mask, brain, plot + # This also always called on attribute access so the latest paths + # are resolved and can be safely used like nib.load(MNI_Template.mask) + # after updating an attribute, e.g. MNI_Template.resolution = 3 + # Avoids having to do nib.load(resolve_mni_path(MNI_Template).mask) like + # we were previously + self.resolve_paths() + + def __repr__(self) -> str: + if self.mni_version == "nonlin_6thgen": + return f"Current global template:\nresolution={self.resolution}\nmni_version={self.mni_version}\nmask_type={self.mask_type}\nmask={self.mask}\nbrain={self.brain}\nplot={self.plot}" + else: + return f"Current global template:\nresolution={self.resolution}\nmni_version={self.mni_version}\nmask={self.mask}\nbrain={self.brain}\nplot={self.plot}" + + @property + def mask(self): + self.resolve_paths() + return self._mask + + @property + def plot(self): + self.resolve_paths() + return self._plot + + @property + def brain(self): + self.resolve_paths() + return self._brain @property def resolution(self): @@ -41,12 +60,12 @@ def resolution(self): def resolution(self, resolution): if isinstance(resolution, (int, float)): resolution = f"{int(resolution)}mm" - if resolution not in ["2mm", "3mm"]: + if resolution not in self._supported_resolutions: raise NotImplementedError( - "Only 2mm and 3mm resolutions are currently supported" + f"Nltools currently supports the following MNI template resolutions: {self._supported_resolutions}" ) self._resolution = resolution - self.update({"resolution": self._resolution}) + self.resolve_paths() @property def mask_type(self): @@ -54,95 +73,59 @@ def mask_type(self): @mask_type.setter def mask_type(self, mask_type): - if mask_type not in ["with_ventricles", "no_ventricles"]: + if mask_type not in self._supported_mask_types: raise NotImplementedError( - "Only 'with_ventricles' and 'no_ventricles' mask_types are currently supported" + f"Nltools currently supports the following MNI mask_types (only applies to nonlin_6thgen): {self._supported_mask_types}" ) self._mask_type = mask_type - self.update({"mask_type": self._mask_type}) - - @property - def mask(self): - return self._mask - - @mask.setter - def mask(self, mask): - self._mask = mask - self.update({"mask": self._mask}) + self.resolve_paths() @property - def plot(self): - return self._plot + def mni_version(self): + return self._mni_version - @plot.setter - def plot(self, plot): - self._plot = plot - self.update({"plot": self._plot}) - - @property - def brain(self): - return self._brain - - @brain.setter - def brain(self, brain): - self._brain = brain - self.update({"brain": self._brain}) + @mni_version.setter + def mni_version(self, mni_version): + if mni_version not in self._supported_mni_versions: + raise ValueError( + f"Nltools currently supports the following MNI template versions: {self._supported_mni_versions}" + ) + self._mni_version = mni_version + self.resolve_paths() + + def resolve_paths(self): + base_path = os.path.join(os.path.dirname(__file__), "resources") + os.path.sep + if self._mni_version == "nonlin_6thgen": + if self._resolution == "3mm": + if self._mask_type == "with_ventricles": + self._mask = os.path.join( + base_path, "MNI152_T1_3mm_brain_mask.nii.gz" + ) + elif self._mask_type == "no_ventricles": + self._mask = os.path.join( + base_path, + "MNI152_T1_3mm_brain_mask_no_ventricles.nii.gz", + ) + self._plot = os.path.join(base_path, "MNI152_T1_3mm.nii.gz") + self._brain = os.path.join(base_path, "MNI152_T1_3mm_brain.nii.gz") + elif self._resolution == "2mm": + if self._mask_type == "with_ventricles": + self._mask = os.path.join( + base_path, "MNI152_T1_2mm_brain_mask.nii.gz" + ) + elif self._mask_type == "no_ventricles": + self._mask = os.path.join( + base_path, + "MNI152_T1_2mm_brain_mask_no_ventricles.nii.gz", + ) + self._plot = os.path.join(base_path, "MNI152_T1_2mm.nii.gz") + self._brain = os.path.join(base_path, "MNI152_T1_2mm_brain.nii.gz") + elif self._mni_version == "2009c": + pass + elif self._mni_version == "2009a": + pass # NOTE: We export this from the module and expect users to interact with it instead of # the class constructor above MNI_Template = MNI_Template_Factory() - - -def resolve_mni_path(MNI_Template): - """Helper function to resolve MNI path based on MNI_Template prefs setting.""" - - res = MNI_Template["resolution"] - m = MNI_Template["mask_type"] - if not isinstance(res, str): - raise ValueError("resolution must be provided as a string!") - if not isinstance(m, str): - raise ValueError("mask_type must be provided as a string!") - - if res == "3mm": - if m == "with_ventricles": - MNI_Template["mask"] = os.path.join( - get_resource_path(), "MNI152_T1_3mm_brain_mask.nii.gz" - ) - elif m == "no_ventricles": - MNI_Template["mask"] = os.path.join( - get_resource_path(), "MNI152_T1_3mm_brain_mask_no_ventricles.nii.gz" - ) - else: - raise ValueError( - "Available mask_types are 'with_ventricles' or 'no_ventricles'" - ) - - MNI_Template["plot"] = os.path.join(get_resource_path(), "MNI152_T1_3mm.nii.gz") - - MNI_Template["brain"] = os.path.join( - get_resource_path(), "MNI152_T1_3mm_brain.nii.gz" - ) - - elif res == "2mm": - if m == "with_ventricles": - MNI_Template["mask"] = os.path.join( - get_resource_path(), "MNI152_T1_2mm_brain_mask.nii.gz" - ) - elif m == "no_ventricles": - MNI_Template["mask"] = os.path.join( - get_resource_path(), "MNI152_T1_2mm_brain_mask_no_ventricles.nii.gz" - ) - else: - raise ValueError( - "Available mask_types are 'with_ventricles' or 'no_ventricles'" - ) - - MNI_Template["plot"] = os.path.join(get_resource_path(), "MNI152_T1_2mm.nii.gz") - - MNI_Template["brain"] = os.path.join( - get_resource_path(), "MNI152_T1_2mm_brain.nii.gz" - ) - else: - raise ValueError("Available templates are '2mm' or '3mm'") - return MNI_Template diff --git a/nltools/resources/tpl-MNI152NLin2009cAsym_res-01_T1w.nii.gz b/nltools/resources/tpl-MNI152NLin2009cAsym_res-01_T1w.nii.gz new file mode 100644 index 00000000..2b53aaac Binary files /dev/null and b/nltools/resources/tpl-MNI152NLin2009cAsym_res-01_T1w.nii.gz differ diff --git a/nltools/resources/tpl-MNI152NLin2009cAsym_res-01_desc-brain_T1w.nii.gz b/nltools/resources/tpl-MNI152NLin2009cAsym_res-01_desc-brain_T1w.nii.gz new file mode 100644 index 00000000..a318e996 Binary files /dev/null and b/nltools/resources/tpl-MNI152NLin2009cAsym_res-01_desc-brain_T1w.nii.gz differ diff --git a/nltools/resources/tpl-MNI152NLin2009cAsym_res-01_desc-brain_mask.nii.gz b/nltools/resources/tpl-MNI152NLin2009cAsym_res-01_desc-brain_mask.nii.gz new file mode 100644 index 00000000..4b9ecdb8 Binary files /dev/null and b/nltools/resources/tpl-MNI152NLin2009cAsym_res-01_desc-brain_mask.nii.gz differ diff --git a/nltools/resources/tpl-MNI152NLin2009cAsym_res-02_T1w.nii.gz b/nltools/resources/tpl-MNI152NLin2009cAsym_res-02_T1w.nii.gz new file mode 100644 index 00000000..653fd149 Binary files /dev/null and b/nltools/resources/tpl-MNI152NLin2009cAsym_res-02_T1w.nii.gz differ diff --git a/nltools/resources/tpl-MNI152NLin2009cAsym_res-02_desc-brain_T1w.nii.gz b/nltools/resources/tpl-MNI152NLin2009cAsym_res-02_desc-brain_T1w.nii.gz new file mode 100644 index 00000000..2bf51d47 Binary files /dev/null and b/nltools/resources/tpl-MNI152NLin2009cAsym_res-02_desc-brain_T1w.nii.gz differ diff --git a/nltools/resources/tpl-MNI152NLin2009cAsym_res-02_desc-brain_mask.nii.gz b/nltools/resources/tpl-MNI152NLin2009cAsym_res-02_desc-brain_mask.nii.gz new file mode 100644 index 00000000..cbf14add Binary files /dev/null and b/nltools/resources/tpl-MNI152NLin2009cAsym_res-02_desc-brain_mask.nii.gz differ diff --git a/nltools/simulator.py b/nltools/simulator.py index a2cf8be0..ee531428 100755 --- a/nltools/simulator.py +++ b/nltools/simulator.py @@ -20,7 +20,7 @@ from scipy.stats import multivariate_normal, binom, ttest_1samp from nltools.data import Brain_Data from nltools.stats import fdr, one_sample_permutation -from nltools.prefs import MNI_Template, resolve_mni_path +from nltools.prefs import MNI_Template import csv from copy import deepcopy from sklearn.utils import check_random_state @@ -39,7 +39,7 @@ def __init__( if isinstance(brain_mask, str): brain_mask = nib.load(brain_mask) elif brain_mask is None: - brain_mask = nib.load(resolve_mni_path(MNI_Template)["mask"]) + brain_mask = nib.load(MNI_Template.mask) elif ~isinstance(brain_mask, nib.nifti1.Nifti1Image): raise ValueError("brain_mask is not a string or a nibabel instance") self.brain_mask = brain_mask diff --git a/nltools/tests/test_brain_data.py b/nltools/tests/test_brain_data.py index c93dc794..6ac66f5c 100644 --- a/nltools/tests/test_brain_data.py +++ b/nltools/tests/test_brain_data.py @@ -8,6 +8,7 @@ from nltools.stats import threshold, align from nltools.mask import create_sphere, roi_to_brain from pathlib import Path +import matplotlib.pyplot as plt from nltools.prefs import MNI_Template @@ -24,13 +25,6 @@ def test_load(tmpdir): output_dir = str(tmpdir) dat = sim.create_data(y, sigma, reps=n_reps, output_dir=output_dir) - # if MNI_Template["resolution"] == '2mm': - # shape_3d = (91, 109, 91) - # shape_2d = (6, 238955) - # elif MNI_Template["resolution"] == '3mm': - # shape_3d = (60, 72, 60) - # shape_2d = (6, 71020) - y = pd.read_csv( os.path.join(str(tmpdir.join("y.csv"))), header=None, index_col=None ) @@ -73,9 +67,9 @@ def test_load(tmpdir): # case the mask argument takes precedence so we warn the user with pytest.warns(UserWarning): bb = Brain_Data( - os.path.join(tmpdir.join("test_write.h5")), mask=MNI_Template["mask"] + os.path.join(tmpdir.join("test_write.h5")), mask=MNI_Template.mask ) - assert bb.mask.get_filename() == MNI_Template["mask"] + assert bb.mask.get_filename() == MNI_Template.mask def test_shape(sim_brain_data): @@ -730,3 +724,16 @@ def test_load_legacy_h5(old_h5_brain, new_h5_brain, tmpdir): assert b_new.shape() == b_new_written.shape() assert np.allclose(b_new.data, b_new_written.data) new_file.unlink() + + +def test_plot(sim_brain_data): + # Plotting smoke tests + sim_brain_data.plot() + + # Can't plot 4d glass brain + with pytest.raises(ValueError): + sim_brain_data.plot(view="glass") + + sim_brain_data[0].plot(view="glass") + + plt.close("all") diff --git a/nltools/tests/test_prefs.py b/nltools/tests/test_prefs.py index 381f22f4..f31612e6 100644 --- a/nltools/tests/test_prefs.py +++ b/nltools/tests/test_prefs.py @@ -1,25 +1,48 @@ from nltools.prefs import MNI_Template from nltools.data import Brain_Data +import matplotlib.pyplot as plt import pytest -def test_change_mni_resolution(): +def test_change_mni_attribute(): # Defaults brain = Brain_Data() assert brain.mask.affine[1, 1] == 2.0 - assert MNI_Template["resolution"] == "2mm" + assert "2mm" in MNI_Template.brain - # -> 3mm - MNI_Template["resolution"] = "3mm" + # Change global -> 3mm + MNI_Template.resolution = "3mm" + # Now default is 3mm new_brain = Brain_Data() assert new_brain.mask.affine[1, 1] == 3.0 + assert "3mm" in MNI_Template.brain - # switch back and test attribute setting + # switch back MNI_Template.resolution = 2.0 # floats are cool - assert MNI_Template["resolution"] == "2mm" + assert MNI_Template.resolution == "2mm" + assert "2mm" in MNI_Template.brain + # Back to 2mm newer_brain = Brain_Data() assert newer_brain.mask.affine[1, 1] == 2.0 with pytest.raises(NotImplementedError): MNI_Template.resolution = 1 + + +def test_pref_and_plotting(sim_brain_data): + # Smoke tests to make sure updating templates doesn't cause plotting issues + # plot methods always refer to the resolution of the Brain_Data + # instance *itself* + + # Should have no effect as simulated data is in 2mm space + MNI_Template.resolution = "3mm" + sim_brain_data.plot() + + MNI_Template.resolution = "2mm" + sim_brain_data.plot() + + # But they do refer to the currently loaded MNI_Template to get the mni_version + # resolution + # TODO: A a test for using a different mni_version (e.g. 2009c) via MNI_Template and making suring plotting still works + plt.close("all") diff --git a/nltools/utils.py b/nltools/utils.py index 2a5e3e46..4f64888a 100644 --- a/nltools/utils.py +++ b/nltools/utils.py @@ -7,7 +7,6 @@ """ __all__ = [ "get_resource_path", - "get_anatomical", "set_algorithm", "attempt_to_import", "all_same", @@ -21,7 +20,6 @@ from os.path import dirname, join, sep as pathsep import nibabel as nib import importlib -import os from sklearn.pipeline import Pipeline from sklearn.utils import check_random_state from scipy.spatial.distance import squareform @@ -30,6 +28,8 @@ import collections from types import GeneratorType from h5py import File as h5File +from nltools.prefs import MNI_Template +import re def to_h5(obj, file_name, obj_type="brain_data", h5_compression="gzip"): @@ -70,13 +70,6 @@ def get_resource_path(): return join(dirname(__file__), "resources") + pathsep -def get_anatomical(): - """Get nltools default anatomical image. - DEPRECATED. See MNI_Template and resolve_mni_path from nltools.prefs - """ - return nib.load(os.path.join(get_resource_path(), "MNI152_T1_2mm.nii.gz")) - - def get_mni_from_img_resolution(brain, img_type="plot"): """ Get the path to the resolution MNI anatomical image that matches the resolution of a Brain_Data instance. Used by Brain_Data.plot() and .iplot() to set backgrounds appropriately. @@ -98,12 +91,11 @@ def get_mni_from_img_resolution(brain, img_type="plot"): "Voxels are not isometric and cannot be visualized in standard space" ) else: - dim = str(int(voxel_dims[0])) + "mm" + dim = f"_{str(int(voxel_dims[0]))}mm_" if img_type == "brain": - mni = f"MNI152_T1_{dim}_brain.nii.gz" + return re.sub(r"_[0-9]+mm_", dim, MNI_Template.brain) else: - mni = f"MNI152_T1_{dim}.nii.gz" - return os.path.join(get_resource_path(), mni) + return re.sub(r"_[0-9]+mm_", dim, MNI_Template.plot) def set_algorithm(algorithm, *args, **kwargs):