Skip to content

Commit dcd517d

Browse files
authored
MAINT: Use multivariate_normal via random_state (#581)
* MAINT: Use multivariate_normal via random_state * MAINT: PEP8 compliance
1 parent 581631c commit dcd517d

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

quantecon/lss.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from textwrap import dedent
1212
import numpy as np
13-
from numpy.random import multivariate_normal
1413
from scipy.linalg import solve
1514
from .matrix_eqn import solve_discrete_lyapunov
1615
from numba import jit
@@ -186,7 +185,8 @@ def simulate(self, ts_length=100, random_state=None):
186185
"""
187186
random_state = check_random_state(random_state)
188187

189-
x0 = multivariate_normal(self.mu_0.flatten(), self.Sigma_0)
188+
x0 = random_state.multivariate_normal(self.mu_0.flatten(),
189+
self.Sigma_0)
190190
w = random_state.randn(self.m, ts_length-1)
191191
v = self.C.dot(w) # Multiply each w_t by C to get v_t = C w_t
192192
# == simulate time series == #
@@ -447,8 +447,8 @@ def __partition(self):
447447
A_diag = np.diag(A)
448448
num_const = 0
449449
for idx in range(n):
450-
if (A_diag[idx] == 1) and (C[idx, :] == 0).all() \
451-
and np.linalg.norm(A[idx, :]) == 1:
450+
if (A_diag[idx] == 1) and (C[idx, :] == 0).all() and \
451+
np.linalg.norm(A[idx, :]) == 1:
452452
sorted_idx.insert(0, idx)
453453
num_const += 1
454454
else:

quantecon/tests/test_lss.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ def test_stationarity(self):
5252

5353
assert_allclose(ssmux.flatten(), np.array([2.5, 2.5, 1]))
5454
assert_allclose(ssmuy.flatten(), np.array([2.5]))
55-
assert_allclose(sssigx, self.ss2.A @ sssigx @ self.ss2.A.T + self.ss2.C @ self.ss2.C.T)
55+
assert_allclose(
56+
sssigx,
57+
self.ss2.A @ sssigx @ self.ss2.A.T + self.ss2.C @ self.ss2.C.T
58+
)
5659
assert_allclose(sssigy, self.ss2.G @ sssigx @ self.ss2.G.T)
5760
assert_allclose(sssigyx, self.ss2.G @ sssigx)
5861

@@ -61,14 +64,14 @@ def test_simulate(self):
6164

6265
sim = ss.simulate(ts_length=250)
6366
for arr in sim:
64-
self.assertTrue(len(arr[0])==250)
67+
self.assertTrue(len(arr[0]) == 250)
6568

6669
def test_simulate_with_seed(self):
6770
ss = self.ss1
6871

6972
xval, yval = ss.simulate(ts_length=5, random_state=5)
70-
expected_output = np.array([0.75 , 0.73456137, 0.6812898, 0.76876387,
71-
.71772107])
73+
expected_output = np.array([0.75, 0.69595649, 0.78269723, 0.73095776,
74+
0.69989036])
7275

7376
assert_allclose(xval[0], expected_output)
7477
assert_allclose(yval[0], expected_output)
@@ -82,8 +85,8 @@ def test_replicate(self):
8285

8386
def test_replicate_with_seed(self):
8487
xval, yval = self.ss1.replicate(T=100, num_reps=5, random_state=5)
85-
expected_output = np.array([0.06871204, 0.06937119, -0.1478022,
86-
0.23841252, -0.06823762])
88+
expected_output = np.array([0.10498898, 0.02892168, 0.04915998,
89+
0.18568489, 0.04541764])
8790

8891
assert_allclose(xval[0], expected_output)
8992
assert_allclose(yval[0], expected_output)
@@ -101,4 +104,3 @@ def test_non_square_A():
101104
if __name__ == '__main__':
102105
suite = unittest.TestLoader().loadTestsFromTestCase(TestLinearStateSpace)
103106
unittest.TextTestRunner(verbosity=2, stream=sys.stderr).run(suite)
104-

0 commit comments

Comments
 (0)