Skip to content

Commit 3518ba9

Browse files
michaelosthegeColCarroll
authored andcommitted
ODE code style and details (#3687)
* don't use built-in as variable name * add test for DtypeError * move make_sens_ic to utils + use ones instead of zeros for test values + use make_sens_ic for setting the dydp test_value + n_states and n_theta args for augment_system function * also pass n_theta instead of n_p * enable all test on float32 * state the correct return shape in the docstring
1 parent a886abc commit 3518ba9

File tree

4 files changed

+73
-61
lines changed

4 files changed

+73
-61
lines changed

pymc3/ode/ode.py

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import scipy
44
import theano
55
import theano.tensor as tt
6-
from ..ode.utils import augment_system
6+
from ..ode import utils
77
from ..exceptions import ShapeError, DtypeError
88

99
_log = logging.getLogger('pymc3')
@@ -71,38 +71,13 @@ def __init__(self, func, times, *, n_states, n_theta, t0=0):
7171

7272
# Private
7373
self._augmented_times = np.insert(times, 0, t0).astype(floatX)
74-
self._augmented_func = augment_system(func, self.n_states, self.n_p)
75-
self._sens_ic = self._make_sens_ic()
74+
self._augmented_func = utils.augment_system(func, self.n_states, self.n_theta)
75+
self._sens_ic = utils.make_sens_ic(self.n_states, self.n_theta, floatX)
7676

7777
# Cache symbolic sensitivities by the hash of inputs
7878
self._apply_nodes = {}
7979
self._output_sensitivities = {}
8080

81-
def _make_sens_ic(self):
82-
"""
83-
The sensitivity matrix will always have consistent form. (n_states, n_states + n_theta)
84-
85-
If the first n_states entries of the parameters vector in the simulate call
86-
correspond to initial conditions of the system,
87-
then the first n_states columns of the sensitivity matrix should form
88-
an identity matrix.
89-
90-
If the last n_theta entries of the parameters vector in the simulate call
91-
correspond to ode paramaters, then the last n_theta columns in
92-
the sensitivity matrix will be 0.
93-
"""
94-
95-
# Initialize the sensitivity matrix to be 0 everywhere
96-
sens_matrix = np.zeros((self.n_states, self.n_states + self.n_theta), dtype=floatX)
97-
98-
# Slip in the identity matrix in the appropirate place
99-
sens_matrix[:,:self.n_states] = np.eye(self.n_states, dtype=floatX)
100-
101-
# We need the sensitivity matrix to be a vector (see augmented_function)
102-
# Ravel and return
103-
dydp = sens_matrix.ravel()
104-
return dydp
105-
10681
def _system(self, Y, t, p):
10782
"""This is the function that will be passed to odeint. Solves both ODE and sensitivities.
10883
@@ -151,10 +126,10 @@ def __call__(self, y0, theta, return_sens=False, **kwargs):
151126
y0 = tt.cast(tt.unbroadcast(tt.as_tensor_variable(y0), 0), floatX)
152127
theta = tt.cast(tt.unbroadcast(tt.as_tensor_variable(theta), 0), floatX)
153128
inputs = [y0, theta]
154-
for i, (input, itype) in enumerate(zip(inputs, self._itypes)):
155-
if not input.type == itype:
156-
raise ValueError('Input {} of type {} does not have the expected type of {}'.format(i, input.type, itype))
157-
129+
for i, (input_val, itype) in enumerate(zip(inputs, self._itypes)):
130+
if not input_val.type == itype:
131+
raise ValueError('Input {} of type {} does not have the expected type of {}'.format(i, input_val.type, itype))
132+
158133
# use default implementation to prepare symbolic outputs (via make_node)
159134
states, sens = super(theano.Op, self).__call__(y0, theta, **kwargs)
160135

pymc3/ode/utils.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,47 @@
33
import theano.tensor as tt
44

55

6-
def augment_system(ode_func, n, m):
6+
def make_sens_ic(n_states, n_theta, floatX):
7+
"""
8+
The sensitivity matrix will always have consistent form. (n_states, n_states + n_theta)
9+
10+
If the first n_states entries of the parameters vector in the simulate call
11+
correspond to initial conditions of the system,
12+
then the first n_states columns of the sensitivity matrix should form
13+
an identity matrix.
14+
15+
If the last n_theta entries of the parameters vector in the simulate call
16+
correspond to ode paramaters, then the last n_theta columns in
17+
the sensitivity matrix will be 0.
18+
19+
Parameters
20+
----------
21+
n_states : int
22+
Number of state variables in the ODE
23+
n_theta : int
24+
Number of ODE parameters
25+
floatX : str
26+
dtype to be used for the array
27+
28+
Returns
29+
-------
30+
dydp : array
31+
1D-array of shape (n_states * (n_states + n_theta),), representing the initial condition of the sensitivities
32+
"""
33+
34+
# Initialize the sensitivity matrix to be 0 everywhere
35+
sens_matrix = np.zeros((n_states, n_states + n_theta), dtype=floatX)
36+
37+
# Slip in the identity matrix in the appropirate place
38+
sens_matrix[:,:n_states] = np.eye(n_states, dtype=floatX)
39+
40+
# We need the sensitivity matrix to be a vector (see augmented_function)
41+
# Ravel and return
42+
dydp = sens_matrix.ravel()
43+
return dydp
44+
45+
46+
def augment_system(ode_func, n_states, n_theta):
747
"""
848
Function to create augmented system.
949
@@ -17,10 +57,10 @@ def augment_system(ode_func, n, m):
1757
----------
1858
ode_func : function
1959
Differential equation. Returns array-like.
20-
n : int
60+
n_states : int
2161
Number of rows of the sensitivity matrix. (n_states)
22-
m : int
23-
Number of columns of the sensitivity matrix. (n_states + n_theta)
62+
n_theta : int
63+
Number of ODE parameters
2464
2565
Returns
2666
-------
@@ -30,11 +70,11 @@ def augment_system(ode_func, n, m):
3070

3171
# Present state of the system
3272
t_y = tt.vector("y", dtype='float64')
33-
t_y.tag.test_value = np.zeros((n,), dtype='float64')
73+
t_y.tag.test_value = np.ones((n_states,), dtype='float64')
3474
# Parameter(s). Should be vector to allow for generaliztion to multiparameter
3575
# systems of ODEs. Is m dimensional because it includes all initial conditions as well as ode parameters
3676
t_p = tt.vector("p", dtype='float64')
37-
t_p.tag.test_value = np.zeros((m,), dtype='float64')
77+
t_p.tag.test_value = np.ones((n_states + n_theta,), dtype='float64')
3878
# Time. Allow for non-automonous systems of ODEs to be analyzed
3979
t_t = tt.scalar("t", dtype='float64')
4080
t_t.tag.test_value = 2.459
@@ -43,12 +83,12 @@ def augment_system(ode_func, n, m):
4383
# Will always be 0 unless the parameter is the inital condition
4484
# Entry i,j is partial of y[i] wrt to p[j]
4585
dydp_vec = tt.vector("dydp", dtype='float64')
46-
dydp_vec.tag.test_value = np.zeros(n * m, dtype='float64')
86+
dydp_vec.tag.test_value = make_sens_ic(n_states, n_theta, 'float64')
4787

48-
dydp = dydp_vec.reshape((n, m))
88+
dydp = dydp_vec.reshape((n_states, n_states + n_theta))
4989

5090
# Get symbolic representation of the ODEs by passing tensors for y, t and theta
51-
yhat = ode_func(t_y, t_t, t_p[n:])
91+
yhat = ode_func(t_y, t_t, t_p[n_states:])
5292
# Stack the results of the ode_func into a single tensor variable
5393
if not isinstance(yhat, (list, tuple)):
5494
yhat = (yhat,)

pymc3/tests/test_ode.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88
import pytest
99

1010

11-
@pytest.mark.xfail(
12-
condition=(theano.config.floatX == "float32"), reason="Fails on float32"
13-
)
1411
def test_gradients():
1512
"""Tests the computation of the sensitivities from the theano computation graph"""
1613

@@ -53,9 +50,6 @@ def augmented_system(Y, t, p):
5350
np.testing.assert_allclose(sensitivity, simulated_sensitivity, rtol=1e-5)
5451

5552

56-
@pytest.mark.xfail(
57-
condition=(theano.config.floatX == "float32"), reason="Fails on float32"
58-
)
5953
def test_simulate():
6054
"""Tests the integration in DifferentialEquation"""
6155

@@ -81,9 +75,6 @@ def ode_func(y, t, p):
8175
np.testing.assert_allclose(y, simulated_y, rtol=1e-5)
8276

8377

84-
@pytest.mark.xfail(
85-
condition=(theano.config.floatX == "float32"), reason="Fails on float32"
86-
)
8778
class TestSensitivityInitialCondition(object):
8879

8980
t = np.arange(0, 12, 0.25).reshape(-1, 1)
@@ -104,7 +95,7 @@ def ode_func_1(y, t, p):
10495
# Sensitivity initial condition for this model should be 1 by 2
10596
model1_sens_ic = np.array([1, 0])
10697

107-
np.testing.assert_array_equal(model1_sens_ic, model1._make_sens_ic())
98+
np.testing.assert_array_equal(model1_sens_ic, model1._sens_ic)
10899

109100
def test_sens_ic_scalar_2_param(self):
110101
# Scalar ODE 2 Param
@@ -118,7 +109,7 @@ def ode_func_2(y, t, p):
118109

119110
model2_sens_ic = np.array([1, 0, 0])
120111

121-
np.testing.assert_array_equal(model2_sens_ic, model2._make_sens_ic())
112+
np.testing.assert_array_equal(model2_sens_ic, model2._sens_ic)
122113

123114
def test_sens_ic_vector_1_param(self):
124115
# Vector ODE 1 Param
@@ -138,7 +129,7 @@ def ode_func_3(y, t, p):
138129
0, 1, 0
139130
])
140131

141-
np.testing.assert_array_equal(model3_sens_ic, model3._make_sens_ic())
132+
np.testing.assert_array_equal(model3_sens_ic, model3._sens_ic)
142133

143134
def test_sens_ic_vector_2_param(self):
144135
# Vector ODE 2 Param
@@ -158,7 +149,7 @@ def ode_func_4(y, t, p):
158149
0, 1, 0, 0
159150
])
160151

161-
np.testing.assert_array_equal(model4_sens_ic, model4._make_sens_ic())
152+
np.testing.assert_array_equal(model4_sens_ic, model4._sens_ic)
162153

163154
def test_sens_ic_vector_3_params(self):
164155
# Big System with Many Parameters
@@ -183,12 +174,9 @@ def ode_func_5(y, t, p):
183174
[0, 0, 1, 0, 0, 0]
184175
])
185176

186-
np.testing.assert_array_equal(np.ravel(model5_sens_ic), model5._make_sens_ic())
177+
np.testing.assert_array_equal(np.ravel(model5_sens_ic), model5._sens_ic)
187178

188179

189-
@pytest.mark.xfail(
190-
condition=(theano.config.floatX == "float32"), reason="Fails on float32"
191-
)
192180
def test_logp_scalar_ode():
193181
"""Test the computation of the log probability for these models"""
194182

@@ -279,9 +267,6 @@ def test_number_of_params(self):
279267
)
280268

281269

282-
@pytest.mark.xfail(
283-
condition=(theano.config.floatX == "float32"), reason="Fails on float32"
284-
)
285270
class TestDiffEqModel(object):
286271
def test_op_equality(self):
287272
"""Tests that the equality of mathematically identical Ops evaluates True"""

pymc3/tests/test_util.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,15 @@ def test_shape_error(self):
9696
with pytest.raises(pm.exceptions.ShapeError):
9797
raise err
9898
pass
99+
100+
def test_dtype_error(self):
101+
err = pm.exceptions.DtypeError('Without dtypes.')
102+
with pytest.raises(pm.exceptions.DtypeError):
103+
raise err
104+
105+
err = pm.exceptions.DtypeError('With shapes.', np.float64, np.float32)
106+
assert 'float64' in err.args[0]
107+
assert 'float32' in err.args[0]
108+
with pytest.raises(pm.exceptions.DtypeError):
109+
raise err
110+
pass

0 commit comments

Comments
 (0)