diff --git a/.travis.yml b/.travis.yml index 4261a18..9660f4e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -53,6 +53,7 @@ install: - source ${MNE_ROOT}/bin/mne_setup_sh; - conda install --yes --quiet $ENSURE_PACKAGES pandas$PANDAS scikit-learn$SKLEARN patsy h5py pillow; - pip install -q joblib nibabel; + - pip install -q scot==0.2.1 - if [ "${PYTHON}" == "3.5" ]; then conda install --yes --quiet $ENSURE_PACKAGES ipython; else diff --git a/examples/connectivity/plot_mne_inverse_label_connectivity.py b/examples/connectivity/plot_mne_inverse_label_connectivity.py new file mode 100644 index 0000000..3d99419 --- /dev/null +++ b/examples/connectivity/plot_mne_inverse_label_connectivity.py @@ -0,0 +1,164 @@ +""" +========================================================================= +Compute source space connectivity and visualize it using a circular graph +========================================================================= + +This example computes connectivity between 68 regions in source space based on +dSPM inverse solutions and a FreeSurfer cortical parcellation. All-to-all +functional and effective connectivity measures are obtained from two different +methods: non-parametric spectral estimates and multivariate autoregressive +(MVAR) models. The connectivity is visualized using a circular graph which is +ordered based on the locations of the regions. + +MVAR connectivity is computed with the Source Connectivity Toolbox (SCoT), see +http://scot-dev.github.io/scot-doc/index.html for details. +""" +# Authors: Martin Luessi +# Alexandre Gramfort +# Martin Billinger +# Nicolas P. Rougier (graph code borrowed from his matplotlib gallery) +# +# License: BSD (3-clause) + +import numpy as np +import matplotlib.pyplot as plt + +import mne +from mne.datasets import sample +from mne.minimum_norm import apply_inverse_epochs, read_inverse_operator +from mne.connectivity import spectral_connectivity +from mne_sandbox.connectivity import mvar_connectivity +from mne.viz import circular_layout +from mne_sandbox.viz import (plot_connectivity_circle, + plot_connectivity_inoutcircles) +from scot.connectivity_statistics import significance_fdr + +print(__doc__) + +data_path = sample.data_path() +subjects_dir = data_path + '/subjects' +fname_inv = data_path + '/MEG/sample/sample_audvis-meg-oct-6-meg-inv.fif' +fname_raw = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif' +fname_event = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif' + +# Load data +inverse_operator = read_inverse_operator(fname_inv) +raw = mne.io.read_raw_fif(fname_raw) +events = mne.read_events(fname_event) + +# Add a bad channel +raw.info['bads'] += ['MEG 2443'] + +# Pick MEG channels +picks = mne.pick_types(raw.info, meg=True, eeg=False, stim=False, eog=True, + exclude='bads') + +# Define epochs for left-auditory condition +event_id, tmin, tmax = 1, -0.2, 0.5 +epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, + baseline=(None, 0), reject=dict(mag=4e-12, grad=4000e-13, + eog=150e-6)) + +# Compute inverse solution and for each epoch. By using "return_generator=True" +# stcs will be a generator object instead of a list. +snr = 1.0 # use lower SNR for single epochs +lambda2 = 1.0 / snr ** 2 +method = "dSPM" # use dSPM method (could also be MNE or sLORETA) +stcs = apply_inverse_epochs(epochs, inverse_operator, lambda2, method, + pick_ori="normal", return_generator=True) + +# Get labels for FreeSurfer 'aparc' cortical parcellation with 34 labels/hemi +labels = mne.read_labels_from_annot('sample', parc='aparc', + subjects_dir=subjects_dir) +label_colors = [label.color for label in labels] + +# Average the source estimates within each label using sign-flips to reduce +# signal cancellations. We do not return a generator, because we want to use +# the estimates repeatedly. +src = inverse_operator['src'] +label_ts = mne.extract_label_time_course(stcs, labels, src, mode='mean_flip', + return_generator=False) + +# First, compute connectivity from spectral estimates in the alpha band. +fmin = 8. +fmax = 13. +sfreq = raw.info['sfreq'] # the sampling frequency +con_methods = ['wpli2_debiased', 'coh'] +con, freqs, times, n_epochs, n_tapers = spectral_connectivity( + label_ts, method=con_methods, mode='multitaper', sfreq=sfreq, fmin=fmin, + fmax=fmax, faverage=True, mt_adaptive=True, n_jobs=1) + +# con is a 3D array, get the connectivity for the first (and only) freq. band +# for each method +con_spec = dict() +for method, c in zip(con_methods, con): + con_spec[method] = c[:, :, 0] + +# Second, compute connectivity from multivariate autoregressive models. +mvar_methods = ['PDC', 'COH'] +con, freqs, order, p_vals = mvar_connectivity(label_ts, mvar_methods, + sfreq=sfreq, fmin=fmin, + fmax=fmax, ridge=10, + n_surrogates=100, n_jobs=1) + +# Get connectivity for the first frequency band. Set connectivity to 0 if not +# significant, while compensating for multiple testing by controlling the false +# discovery rate. +con_mvar = dict() +for method, c, p in zip(mvar_methods, con, p_vals): + con_mvar[method] = c[:, :, 0] * significance_fdr(p[:, :, 0], 0.01) + +# Now, we visualize the connectivity using a circular graph layout + +# First, we reorder the labels based on their location in the left hemi +label_names = [label.name for label in labels] + +lh_labels = [name for name in label_names if name.endswith('lh')] + +# Get the y-location of the label +label_ypos = list() +for name in lh_labels: + idx = label_names.index(name) + ypos = np.mean(labels[idx].pos[:, 1]) + label_ypos.append(ypos) + +# Reorder the labels based on their location +lh_labels = [label for (yp, label) in sorted(zip(label_ypos, lh_labels))] + +# For the right hemi +rh_labels = [label[:-2] + 'rh' for label in lh_labels] + +# Save the plot order and create a circular layout +node_order = list() +node_order.extend(lh_labels[::-1]) # reverse the order +node_order.extend(rh_labels) + +node_angles = circular_layout(label_names, node_order, start_pos=90, + group_boundaries=[0, len(label_names) / 2]) + +# Plot the graph using node colors from the FreeSurfer parcellation. We only +# show the 300 strongest connections. +plot_connectivity_circle(con_spec['wpli2_debiased'], label_names, n_lines=300, + node_angles=node_angles, node_colors=label_colors, + title='All-to-All Connectivity left-Auditory ' + 'Condition (WPLI^2, debiased)', show=False) +plt.savefig('circle.png', facecolor='black') + +# Compare coherence from both estimation methods +fig = plt.figure(num=None, figsize=(8, 4), facecolor='black') +for ii, (con, method) in enumerate(zip([con_spec['coh'], con_mvar['COH']], + ['Spectral', 'MVAR'])): + plot_connectivity_circle(con, label_names, n_lines=300, + node_angles=node_angles, node_colors=label_colors, + title=method, padding=0, fontsize_colorbar=6, + fig=fig, subplot=(1, 2, ii + 1), plot_names=False, + show=False) +plt.suptitle('All-to-all coherence', color='white', fontsize=14) + +# Show effective (directed) connectivity for one node +plot_connectivity_inoutcircles(con_mvar['PDC'], 'superiortemporal-lh', + label_names, node_angles=node_angles, padding=0, + node_colors=label_colors, plot_names=False, + title='Effective connectivity (PDC)', show=False) + +plt.show() diff --git a/mne_sandbox/connectivity/__init__.py b/mne_sandbox/connectivity/__init__.py new file mode 100644 index 0000000..0290e17 --- /dev/null +++ b/mne_sandbox/connectivity/__init__.py @@ -0,0 +1,4 @@ +""" Connectivity Analysis Tools +""" + +from .mvar import mvar_connectivity diff --git a/mne_sandbox/connectivity/mvar.py b/mne_sandbox/connectivity/mvar.py new file mode 100644 index 0000000..88e2af1 --- /dev/null +++ b/mne_sandbox/connectivity/mvar.py @@ -0,0 +1,277 @@ +# Authors: Martin Billinger +# +# License: BSD (3-clause) + +from __future__ import division +import numpy as np +import logging +from types import GeneratorType + +from scot.varbase import VARBase +from scot.var import VAR +from scot.connectivity import connectivity +from scot.connectivity_statistics import surrogate_connectivity +from scot.xvschema import make_nfold +from scot.utils import acm + +from mne.parallel import parallel_func +from mne.utils import logger, verbose + + +def _autocorrelations(epochs, max_lag): + return [acm(epochs, l) for l in range(max_lag + 1)] + + +def _get_n_epochs(epochs, n): + """Generator that returns lists with at most n epochs""" + epochs_out = [] + for e in epochs: + epochs_out.append(e) + if len(epochs_out) >= n: + yield epochs_out + epochs_out = [] + if len(epochs_out) > 0: + yield epochs_out + + +def _get_n_epochblocks(epochs, n_blocks, blocksize): + """Generator that returns lists with at most n_blocks of epochs""" + blocks_out = [] + for block in _get_n_epochs(epochs, blocksize): + blocks_out.append(block) + if len(blocks_out) >= n_blocks: + yield blocks_out + blocks_out = [] + if len(blocks_out) > 0: + yield blocks_out + + +def _fit_mvar_lsq(data, pmin, pmax, delta, n_jobs, verbose): + var = VAR(pmin, delta, xvschema=make_nfold(10)) + if pmin != pmax: + logger.info('MVAR order selection...') + var.optimize_order(data, pmin, pmax, n_jobs=n_jobs, verbose=verbose) + # todo: only convert to list if data is a generator + data = np.asarray(list(data)) + var.fit(data) + return var + + +def _fit_mvar_yw(data, pmin, pmax, n_jobs, blocksize, verbose=None): + if pmin != pmax: + raise NotImplementedError('Yule-Walker fitting does not support ' + 'automatic model order selection.') + order = pmin + + if not isinstance(data, GeneratorType): + blocksize = int(np.ceil(len(data) / n_jobs)) + + parallel, block_autocorrelations, _ = \ + parallel_func(_autocorrelations, n_jobs, + verbose=verbose) + n_blocks = 0 + for blocks in _get_n_epochblocks(data, n_jobs, blocksize): + acms = parallel(block_autocorrelations(block, order) + for block in blocks) + if n_blocks == 0: + acm_estimates = np.sum(acms, 0) + else: + acm_estimates += np.sum(acms, 0) + n_blocks += len(blocks) + acm_estimates /= n_blocks + + var = VARBase(order) + var.from_yw(acm_estimates) + + return var + + +@verbose +def mvar_connectivity(data, method, order=(1, None), fitting_mode='lsq', + ridge=0, sfreq=2 * np.pi, fmin=0, fmax=np.inf, n_fft=64, + n_surrogates=None, buffer_size=8, n_jobs=1, blocksize=10, + verbose=None): + """Estimate connectivity from multivariate autoregressive (MVAR) models. + + This function uses routines from SCoT [1] to fit MVAR models and compute + connectivity measures. + + Parameters + ---------- + data : array, shape=(n_epochs, n_signals, n_times) + or list/generator of array, shape =(n_signals, n_times) + The data from which to compute connectivity. + method : string | list of string + Connectivity measure(s) to compute. Supported measures: + 'COH' : coherence [2] + 'pCOH' : partial coherence [3] + 'PDC' : partial directed coherence [4] + 'PDCF' : partial directed coherence factor [4] + 'GPDC' : generalized partial directed coherence [5] + 'DTF' : directed transfer function [6] + 'ffDTF' : full-frequency directed transfer function [7] + 'dDTF' : "direct" directed transfer function [7] + 'GDTF' : generalized directed transfer function [5] + order : int | (int, int) + Order (length) of the underlying MVAR model. If order is a tuple + (p0, p1) of two ints, the function selects the best model order between + p0 and p1. p1 can be None, which causes the order selection to stop at + the lowest candidate. + fitting_mode : str + Determines how to fit the MVAR model. + 'lsq' : Least-Squares fitting + 'yw' : Solve Yule-Walker equations + Yule-Walker equations can utilize data generators, which makes them + more memory efficient than least-squares. However, yw-estimation may + fail if `order` or `n_signals` is too high for the amount of data + available. + ridge : float + Ridge-regression coefficient (l2 penalty) for least-squares fitting. + This parameter is ignored for Yule-Walker fitting. + sfreq : float + The sampling frequency. + fmin : float | tuple of floats + The lower frequency of interest. Multiple bands are defined using + a tuple, e.g., (8., 16.) for two bands with 8Hz and 16Hz lower freq. + fmax : float | tuple of floats + The upper frequency of interest. Multiple bands are defined using + a tuple, e.g. (12., 24.) for two band with 12Hz and 24Hz upper freq. + n_fft : int + Number of FFT bins to calculate. + n_surrogates : int | None + If set to None, no statistics are calculated. Otherwise, `surrogates` + is the number of surrogate datasets on which the chance level is + calculated. In this case the *p*-values are returned, which are related + to the probability that the observed connectivity is not caused by + chance. See scot.connectivity_statistics.surrogate_connectivity for + details on the procedure. + **Warning**: Correction for multiple testing is required if the + *p*-values are used as basis for significance testing. + buffer_size : int + Surrogates are calculated in `n_surrogates // buffer_size` blocks. + Lower buffer_size takes less memory but has more computational + overhead than higher buffer_size. + n_jobs : int + Number of jobs to run in parallel. This is used for model order + selection and statistics calculations. + blocksize : int + Epochs are prozessed in batches of size blocksize. For best performance + set blocksize so that `n_epochs == n_jobs * blocksize`. + verbose : bool, str, int, or None + If not None, override default verbose level (see mne.verbose). + + Returns + ------- + con : array | list of arrays + Computed connectivity measure(s). The shape of each array is + (n_signals, n_signals, n_frequencies) + freqs : array + Frequency points at which the connectivity was computed. + var_order : int + MVAR model order that was used for fitting the model. + p_values : array | list of arrays | None + *p*-values of connectivity measure(s). The shape of each array is + (n_signals, n_signals, n_frequencies). `p_values` is returned as None + if no statistics are calculated (i.e. `n_surrogates` evaluates to + False). + + References + ---------- + [1] M. Billinger, C.Brunner, G. R. Mueller-Putz. "SCoT: a Python toolbox + for EEG source connectivity", Frontiers in Neuroinformatics, 8:22, 2014 + + [2] P. L. Nunez, R. Srinivasan, A. F. Westdorp, R. S. Wijesinghe, + D. M. Tucker, R. B. Silverstein, P. J. Cadusch. EEG coherency: I: + statistics, reference electrode, volume conduction, Laplacians, + cortical imaging, and interpretation at multiple scales. Electroenceph. + Clin. Neurophysiol. 103(5): 499-515, 1997. + + [3] P. J. Franaszczuk, K. J. Blinowska, M. Kowalczyk. The application of + parametric multichannel spectral estimates in the study of electrical + brain activity. Biol. Cybernetics 51(4): 239-247, 1985. + + [4] L. A. Baccala, K. Sameshima. Partial directed coherence: a new concept + in neural structure determination. Biol. Cybernetics 84(6):463-474, + 2001. + + [5] L. Faes, S. Erla, G. Nollo. Measuring Connectivity in Linear + Multivariate Processes: Definitions, Interpretation, and Practical + Analysis. Comput. Math. Meth. Med. 2012:140513, 2012. + + [6] M. J. Kaminski, K. J. Blinowska. A new method of the description of the + information flow in the brain structures. Biol. Cybernetics 65(3): + 203-210, 1991. + + [7] A. Korzeniewska, M. Manczak, M. Kaminski, K. J. Blinowska, S. Kasicki. + Determination of information flow direction among brain structures by a + modified directed transfer function (dDTF) method. J. Neurosci. Meth. + 125(1-2): 195-207, 2003. + """ + scot_verbosity = 5 if logger.level <= logging.INFO else 0 + + if not isinstance(method, (list, tuple)): + method = [method] + + fmin = np.asarray((fmin,)).ravel() + fmax = np.asarray((fmax,)).ravel() + if len(fmin) != len(fmax): + raise ValueError('fmin and fmax must have the same length') + if np.any(fmin > fmax): + raise ValueError('fmax must be larger than fmin') + + try: + pmin, pmax = order[0], order[1] + except TypeError: + pmin, pmax = order, order + + logger.info('MVAR fitting...') + if fitting_mode == 'yw': + var = _fit_mvar_yw(data, pmin, pmax, n_jobs=n_jobs, + blocksize=blocksize, verbose=verbose) + elif fitting_mode == 'lsq': + var = _fit_mvar_lsq(data, pmin, pmax, ridge, n_jobs=n_jobs, + verbose=scot_verbosity) + else: + raise ValueError('Unknown fitting mode: %s' % fitting_mode) + + freqs, fmask = [], [] + freq_range = np.linspace(0, sfreq / 2, n_fft) + for fl, fh in zip(fmin, fmax): + fmask.append(np.logical_and(fl <= freq_range, freq_range <= fh)) + freqs.append(freq_range[fmask[-1]]) + + logger.info('Connectivity computation...') + results = [] + con = connectivity(method, var.coef, var.rescov, n_fft) + for mth in method: + bands = [np.mean(np.abs(con[mth][:, :, fm]), axis=2) for fm in fmask] + results.append(np.transpose(bands, (1, 2, 0))) + + if n_surrogates is not None and n_surrogates > 0: + logger.info('Computing connectivity statistics...') + data = np.asarray(list(data)).transpose([2, 1, 0]) + + n_blocks = n_surrogates // buffer_size + + p_vals = [] + # do them in junks, in order to save memory + for i in range(n_blocks): + scon = surrogate_connectivity(method, data, var, nfft=n_fft, + repeats=buffer_size, n_jobs=n_jobs, + verbose=scot_verbosity) + + for m, mth in enumerate(method): + c, sc = np.abs(con[mth]), np.abs(scon[mth]) + bands = [np.mean(c[:, :, fm], axis=-1) for fm in fmask] + sbands = [np.mean(sc[:, :, :, fm], axis=-1) for fm in fmask] + + p = [np.sum(bs >= b, axis=0) for b, bs in zip(bands, sbands)] + p = np.array(p).transpose(1, 2, 0) / (n_blocks * buffer_size) + if i == 0: + p_vals.append(p) + else: + p_vals[m] += p + else: + p_vals = None + + return results, freqs, var.p, p_vals diff --git a/mne_sandbox/connectivity/tests/test_mvar.py b/mne_sandbox/connectivity/tests/test_mvar.py new file mode 100644 index 0000000..0dd1b86 --- /dev/null +++ b/mne_sandbox/connectivity/tests/test_mvar.py @@ -0,0 +1,178 @@ +# Authors: Martin Billinger +# +# License: BSD (3-clause) + +import numpy as np +from numpy.testing import (assert_array_equal, assert_array_almost_equal, + assert_array_less) +from nose.tools import assert_raises, assert_equal +from copy import deepcopy + +from mne_sandbox.connectivity import mvar_connectivity +from mne_sandbox.connectivity.mvar import _fit_mvar_lsq, _fit_mvar_yw + + +def _make_data(var_coef, n_samples, n_epochs): + var_order = var_coef.shape[0] + n_signals = var_coef.shape[1] + + x = np.random.randn(n_signals, n_epochs * n_samples + 10 * var_order) + for i in range(var_order, x.shape[1]): + for k in range(var_order): + x[:, [i]] += np.dot(var_coef[k], x[:, [i - k - 1]]) + + x = x[:, -n_epochs * n_samples:] + + win = np.arange(0, n_samples) + return [x[:, i + win] for i in range(0, n_epochs * n_samples, n_samples)] + + +def _data_generator(data): + for d in data: + yield d + + +def test_mvar_connectivity(): + """Test MVAR connectivity estimation""" + # Use a case known to have no spurious correlations (it would bad if + # nosetests could randomly fail): + np.random.seed(0) + + n_sigs = 3 + n_epochs = 100 + n_samples = 500 + + # test invalid fmin fmax settings + assert_raises(ValueError, mvar_connectivity, [], 'S', 5, fmin=10, fmax=5) + assert_raises(ValueError, mvar_connectivity, [], 'DTF', 1, fmin=(0, 11), + fmax=(5, 10)) + assert_raises(ValueError, mvar_connectivity, [], 'PDC', 99, fmin=(11,), + fmax=(12, 15)) + assert_raises(ValueError, mvar_connectivity, [], 'S', fitting_mode='') + assert_raises(NotImplementedError, mvar_connectivity, [], 'H', + fitting_mode='yw') + + methods = ['S', 'COH', 'DTF', 'PDC', 'ffDTF', 'GPDC', 'GDTF', 'A'] + + # generate data without connectivity + var_coef = np.zeros((1, n_sigs, n_sigs)) + data = _make_data(var_coef, n_samples, n_epochs) + + con, freqs, p, p_vals = mvar_connectivity(data, methods, order=1, + fitting_mode='yw') + con = dict((m, c) for m, c in zip(methods, con)) + assert_equal(p, 1) + + assert_array_almost_equal(con['S'][:, :, 0], np.eye(n_sigs), decimal=2) + assert_array_almost_equal(con['COH'][:, :, 0], np.eye(n_sigs), decimal=2) + assert_array_almost_equal(con['COH'][:, :, 0].diagonal(), np.ones(n_sigs)) + assert_array_almost_equal(con['DTF'][:, :, 0], np.eye(n_sigs), decimal=2) + assert_array_almost_equal(con['PDC'][:, :, 0], np.eye(n_sigs), decimal=2) + assert_array_almost_equal(con['ffDTF'][:, :, 0] / np.sqrt(len(freqs[0])), + np.eye(n_sigs), decimal=2) + assert_array_almost_equal(con['GPDC'][:, :, 0], np.eye(n_sigs), decimal=2) + assert_array_almost_equal(con['GDTF'][:, :, 0], np.eye(n_sigs), decimal=2) + + # generate data with strong directed connectivity + f = 1e3 + var_coef = np.zeros((1, n_sigs, n_sigs)) + var_coef[:, 1, 0] = f + data = _make_data(var_coef, n_samples, n_epochs) + + con, freqs, p, p_vals = mvar_connectivity(data, methods, order=(2, 5)) + con = dict((m, c) for m, c in zip(methods, con)) + + h = var_coef.squeeze() + np.eye(n_sigs) + + assert_array_almost_equal(con['S'][:, :, 0] / f**2, np.dot(h, h.T) / f**2, + decimal=2) + assert_array_almost_equal(con['COH'][:, :, 0], np.dot(h, h.T) > 0, + decimal=2) + assert_array_almost_equal(con['DTF'][:, :, 0], + h / np.sum(h, 1, keepdims=True), decimal=2) + assert_array_almost_equal(con['ffDTF'][:, :, 0] / np.sqrt(len(freqs[0])), + h / np.sum(h, 1, keepdims=True), decimal=2) + assert_array_almost_equal(con['GDTF'][:, :, 0], + h / np.sum(h, 1, keepdims=True), decimal=2) + assert_array_almost_equal(con['PDC'][:, :, 0], + h / np.sum(h, 0, keepdims=True), decimal=2) + assert_array_almost_equal(con['GPDC'][:, :, 0], + h / np.sum(h, 0, keepdims=True), decimal=2) + + # generate data with strong cascaded directed connectivity + f = 1e3 + var_coef = np.zeros((1, n_sigs, n_sigs)) + var_coef[:, 1, 0] = f + var_coef[:, 2, 1] = f + data = _make_data(var_coef, n_samples, n_epochs) + + con, freqs, p, p_vals = mvar_connectivity(data, methods, order=(1, None)) + con = dict((m, c) for m, c in zip(methods, con)) + + assert_array_almost_equal(con['S'][:, :, 0] / f**4, [[f**-4, f**-3, f**-2], + [f**-3, f**-2, f**-1], + [f**-2, f**-1, f**0]], + decimal=2) + assert_array_almost_equal(con['COH'][:, :, 0], np.ones((n_sigs, n_sigs)), + decimal=2) + assert_array_almost_equal(con['DTF'][:, :, 0], [[1, 0, 0], + [1, 0, 0], + [1, 0, 0]], decimal=2) + assert_array_almost_equal(con['ffDTF'][:, :, 0] / np.sqrt(len(freqs[0])), + [[1, 0, 0], [1, 0, 0], [1, 0, 0]], decimal=2) + assert_array_almost_equal(con['GDTF'], con['DTF'], decimal=2) + + h = var_coef.squeeze() + np.eye(n_sigs) + assert_array_almost_equal(con['PDC'][:, :, 0], + h / np.sum(h, 0, keepdims=True), decimal=2) + assert_array_almost_equal(con['GPDC'], con['PDC'], decimal=2) + + # generate data with some directed connectivity + # check if statistics report only significant connectivity where the + # original coefficients were non-zero + var_coef = np.zeros((1, n_sigs, n_sigs)) + var_coef[:, 1, 0] = 1 + var_coef[:, 2, 1] = 1 + data = _make_data(var_coef, n_samples, n_epochs) + + con, freqs, p, p_vals = mvar_connectivity(data, 'PDC', order=(1, None), + n_surrogates=20) + + for i in range(n_sigs): + for j in range(n_sigs): + if var_coef[0, i, j] > 0: + assert_array_less(p_vals[0][i, j, 0], 0.05) + else: + assert_array_less(0.05, p_vals[0][i, j, 0]) + + +def test_fit_mvar(): + """Test MVAR model fitting""" + np.random.seed(42) + + n_sigs = 3 + n_epochs = 65 + n_samples = 200 + + var_coef = np.zeros((1, n_sigs, n_sigs)) + var_coef[0, :, :] = [[0.9, 0, 0], + [1, 0.5, 0], + [2, 0, -0.5]] + data = _make_data(var_coef, n_samples, n_epochs) + data0 = deepcopy(data) + + var = _fit_mvar_lsq(data, pmin=1, pmax=1, delta=0, n_jobs=1, verbose=0) + assert_array_equal(data, data0) + assert_array_almost_equal(var_coef[0], var.coef, decimal=2) + + var = _fit_mvar_yw(data, pmin=1, pmax=1, n_jobs=1, blocksize=7, verbose=0) + assert_array_equal(data, data0) + assert_array_almost_equal(var_coef[0], var.coef, decimal=2) + + data = _data_generator(data0) + var = _fit_mvar_lsq(data, pmin=1, pmax=1, delta=0, n_jobs=2, verbose=0) + assert_array_almost_equal(var_coef[0], var.coef, decimal=2) + + data = _data_generator(data0) + var = _fit_mvar_yw(data, pmin=1, pmax=1, n_jobs=2, blocksize=9, verbose=0) + assert_array_almost_equal(var_coef[0], var.coef, decimal=2) diff --git a/mne_sandbox/viz/__init__.py b/mne_sandbox/viz/__init__.py new file mode 100644 index 0000000..ee30e0e --- /dev/null +++ b/mne_sandbox/viz/__init__.py @@ -0,0 +1,6 @@ +"""Visualization routines +""" + + +from .connectivity import (plot_connectivity_circle, plot_connectivity_matrix, + plot_connectivity_inoutcircles) diff --git a/mne_sandbox/viz/connectivity.py b/mne_sandbox/viz/connectivity.py new file mode 100644 index 0000000..537df41 --- /dev/null +++ b/mne_sandbox/viz/connectivity.py @@ -0,0 +1,586 @@ +"""Functions to plot directed connectivity +""" +from __future__ import print_function + +# Authors: Martin Billinger +# +# License: Simplified BSD + + +from itertools import cycle +from functools import partial + +import numpy as np + +from mne.viz.utils import plt_show +from mne.externals.six import string_types +from mne.fixes import tril_indices, normalize_colors +from mne.viz.circle import _plot_connectivity_circle_onpick + + +# copied from mne.viz.circle to add optional `plot_names` argument +def plot_connectivity_circle(con, node_names, indices=None, n_lines=None, + node_angles=None, node_width=None, + node_colors=None, facecolor='black', + textcolor='white', node_edgecolor='black', + linewidth=1.5, colormap='hot', vmin=None, + vmax=None, colorbar=True, title=None, + colorbar_size=0.2, colorbar_pos=(-0.3, 0.1), + fontsize_title=12, fontsize_names=8, + fontsize_colorbar=8, padding=6., + fig=None, subplot=111, interactive=True, + node_linewidth=2., plot_names=True, show=True): + """Visualize connectivity as a circular graph. + + Note: This code is based on the circle graph example by Nicolas P. Rougier + http://www.labri.fr/perso/nrougier/coding/. + + Parameters + ---------- + con : array + Connectivity scores. Can be a square matrix, or a 1D array. If a 1D + array is provided, "indices" has to be used to define the connection + indices. + node_names : list of str + Node names. The order corresponds to the order in con. + indices : tuple of arrays | None + Two arrays with indices of connections for which the connections + strenghts are defined in con. Only needed if con is a 1D array. + n_lines : int | None + If not None, only the n_lines strongest connections (strength=abs(con)) + are drawn. + node_angles : array, shape=(len(node_names,)) | None + Array with node positions in degrees. If None, the nodes are equally + spaced on the circle. See mne.viz.circular_layout. + node_width : float | None + Width of each node in degrees. If None, the minimum angle between any + two nodes is used as the width. + node_colors : list of tuples | list of str + List with the color to use for each node. If fewer colors than nodes + are provided, the colors will be repeated. Any color supported by + matplotlib can be used, e.g., RGBA tuples, named colors. + facecolor : str + Color to use for background. See matplotlib.colors. + textcolor : str + Color to use for text. See matplotlib.colors. + node_edgecolor : str + Color to use for lines around nodes. See matplotlib.colors. + linewidth : float + Line width to use for connections. + colormap : str + Colormap to use for coloring the connections. + vmin : float | None + Minimum value for colormap. If None, it is determined automatically. + vmax : float | None + Maximum value for colormap. If None, it is determined automatically. + colorbar : bool + Display a colorbar or not. + title : str + The figure title. + colorbar_size : float + Size of the colorbar. + colorbar_pos : 2-tuple + Position of the colorbar. + fontsize_title : int + Font size to use for title. + fontsize_names : int + Font size to use for node names. + fontsize_colorbar : int + Font size to use for colorbar. + padding : float + Space to add around figure to accommodate long labels. + fig : None | instance of matplotlib.pyplot.Figure + The figure to use. If None, a new figure with the specified background + color will be created. + subplot : int | 3-tuple + Location of the subplot when creating figures with multiple plots. E.g. + 121 or (1, 2, 1) for 1 row, 2 columns, plot 1. See + matplotlib.pyplot.subplot. + interactive : bool + When enabled, left-click on a node to show only connections to that + node. Right-click shows all connections. + node_linewidth : float + Line with for nodes. + plot_names : bool + Draw node names if True. + show : bool + Show figure if True. + + Returns + ------- + fig : instance of matplotlib.pyplot.Figure + The figure handle. + axes : instance of matplotlib.axes.PolarAxesSubplot + The subplot handle. + """ + import matplotlib.pyplot as plt + import matplotlib.path as m_path + import matplotlib.patches as m_patches + + n_nodes = len(node_names) + + if node_angles is not None: + if len(node_angles) != n_nodes: + raise ValueError('node_angles has to be the same length ' + 'as node_names') + # convert it to radians + node_angles = node_angles * np.pi / 180 + else: + # uniform layout on unit circle + node_angles = np.linspace(0, 2 * np.pi, n_nodes, endpoint=False) + + if node_width is None: + # widths correspond to the minimum angle between two nodes + dist_mat = node_angles[None, :] - node_angles[:, None] + dist_mat[np.diag_indices(n_nodes)] = 1e9 + node_width = np.min(np.abs(dist_mat)) + else: + node_width = node_width * np.pi / 180 + + if node_colors is not None: + if len(node_colors) < n_nodes: + node_colors = cycle(node_colors) + else: + # assign colors using colormap + node_colors = [plt.cm.spectral(i / float(n_nodes)) + for i in range(n_nodes)] + + # handle 1D and 2D connectivity information + if con.ndim == 1: + if indices is None: + raise ValueError('indices has to be provided if con.ndim == 1') + elif con.ndim == 2: + if con.shape[0] != n_nodes or con.shape[1] != n_nodes: + raise ValueError('con has to be 1D or a square matrix') + # we use the lower-triangular part + indices = tril_indices(n_nodes, -1) + con = con[indices] + else: + raise ValueError('con has to be 1D or a square matrix') + + # get the colormap + if isinstance(colormap, string_types): + colormap = plt.get_cmap(colormap) + + # Make figure background the same colors as axes + if fig is None: + fig = plt.figure(figsize=(8, 8), facecolor=facecolor) + + # Use a polar axes + if not isinstance(subplot, tuple): + subplot = (subplot,) + axes = plt.subplot(*subplot, polar=True, axisbg=facecolor) + + # No ticks, we'll put our own + plt.xticks([]) + plt.yticks([]) + + # Set y axes limit, add additonal space if requested + plt.ylim(0, 10 + padding) + + # Remove the black axes border which may obscure the labels + axes.spines['polar'].set_visible(False) + + # Draw lines between connected nodes, only draw the strongest connections + if n_lines is not None and len(con) > n_lines: + con_thresh = np.sort(np.abs(con).ravel())[-n_lines] + else: + con_thresh = 0. + + # get the connections which we are drawing and sort by connection strength + # this will allow us to draw the strongest connections first + con_abs = np.abs(con) + con_draw_idx = np.where(con_abs >= con_thresh)[0] + + con = con[con_draw_idx] + con_abs = con_abs[con_draw_idx] + indices = [ind[con_draw_idx] for ind in indices] + + # now sort them + sort_idx = np.argsort(con_abs) + con_abs = con_abs[sort_idx] + con = con[sort_idx] + indices = [ind[sort_idx] for ind in indices] + + # Get vmin vmax for color scaling + if vmin is None: + vmin = np.min(con[np.abs(con) >= con_thresh]) + if vmax is None: + vmax = np.max(con) + vrange = vmax - vmin + + # We want to add some "noise" to the start and end position of the + # edges: We modulate the noise with the number of connections of the + # node and the connection strength, such that the strongest connections + # are closer to the node center + nodes_n_con = np.zeros((n_nodes), dtype=np.int) + for i, j in zip(indices[0], indices[1]): + nodes_n_con[i] += 1 + nodes_n_con[j] += 1 + + # initalize random number generator so plot is reproducible + rng = np.random.mtrand.RandomState(seed=0) + + n_con = len(indices[0]) + noise_max = 0.25 * node_width + start_noise = rng.uniform(-noise_max, noise_max, n_con) + end_noise = rng.uniform(-noise_max, noise_max, n_con) + + nodes_n_con_seen = np.zeros_like(nodes_n_con) + for i, (start, end) in enumerate(zip(indices[0], indices[1])): + nodes_n_con_seen[start] += 1 + nodes_n_con_seen[end] += 1 + + start_noise[i] *= ((nodes_n_con[start] - nodes_n_con_seen[start]) / + float(nodes_n_con[start])) + end_noise[i] *= ((nodes_n_con[end] - nodes_n_con_seen[end]) / + float(nodes_n_con[end])) + + # scale connectivity for colormap (vmin<=>0, vmax<=>1) + con_val_scaled = (con - vmin) / vrange + + # Finally, we draw the connections + for pos, (i, j) in enumerate(zip(indices[0], indices[1])): + # Start point + t0, r0 = node_angles[i], 10 + + # End point + t1, r1 = node_angles[j], 10 + + # Some noise in start and end point + t0 += start_noise[pos] + t1 += end_noise[pos] + + verts = [(t0, r0), (t0, 5), (t1, 5), (t1, r1)] + codes = [m_path.Path.MOVETO, m_path.Path.CURVE4, m_path.Path.CURVE4, + m_path.Path.LINETO] + path = m_path.Path(verts, codes) + + color = colormap(con_val_scaled[pos]) + + # Actual line + patch = m_patches.PathPatch(path, fill=False, edgecolor=color, + linewidth=linewidth, alpha=1.) + axes.add_patch(patch) + + # Draw ring with colored nodes + height = np.ones(n_nodes) * 1.0 + bars = axes.bar(node_angles, height, width=node_width, bottom=9, + edgecolor=node_edgecolor, lw=node_linewidth, + facecolor='.9', align='center') + + for bar, color in zip(bars, node_colors): + bar.set_facecolor(color) + + # Draw node labels + if plot_names: + angles_deg = 180 * node_angles / np.pi + for name, angle_rad, angle_deg in zip(node_names, node_angles, + angles_deg): + if angle_deg >= 270: + ha = 'left' + else: + # Flip the label, so text is always upright + angle_deg += 180 + ha = 'right' + + axes.text(angle_rad, 10.4, name, size=fontsize_names, + rotation=angle_deg, rotation_mode='anchor', + horizontalalignment=ha, verticalalignment='center', + color=textcolor) + + if title is not None: + plt.title(title, color=textcolor, fontsize=fontsize_title, + axes=axes) + + if colorbar: + norm = normalize_colors(vmin=vmin, vmax=vmax) + sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm) + sm.set_array(np.linspace(vmin, vmax)) + cb = plt.colorbar(sm, ax=axes, use_gridspec=False, + shrink=colorbar_size, + anchor=colorbar_pos) + cb_yticks = plt.getp(cb.ax.axes, 'yticklabels') + cb.ax.tick_params(labelsize=fontsize_colorbar) + plt.setp(cb_yticks, color=textcolor) + + # Add callback for interaction + if interactive: + callback = partial(_plot_connectivity_circle_onpick, fig=fig, + axes=axes, indices=indices, n_nodes=n_nodes, + node_angles=node_angles) + + fig.canvas.mpl_connect('button_press_event', callback) + + plt_show(show) + return fig, axes + + +def _plot_connectivity_matrix_nodename(x, y, con, node_names): + x = int(round(x) - 2) + y = int(round(y) - 2) + if x < 0 or y < 0 or x >= len(node_names) or y >= len(node_names): + return '' + return '%s --> %s: %.2f' % (node_names[x], node_names[y], + con[y + 2, x + 2]) + + +def plot_connectivity_matrix(con, node_names, indices=None, + node_colors=None, facecolor='black', + textcolor='white', colormap='hot', vmin=None, + vmax=None, colorbar=True, title=None, + colorbar_size=0.2, colorbar_pos=(-0.3, 0.1), + fontsize_title=12, fontsize_names=8, + fontsize_colorbar=8, fig=None, subplot=111, + show_names=True): + """Visualize connectivity as a matrix. + + Parameters + ---------- + con : array + Connectivity scores. Can be a square matrix, or a 1D array. If a 1D + array is provided, "indices" has to be used to define the connection + indices. + node_names : list of str + Node names. The order corresponds to the order in con. + indices : tuple of arrays | None + Two arrays with indices of connections for which the connections + strenghts are defined in con. Only needed if con is a 1D array. + node_colors : list of tuples | list of str + List with the color to use for each node. If fewer colors than nodes + are provided, the colors will be repeated. Any color supported by + matplotlib can be used, e.g., RGBA tuples, named colors. + facecolor : str + Color to use for background. See matplotlib.colors. + textcolor : str + Color to use for text. See matplotlib.colors. + colormap : str + Colormap to use for coloring the connections. + vmin : float | None + Minimum value for colormap. If None, it is determined automatically. + vmax : float | None + Maximum value for colormap. If None, it is determined automatically. + colorbar : bool + Display a colorbar or not. + title : str + The figure title. + colorbar_size : float + Size of the colorbar. + colorbar_pos : 2-tuple + Position of the colorbar. + fontsize_title : int + Font size to use for title. + fontsize_names : int + Font size to use for node names. + fontsize_colorbar : int + Font size to use for colorbar. + padding : float + Space to add around figure to accommodate long labels. + fig : None | instance of matplotlib.pyplot.Figure + The figure to use. If None, a new figure with the specified background + color will be created. + subplot : int | 3-tuple + Location of the subplot when creating figures with multiple plots. E.g. + 121 or (1, 2, 1) for 1 row, 2 columns, plot 1. See + matplotlib.pyplot.subplot. + show_names : bool + Enable or disable display of node names in the plot. The names are + always displayed in the status bar when hovering over them. + + Returns + ------- + fig : instance of matplotlib.pyplot.Figure + The figure handle. + axes : instance of matplotlib.axes.PolarAxesSubplot + The subplot handle. + """ + import matplotlib.pyplot as plt + + n_nodes = len(node_names) + + if node_colors is not None: + if len(node_colors) < n_nodes: + node_colors = cycle(node_colors) + else: + # assign colors using colormap + node_colors = [plt.cm.spectral(i / float(n_nodes)) + for i in range(n_nodes)] + + # handle 1D and 2D connectivity information + if con.ndim == 1: + if indices is None: + raise ValueError('indices must be provided if con.ndim == 1') + tmp = np.zeros((n_nodes, n_nodes)) * np.nan + for ci in zip(con, *indices): + tmp[ci[1:]] = ci[0] + con = tmp + elif con.ndim == 2: + if con.shape[0] != n_nodes or con.shape[1] != n_nodes: + raise ValueError('con has to be 1D or a square matrix') + else: + raise ValueError('con has to be 1D or a square matrix') + + # remove diagonal (do not show node's self-connectivity) + con = con.copy() + np.fill_diagonal(con, np.nan) + + # get the colormap + if isinstance(colormap, string_types): + colormap = plt.get_cmap(colormap) + + # Make figure background the same colors as axes + if fig is None: + fig = plt.figure(figsize=(8, 8), facecolor=facecolor) + + if not isinstance(subplot, tuple): + subplot = (subplot,) + axes = plt.subplot(*subplot, axisbg=facecolor) + + axes.spines['bottom'].set_visible(False) + axes.spines['right'].set_visible(False) + axes.spines['left'].set_visible(False) + axes.spines['top'].set_visible(False) + + tmp = np.empty((n_nodes + 4, n_nodes + 4)) * np.nan + tmp[2:-2, 2:-2] = con + con = tmp + + h = axes.imshow(con, cmap=colormap, interpolation='nearest', vmin=vmin, + vmax=vmax) + + nodes = np.empty((n_nodes + 4, n_nodes + 4, 4)) * np.nan + for i in range(n_nodes): + nodes[i + 2, 0, :] = node_colors[i] + nodes[i + 2, -1, :] = node_colors[i] + nodes[0, i + 2, :] = node_colors[i] + nodes[-1, i + 2, :] = node_colors[i] + axes.imshow(nodes, interpolation='nearest') + + if colorbar: + cb = plt.colorbar(h, ax=axes, use_gridspec=False, + shrink=colorbar_size, + anchor=colorbar_pos) + cb_yticks = plt.getp(cb.ax.axes, 'yticklabels') + cb.ax.tick_params(labelsize=fontsize_colorbar) + plt.setp(cb_yticks, color=textcolor) + + if title is not None: + plt.title(title, color=textcolor, fontsize=fontsize_title, + axes=axes) + + # Draw node labels + if show_names: + for i, name in enumerate(node_names): + axes.text(-1, i + 2, name, size=fontsize_names, + rotation=0, rotation_mode='anchor', + horizontalalignment='right', verticalalignment='center', + color=textcolor) + axes.text(i + 2, len(node_names) + 4, name, size=fontsize_names, + rotation=90, rotation_mode='anchor', + horizontalalignment='right', verticalalignment='center', + color=textcolor) + + axes.format_coord = partial(_plot_connectivity_matrix_nodename, con=con, + node_names=node_names) + + return fig, axes + + +def plot_connectivity_inoutcircles(con, seed, node_names, facecolor='black', + textcolor='white', colormap='hot', + title=None, fontsize_suptitle=14, fig=None, + subplot=(121, 122), **kwargs): + """Visualize effective connectivity with two circular graphs, one for + incoming, and one for outgoing connections. + + Note: This code is based on the circle graph example by Nicolas P. Rougier + http://www.loria.fr/~rougier/coding/recipes.html + + Parameters + ---------- + con : array + Connectivity scores. Can be a square matrix, or a 1D array. If a 1D + array is provided, "indices" has to be used to define the connection + indices. + seed : int | str + Index or name of the seed node. Connections towards and from that node + are displayed. The seed can be changed by clicking on a node in + interactive mode. + node_names : list of str + Node names. The order corresponds to the order in con. + facecolor : str + Color to use for background. See matplotlib.colors. + textcolor : str + Color to use for text. See matplotlib.colors. + colormap : str | (str, str) + Colormap to use for coloring the connections. Can be a tuple of two + strings, in which case the first colormap is used for incoming, and the + second colormap for outgoing connections. + title : str + The figure title. + fontsize_suptitle : int + Font size to use for title. + fig : None | instance of matplotlib.pyplot.Figure + The figure to use. If None, a new figure with the specified background + color will be created. + subplot : (int, int) | (3-tuple, 3-tuple) + Location of the two subplots for incoming and outgoing connections. + E.g. 121 or (1, 2, 1) for 1 row, 2 columns, plot 1. See + matplotlib.pyplot.subplot. + **kwargs : + The remaining keyword-arguments will be passed directly to + plot_connectivity_circle. + + Returns + ------- + fig : instance of matplotlib.pyplot.Figure + The figure handle. + axes_in : instance of matplotlib.axes.PolarAxesSubplot + The subplot handle. + axes_out : instance of matplotlib.axes.PolarAxesSubplot + The subplot handle. + """ + import matplotlib.pyplot as plt + + n_nodes = len(node_names) + + if any(isinstance(seed, t) for t in string_types): + try: + seed = node_names.index(seed) + except ValueError: + from difflib import get_close_matches + close = get_close_matches(seed, node_names) + raise ValueError('{} is not in the list of node names. Did you ' + 'mean {}?'.format(seed, close)) + + if seed < 0 or seed >= n_nodes: + raise ValueError('seed={} is not in range [0, {}].' + .format(seed, n_nodes - 1)) + + if type(colormap) not in (tuple, list): + colormap = (colormap, colormap) + + # Default figure size accomodates two horizontally arranged circles + if fig is None: + fig = plt.figure(figsize=(8, 4), facecolor=facecolor) + + index_in = (np.array([seed] * n_nodes), + np.array([i for i in range(n_nodes)])) + index_out = index_in[::-1] + + fig, axes_in = plot_connectivity_circle(con[seed, :].ravel(), node_names, + indices=index_in, + colormap=colormap[0], fig=fig, + subplot=subplot[0], + title='incoming', **kwargs) + + fig, axes_out = plot_connectivity_circle(con[:, seed].ravel(), node_names, + indices=index_out, + colormap=colormap[1], fig=fig, + subplot=subplot[1], + title='outgoing', **kwargs) + + if title is not None: + plt.suptitle(title, color=textcolor, fontsize=fontsize_suptitle) + + return fig, axes_in, axes_out diff --git a/mne_sandbox/viz/tests/test_connecitvity.py b/mne_sandbox/viz/tests/test_connecitvity.py new file mode 100644 index 0000000..6449fc3 --- /dev/null +++ b/mne_sandbox/viz/tests/test_connecitvity.py @@ -0,0 +1,81 @@ +# Authors: Martin Billinger +# +# License: Simplified BSD + +import numpy as np +from numpy.testing import assert_array_equal +from nose.tools import assert_raises, assert_equal + +from mne_sandbox.viz import (plot_connectivity_circle, + plot_connectivity_matrix, + plot_connectivity_inoutcircles) +from mne_sandbox.viz.connectivity import _plot_connectivity_matrix_nodename + +# Set our plotters to test mode +import matplotlib +matplotlib.use('Agg') # for testing don't use X server + + +def test_plot_connectivity_circle(): + """Test plotting connecitvity circle + """ + label_names = ['bankssts-lh', 'bankssts-rh', 'caudalanteriorcingulate-lh', + 'caudalanteriorcingulate-rh', 'caudalmiddlefrontal-lh'] + + con = np.random.RandomState(42).rand(5, 5) + + plot_connectivity_circle(con, label_names, plot_names=True) + plot_connectivity_circle(con, label_names, plot_names=False) + + +def test_plot_connectivity_matrix(): + """Test plotting connecitvity matrix + """ + label_names = ['bankssts-lh', 'bankssts-rh', 'caudalanteriorcingulate-lh', + 'caudalanteriorcingulate-rh', 'caudalmiddlefrontal-lh'] + + con = np.random.RandomState(42).rand(5, 5) + con0 = con.copy() + + assert_raises(ValueError, plot_connectivity_matrix, + con=np.empty((2, 2, 2)), node_names=label_names) + assert_raises(ValueError, plot_connectivity_matrix, + con=np.empty((1, 2)), node_names=label_names) + + plot_connectivity_matrix(con, label_names, colormap='jet', title='Test') + + # check that function does not change arguments + assert_array_equal(con, con0) + + # test status bar text + labels = ['a', 'b', 'c', 'd'] + con = np.empty((8, 8)) + con[:] = np.nan + con[2:-2, 2:-2] = np.arange(16).reshape(4, 4) + + str1 = _plot_connectivity_matrix_nodename(1, 1, con, labels) + str2 = _plot_connectivity_matrix_nodename(2, 3, con, labels) + + assert_equal(str1, '') + assert_equal(str2, 'a --> b: 4.00') + + +def test_plot_connectivity_inoutcircles(): + """Test plotting directional connecitvity circles + """ + label_names = ['bankssts-lh', 'bankssts-rh', 'caudalanteriorcingulate-lh', + 'caudalanteriorcingulate-rh', 'caudalmiddlefrontal-lh'] + + con = np.random.RandomState(42).rand(5, 5) + con0 = con.copy() + + assert_raises(ValueError, plot_connectivity_inoutcircles, con, 'n/a', + label_names) + assert_raises(ValueError, plot_connectivity_inoutcircles, con, 99, + label_names) + + plot_connectivity_inoutcircles(con, 'bankssts-rh', label_names, + title='Test') + + # check that function does not change arguments + assert_array_equal(con, con0) diff --git a/setup.py b/setup.py index 5594c88..3198edb 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,8 @@ platforms='any', packages=[ 'mne_sandbox', + 'mne_sandbox.connectivity', 'mne_sandbox.preprocessing', + 'mne_sandbox.viz' ], - ) + )