Skip to content

Commit 63a8e08

Browse files
committed
Fixed Yule-Walker fitting and added tests
1 parent 2b426dd commit 63a8e08

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

mne_sandbox/connectivity/mvar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def _acm(x, l):
2525
a = x[:, l:]
2626
b = x[:, 0:-l]
2727

28-
return np.dot(a[:, :], b[:, :].T) / a.shape[1]
28+
return np.dot(a[:, :], b[:, :].T).T / a.shape[1]
2929

3030

3131
def _epoch_autocorrelations(epoch, max_lag):

mne_sandbox/connectivity/tests/test_mvar.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
# License: BSD (3-clause)
44

55
import numpy as np
6-
from numpy.testing import assert_array_almost_equal, assert_array_less
6+
from numpy.testing import (assert_array_equal, assert_array_almost_equal,
7+
assert_array_less)
78
from nose.tools import assert_raises, assert_equal
89

910
from mne_sandbox.connectivity import mvar_connectivity
11+
from mne_sandbox.connectivity.mvar import _fit_mvar_lsq, _fit_mvar_yw
1012

1113

1214
def _make_data(var_coef, n_samples, n_epochs):
@@ -136,3 +138,27 @@ def test_mvar_connectivity():
136138
assert_array_less(p_vals[0][i, j, 0], 0.05)
137139
else:
138140
assert_array_less(0.05, p_vals[0][i, j, 0])
141+
142+
143+
def test_fit_mvar():
144+
"""Test MVAR model fitting"""
145+
np.random.seed(0)
146+
147+
n_sigs = 3
148+
n_epochs = 50
149+
n_samples = 200
150+
151+
var_coef = np.zeros((1, n_sigs, n_sigs))
152+
var_coef[0, :, :] = [[0.9, 0, 0],
153+
[1, 0.5, 0],
154+
[2, 0, -0.5]]
155+
data = _make_data(var_coef, n_samples, n_epochs)
156+
data0 = data.copy()
157+
158+
var = _fit_mvar_lsq(data, pmin=1, pmax=1, delta=0, n_jobs=1, verbose=0)
159+
assert_array_equal(data, data0)
160+
assert_array_almost_equal(var_coef[0], var.coef, decimal=2)
161+
162+
var = _fit_mvar_yw(data, pmin=1, pmax=1, n_jobs=1, verbose=0)
163+
assert_array_equal(data, data0)
164+
assert_array_almost_equal(var_coef[0], var.coef, decimal=2)

0 commit comments

Comments
 (0)