Skip to content

Commit 12039e0

Browse files
ricardoV94twiecki
authored andcommitted
Rename GaussianRandomWalk init argument to init_dist and fix some type hints
1 parent 42e3745 commit 12039e0

File tree

2 files changed

+62
-43
lines changed

2 files changed

+62
-43
lines changed

pymc/distributions/timeseries.py

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -125,22 +125,24 @@ class GaussianRandomWalkRV(RandomVariable):
125125
dtype = "floatX"
126126
_print_name = ("GaussianRandomWalk", "\\operatorname{GaussianRandomWalk}")
127127

128-
def make_node(self, rng, size, dtype, mu, sigma, init, steps):
128+
def make_node(self, rng, size, dtype, mu, sigma, init_dist, steps):
129129
steps = at.as_tensor_variable(steps)
130130
if not steps.ndim == 0 or not steps.dtype.startswith("int"):
131131
raise ValueError("steps must be an integer scalar (ndim=0).")
132132

133133
mu = at.as_tensor_variable(mu)
134134
sigma = at.as_tensor_variable(sigma)
135-
init = at.as_tensor_variable(init)
135+
init_dist = at.as_tensor_variable(init_dist)
136136

137137
# Resize init distribution
138138
size = normalize_size_param(size)
139139
# If not explicit, size is determined by the shapes of mu, sigma, and init
140-
init_size = size if not rv_size_is_none(size) else at.broadcast_shape(mu, sigma, init)
141-
init = change_rv_size(init, init_size)
140+
init_dist_size = (
141+
size if not rv_size_is_none(size) else at.broadcast_shape(mu, sigma, init_dist)
142+
)
143+
init_dist = change_rv_size(init_dist, init_dist_size)
142144

143-
return super().make_node(rng, size, dtype, mu, sigma, init, steps)
145+
return super().make_node(rng, size, dtype, mu, sigma, init_dist, steps)
144146

145147
def _supp_shape_from_params(self, dist_params, reop_param_idx=0, param_shapes=None):
146148
steps = dist_params[3]
@@ -153,7 +155,7 @@ def rng_fn(
153155
rng: np.random.RandomState,
154156
mu: Union[np.ndarray, float],
155157
sigma: Union[np.ndarray, float],
156-
init: float,
158+
init_dist: Union[np.ndarray, float],
157159
steps: int,
158160
size: Tuple[int],
159161
) -> np.ndarray:
@@ -170,16 +172,16 @@ def rng_fn(
170172
----------
171173
rng: np.random.RandomState
172174
Numpy random number generator
173-
mu: array_like
175+
mu: array_like of float
174176
Random walk mean
175-
sigma: np.ndarray
177+
sigma: array_like of float
176178
Standard deviation of innovation (sigma > 0)
177-
init: float
179+
init_dist: array_like of float
178180
Initialization value for GaussianRandomWalk
179181
steps: int
180182
Length of random walk, must be greater than 1. Returned array will be of size+1 to
181183
account as first value is initial value
182-
size: int
184+
size: tuple of int
183185
The number of Random Walk time series generated
184186
185187
Returns
@@ -196,7 +198,7 @@ def rng_fn(
196198
bcast_shape = np.broadcast_shapes(
197199
np.asarray(mu).shape,
198200
np.asarray(sigma).shape,
199-
np.asarray(init).shape,
201+
np.asarray(init_dist).shape,
200202
)
201203
dist_shape = (*bcast_shape, int(steps))
202204

@@ -207,7 +209,7 @@ def rng_fn(
207209
# Add one dimension to the right, so that mu and sigma broadcast safely along
208210
# the steps dimension
209211
innovations = rng.normal(loc=mu[..., None], scale=sigma[..., None], size=dist_shape)
210-
grw = np.concatenate([init[..., None], innovations], axis=-1)
212+
grw = np.concatenate([init_dist[..., None], innovations], axis=-1)
211213
return np.cumsum(grw, axis=-1)
212214

213215

@@ -223,7 +225,7 @@ class GaussianRandomWalk(distribution.Continuous):
223225
innovation drift, defaults to 0.0
224226
sigma : tensor_like of float, optional
225227
sigma > 0, innovation standard deviation, defaults to 1.0
226-
init : unnamed distribution
228+
init_dist : unnamed distribution
227229
Univariate distribution of the initial value, created with the `.dist()` API.
228230
Defaults to a unit Normal.
229231
@@ -248,7 +250,7 @@ def __new__(cls, *args, steps=None, **kwargs):
248250

249251
@classmethod
250252
def dist(
251-
cls, mu=0.0, sigma=1.0, *, init=None, steps=None, size=None, **kwargs
253+
cls, mu=0.0, sigma=1.0, *, init_dist=None, steps=None, size=None, **kwargs
252254
) -> at.TensorVariable:
253255

254256
mu = at.as_tensor_variable(floatX(mu))
@@ -263,27 +265,34 @@ def dist(
263265
raise ValueError("Must specify steps or shape parameter")
264266
steps = at.as_tensor_variable(intX(steps))
265267

268+
if "init" in kwargs:
269+
warnings.warn(
270+
"init parameter is now called init_dist. Using init will raise an error in a future release.",
271+
FutureWarning,
272+
)
273+
init_dist = kwargs.pop("init")
274+
266275
# If no scalar distribution is passed then initialize with a Normal of same mu and sigma
267-
if init is None:
268-
init = Normal.dist(0, 1)
276+
if init_dist is None:
277+
init_dist = Normal.dist(0, 1)
269278
else:
270279
if not (
271-
isinstance(init, at.TensorVariable)
272-
and init.owner is not None
273-
and isinstance(init.owner.op, RandomVariable)
274-
and init.owner.op.ndim_supp == 0
280+
isinstance(init_dist, at.TensorVariable)
281+
and init_dist.owner is not None
282+
and isinstance(init_dist.owner.op, RandomVariable)
283+
and init_dist.owner.op.ndim_supp == 0
275284
):
276285
raise TypeError("init must be a univariate distribution variable")
277-
check_dist_not_registered(init)
286+
check_dist_not_registered(init_dist)
278287

279288
# Ignores logprob of init var because that's accounted for in the logp method
280-
init = ignore_logprob(init)
289+
init_dist = ignore_logprob(init_dist)
281290

282-
return super().dist([mu, sigma, init, steps], size=size, **kwargs)
291+
return super().dist([mu, sigma, init_dist, steps], size=size, **kwargs)
283292

284-
def moment(rv, size, mu, sigma, init, steps):
293+
def moment(rv, size, mu, sigma, init_dist, steps):
285294
grw_moment = at.zeros_like(rv)
286-
grw_moment = at.set_subtensor(grw_moment[..., 0], moment(init))
295+
grw_moment = at.set_subtensor(grw_moment[..., 0], moment(init_dist))
287296
# Add one dimension to the right, so that mu broadcasts safely along the steps
288297
# dimension
289298
grw_moment = at.set_subtensor(grw_moment[..., 1:], mu[..., None])
@@ -293,13 +302,13 @@ def logp(
293302
value: at.Variable,
294303
mu: at.Variable,
295304
sigma: at.Variable,
296-
init: at.Variable,
305+
init_dist: at.Variable,
297306
steps: at.Variable,
298307
) -> at.TensorVariable:
299308
"""Calculate log-probability of Gaussian Random Walk distribution at specified value."""
300309

301310
# Calculate initialization logp
302-
init_logp = logp(init, value[..., 0])
311+
init_logp = logp(init_dist, value[..., 0])
303312

304313
# Make time series stationary around the mean value
305314
stationary_series = value[..., 1:] - value[..., :-1]
@@ -429,7 +438,7 @@ def dist(
429438
"init parameter is now called init_dist. Using init will raise an error in a future release.",
430439
FutureWarning,
431440
)
432-
init_dist = kwargs["init"]
441+
init_dist = kwargs.pop("init")
433442

434443
ar_order = cls._get_ar_order(rhos=rhos, constant=constant, ar_order=ar_order)
435444
steps = get_steps(steps=steps, shape=kwargs.get("shape", None), step_shape_offset=ar_order)

pymc/tests/test_distributions_timeseries.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ class TestGaussianRandomWalkRandom(BaseTestDistributionRandom):
100100
size = None
101101

102102
pymc_dist = pm.GaussianRandomWalk
103-
pymc_dist_params = {"mu": 1.0, "sigma": 2, "init": pm.Constant.dist(0), "steps": 4}
104-
expected_rv_op_params = {"mu": 1.0, "sigma": 2, "init": pm.Constant.dist(0), "steps": 4}
103+
pymc_dist_params = {"mu": 1.0, "sigma": 2, "init_dist": pm.Constant.dist(0), "steps": 4}
104+
expected_rv_op_params = {"mu": 1.0, "sigma": 2, "init_dist": pm.Constant.dist(0), "steps": 4}
105105

106106
checks_to_run = [
107107
"check_pymc_params_match_rv_op",
@@ -142,36 +142,38 @@ def test_gaussianrandomwalk_inference(self):
142142
@pytest.mark.parametrize("init", [None, pm.Normal.dist()])
143143
def test_gaussian_random_walk_init_dist_shape(self, init):
144144
"""Test that init_dist is properly resized"""
145-
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init=init)
145+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init_dist=init)
146146
assert tuple(grw.owner.inputs[-2].shape.eval()) == ()
147147

148-
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init=init, size=(5,))
148+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init_dist=init, size=(5,))
149149
assert tuple(grw.owner.inputs[-2].shape.eval()) == (5,)
150150

151-
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init=init, shape=2)
151+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init_dist=init, shape=2)
152152
assert tuple(grw.owner.inputs[-2].shape.eval()) == ()
153153

154-
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init=init, shape=(5, 2))
154+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init_dist=init, shape=(5, 2))
155155
assert tuple(grw.owner.inputs[-2].shape.eval()) == (5,)
156156

157-
grw = pm.GaussianRandomWalk.dist(mu=[0, 0], sigma=1, steps=1, init=init)
157+
grw = pm.GaussianRandomWalk.dist(mu=[0, 0], sigma=1, steps=1, init_dist=init)
158158
assert tuple(grw.owner.inputs[-2].shape.eval()) == (2,)
159159

160-
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=[1, 1], steps=1, init=init)
160+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=[1, 1], steps=1, init_dist=init)
161161
assert tuple(grw.owner.inputs[-2].shape.eval()) == (2,)
162162

163-
grw = pm.GaussianRandomWalk.dist(mu=np.zeros((3, 1)), sigma=[1, 1], steps=1, init=init)
163+
grw = pm.GaussianRandomWalk.dist(mu=np.zeros((3, 1)), sigma=[1, 1], steps=1, init_dist=init)
164164
assert tuple(grw.owner.inputs[-2].shape.eval()) == (3, 2)
165165

166166
def test_shape_ellipsis(self):
167167
grw = pm.GaussianRandomWalk.dist(
168-
mu=0, sigma=1, steps=5, init=pm.Normal.dist(), shape=(3, ...)
168+
mu=0, sigma=1, steps=5, init_dist=pm.Normal.dist(), shape=(3, ...)
169169
)
170170
assert tuple(grw.shape.eval()) == (3, 6)
171171
assert tuple(grw.owner.inputs[-2].shape.eval()) == (3,)
172172

173173
def test_gaussianrandomwalk_broadcasted_by_init_dist(self):
174-
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=4, init=pm.Normal.dist(size=(2, 3)))
174+
grw = pm.GaussianRandomWalk.dist(
175+
mu=0, sigma=1, steps=4, init_dist=pm.Normal.dist(size=(2, 3))
176+
)
175177
assert tuple(grw.shape.eval()) == (2, 3, 5)
176178
assert grw.eval().shape == (2, 3, 5)
177179

@@ -210,14 +212,14 @@ def test_inferred_steps_from_observed(self):
210212
],
211213
)
212214
def test_gaussian_random_walk_init_dist_logp(self, init):
213-
grw = pm.GaussianRandomWalk.dist(init=init, steps=1)
215+
grw = pm.GaussianRandomWalk.dist(init_dist=init, steps=1)
214216
assert np.isclose(
215217
pm.logp(grw, [0, 0]).eval(),
216218
pm.logp(init, 0).eval() + scipy.stats.norm.logpdf(0),
217219
)
218220

219221
@pytest.mark.parametrize(
220-
"mu, sigma, init, steps, size, expected",
222+
"mu, sigma, init_dist, steps, size, expected",
221223
[
222224
(0, 1, Normal.dist(1), 10, None, np.ones((11,))),
223225
(1, 1, Normal.dist(0), 10, (2,), np.full((2, 11), np.arange(11))),
@@ -233,11 +235,15 @@ def test_gaussian_random_walk_init_dist_logp(self, init):
233235
),
234236
],
235237
)
236-
def test_moment(self, mu, sigma, init, steps, size, expected):
238+
def test_moment(self, mu, sigma, init_dist, steps, size, expected):
237239
with Model() as model:
238-
GaussianRandomWalk("x", mu=mu, sigma=sigma, init=init, steps=steps, size=size)
240+
GaussianRandomWalk("x", mu=mu, sigma=sigma, init_dist=init_dist, steps=steps, size=size)
239241
assert_moment_is_expected(model, expected)
240242

243+
def test_init_deprecated_arg(self):
244+
with pytest.warns(FutureWarning, match="init parameter is now called init_dist"):
245+
pm.GaussianRandomWalk.dist(init=Normal.dist(), shape=(10,))
246+
241247

242248
class TestAR:
243249
def test_order1_logp(self):
@@ -434,6 +440,10 @@ def test_moment(self, size, expected):
434440
AR("x", rho=[0, 0], init_dist=init_dist, steps=5, size=size)
435441
assert_moment_is_expected(model, expected, check_finite_logp=False)
436442

443+
def test_init_deprecated_arg(self):
444+
with pytest.warns(FutureWarning, match="init parameter is now called init_dist"):
445+
pm.AR.dist(rho=[1, 2, 3], init=Normal.dist(), shape=(10,))
446+
437447

438448
@pytest.mark.xfail(reason="Timeseries not refactored", raises=NotImplementedError)
439449
def test_GARCH11():

0 commit comments

Comments
 (0)