|
3 | 3 | # License: BSD (3-clause) |
4 | 4 |
|
5 | 5 | import numpy as np |
6 | | -from numpy.testing import assert_array_almost_equal, assert_array_less |
| 6 | +from numpy.testing import (assert_array_equal, assert_array_almost_equal, |
| 7 | + assert_array_less) |
7 | 8 | from nose.tools import assert_raises, assert_equal |
8 | 9 |
|
9 | 10 | from mne_sandbox.connectivity import mvar_connectivity |
| 11 | +from mne_sandbox.connectivity.mvar import _fit_mvar_lsq, _fit_mvar_yw |
10 | 12 |
|
11 | 13 |
|
12 | 14 | def _make_data(var_coef, n_samples, n_epochs): |
@@ -136,3 +138,27 @@ def test_mvar_connectivity(): |
136 | 138 | assert_array_less(p_vals[0][i, j, 0], 0.05) |
137 | 139 | else: |
138 | 140 | assert_array_less(0.05, p_vals[0][i, j, 0]) |
| 141 | + |
| 142 | + |
| 143 | +def test_fit_mvar(): |
| 144 | + """Test MVAR model fitting""" |
| 145 | + np.random.seed(0) |
| 146 | + |
| 147 | + n_sigs = 3 |
| 148 | + n_epochs = 50 |
| 149 | + n_samples = 200 |
| 150 | + |
| 151 | + var_coef = np.zeros((1, n_sigs, n_sigs)) |
| 152 | + var_coef[0, :, :] = [[0.9, 0, 0], |
| 153 | + [1, 0.5, 0], |
| 154 | + [2, 0, -0.5]] |
| 155 | + data = _make_data(var_coef, n_samples, n_epochs) |
| 156 | + data0 = data.copy() |
| 157 | + |
| 158 | + var = _fit_mvar_lsq(data, pmin=1, pmax=1, delta=0, n_jobs=1, verbose=0) |
| 159 | + assert_array_equal(data, data0) |
| 160 | + assert_array_almost_equal(var_coef[0], var.coef, decimal=2) |
| 161 | + |
| 162 | + var = _fit_mvar_yw(data, pmin=1, pmax=1, n_jobs=1, verbose=0) |
| 163 | + assert_array_equal(data, data0) |
| 164 | + assert_array_almost_equal(var_coef[0], var.coef, decimal=2) |
0 commit comments