Skip to content

Commit 66e0252

Browse files
Use nwe name order in autoregressive component
1 parent 083e786 commit 66e0252

File tree

6 files changed

+37
-37
lines changed

6 files changed

+37
-37
lines changed

pymc_extras/statespace/models/structural/components/autoregressive.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -101,31 +101,31 @@ def populate_component_properties(self):
101101
]
102102

103103
self.shock_names = [f"{self.name}[{obs_name}]" for obs_name in self.observed_state_names]
104-
self.param_names = [f"{self.name}_params", f"{self.name}_sigma"]
105-
self.param_dims = {f"{self.name}_params": (f"{self.name}_lag",)}
106-
self.coords = {f"{self.name}_lag": self.ar_lags.tolist()}
104+
self.param_names = [f"params_{self.name}", f"sigma_{self.name}"]
105+
self.param_dims = {f"params_{self.name}": (f"lag_{self.name}",)}
106+
self.coords = {f"lag_{self.name}": self.ar_lags.tolist()}
107107

108108
if self.k_endog > 1:
109-
self.param_dims[f"{self.name}_params"] = (
110-
f"{self.name}_endog",
109+
self.param_dims[f"params_{self.name}"] = (
110+
f"endog_{self.name}",
111111
AR_PARAM_DIM,
112112
)
113-
self.param_dims[f"{self.name}_sigma"] = (f"{self.name}_endog",)
113+
self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",)
114114

115-
self.coords[f"{self.name}_endog"] = self.observed_state_names
115+
self.coords[f"endog_{self.name}"] = self.observed_state_names
116116

117117
self.param_info = {
118-
f"{self.name}_params": {
118+
f"params_{self.name}": {
119119
"shape": (self.k_states,) if self.k_endog == 1 else (self.k_endog, self.k_states),
120120
"constraints": None,
121121
"dims": (AR_PARAM_DIM,)
122122
if self.k_endog == 1
123123
else (
124-
f"{self.name}_endog",
124+
f"endog_{self.name}",
125125
AR_PARAM_DIM,
126126
),
127127
},
128-
f"{self.name}_sigma": {
128+
f"sigma_{self.name}": {
129129
"shape": () if self.k_endog == 1 else (self.k_endog,),
130130
"constraints": "Positive",
131131
"dims": None if self.k_endog == 1 else (f"{self.name}_endog",),
@@ -139,10 +139,10 @@ def make_symbolic_graph(self) -> None:
139139

140140
k_nonzero = int(sum(self.order))
141141
ar_params = self.make_and_register_variable(
142-
f"{self.name}_params", shape=(k_nonzero,) if k_endog == 1 else (k_endog, k_nonzero)
142+
f"params_{self.name}", shape=(k_nonzero,) if k_endog == 1 else (k_endog, k_nonzero)
143143
)
144144
sigma_ar = self.make_and_register_variable(
145-
f"{self.name}_sigma", shape=() if k_endog == 1 else (k_endog,)
145+
f"sigma_{self.name}", shape=() if k_endog == 1 else (k_endog,)
146146
)
147147

148148
if k_endog == 1:

pymc_extras/statespace/utils/constants.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
SHOCK_DIM = "shock"
88
SHOCK_AUX_DIM = "shock_aux"
99
TIME_DIM = "time"
10-
AR_PARAM_DIM = "ar_lag"
11-
MA_PARAM_DIM = "ma_lag"
12-
SEASONAL_AR_PARAM_DIM = "seasonal_ar_lag"
13-
SEASONAL_MA_PARAM_DIM = "seasonal_ma_lag"
10+
AR_PARAM_DIM = "lag_ar"
11+
MA_PARAM_DIM = "lag_ma"
12+
SEASONAL_AR_PARAM_DIM = "seasonal_lag_ar"
13+
SEASONAL_MA_PARAM_DIM = "seasonal_lag_ma"
1414
ETS_SEASONAL_DIM = "seasonal_lag"
1515

1616
NEVER_TIME_VARYING = ["initial_state", "initial_state_cov", "a0", "P0"]

tests/statespace/models/structural/components/test_autoregressive.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_autoregressive_model(order, rng):
2121
lags = np.arange(len(order) if isinstance(order, list) else order, dtype="int") + 1
2222
if isinstance(order, list):
2323
lags = lags[np.flatnonzero(order)]
24-
assert_allclose(ar.coords["auto_regressive_lag"], lags)
24+
assert_allclose(ar.coords["lag_auto_regressive"], lags)
2525

2626

2727
def test_autoregressive_multiple_observed_build(rng):
@@ -44,15 +44,15 @@ def test_autoregressive_multiple_observed_build(rng):
4444
assert mod.shock_names == ["auto_regressive[data_1]", "auto_regressive[data_2]"]
4545

4646
params = {
47-
"auto_regressive_params": np.full(
47+
"params_auto_regressive": np.full(
4848
(
4949
2,
5050
sum(ar.order),
5151
),
5252
0.5,
5353
dtype=config.floatX,
5454
),
55-
"auto_regressive_sigma": np.array([0.05, 0.12]),
55+
"sigma_auto_regressive": np.array([0.05, 0.12]),
5656
}
5757
_, _, _, _, T, Z, R, _, Q = mod._unpack_statespace_with_placeholders()
5858
input_vars = explicit_graph_inputs([T, Z, R, Q])
@@ -94,16 +94,16 @@ def test_autoregressive_multiple_observed_data(rng):
9494
mod = ar.build(verbose=False)
9595

9696
params = {
97-
"auto_regressive_params": np.array([0.9, 0.8, 0.5]).reshape((3, 1)),
98-
"auto_regressive_sigma": np.array([0.05, 0.12, 0.22]),
97+
"params_auto_regressive": np.array([0.9, 0.8, 0.5]).reshape((3, 1)),
98+
"sigma_auto_regressive": np.array([0.05, 0.12, 0.22]),
9999
"initial_state_cov": np.eye(3),
100100
}
101101

102102
# Recover the AR(1) coefficients from the simulated data via OLS
103103
x, y = simulate_from_numpy_model(mod, rng, params, steps=2000)
104104
for i in range(3):
105105
ols_coefs = np.polyfit(y[:-1, i], y[1:, i], 1)
106-
np.testing.assert_allclose(ols_coefs[0], params["auto_regressive_params"][i, 0], atol=1e-1)
106+
np.testing.assert_allclose(ols_coefs[0], params["params_auto_regressive"][i, 0], atol=1e-1)
107107

108108

109109
def test_add_autoregressive_different_observed():
@@ -128,5 +128,5 @@ def test_add_autoregressive_different_observed():
128128
]
129129

130130
assert mod.shock_names == ["ar1[data_1]", "ar6[data_2]"]
131-
assert mod.coords["ar1_lag"] == [1]
132-
assert mod.coords["ar6_lag"] == [1, 2, 3, 4, 5, 6]
131+
assert mod.coords["lag_ar1"] == [1]
132+
assert mod.coords["lag_ar6"] == [1, 2, 3, 4, 5, 6]

tests/statespace/models/structural/test_against_statsmodels.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -405,22 +405,22 @@ def create_structural_model_and_equivalent_statsmodel(
405405

406406
if autoregressive is not None:
407407
ar_names = [f"L{i+1}" for i in range(autoregressive)]
408-
ar_params = rng.normal(size=(autoregressive,)).astype(floatX)
408+
params_ar = rng.normal(size=(autoregressive,)).astype(floatX)
409409
if autoregressive == 1:
410-
ar_params = ar_params.item()
410+
params_ar = params_ar.item()
411411
sigma2 = np.abs(rng.normal()).astype(floatX)
412412

413-
params["ar_params"] = ar_params
414-
params["ar_sigma"] = np.sqrt(sigma2)
415-
expected_param_dims["ar_params"] += (AR_PARAM_DIM,)
413+
params["params_ar"] = params_ar
414+
params["sigma_ar"] = np.sqrt(sigma2)
415+
expected_param_dims["params_ar"] += (AR_PARAM_DIM,)
416416
expected_coords[AR_PARAM_DIM] += tuple(list(range(1, autoregressive + 1)))
417417
expected_coords[ALL_STATE_DIM] += ar_names
418418
expected_coords[ALL_STATE_AUX_DIM] += ar_names
419419
expected_coords[SHOCK_DIM] += ["ar"]
420420
expected_coords[SHOCK_AUX_DIM] += ["ar"]
421421

422422
sm_params["sigma2.ar"] = sigma2
423-
for i, rho in enumerate(ar_params):
423+
for i, rho in enumerate(params_ar):
424424
sm_init[f"ar.L{i+1}"] = 0
425425
sm_params[f"ar.L{i+1}"] = rho
426426

tests/statespace/models/test_SARIMAX.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@ def pymc_mod(arima_mod):
178178
# x0 = pm.Normal('x0', dims=['state'])
179179
# P0_diag = pm.Gamma('P0_diag', alpha=2, beta=1, dims=['state'])
180180
# P0 = pm.Deterministic('P0', pt.diag(P0_diag), dims=['state', 'state_aux'])
181-
ar_params = pm.Normal("ar_params", sigma=0.1, dims=["ar_lag"])
182-
ma_params = pm.Normal("ma_params", sigma=1, dims=["ma_lag"])
181+
ar_params = pm.Normal("ar_params", sigma=0.1, dims=["lag_ar"])
182+
ma_params = pm.Normal("ma_params", sigma=1, dims=["lag_ma"])
183183
sigma_state = pm.Exponential("sigma_state", 0.5)
184184
arima_mod.build_statespace_graph(data=data, save_kalman_filter_outputs_in_idata=True)
185185

@@ -207,8 +207,8 @@ def pymc_mod_interp(arima_mod_interp):
207207
P0 = pm.Deterministic(
208208
"P0", pt.eye(arima_mod_interp.k_states) * P0_sigma, dims=["state", "state_aux"]
209209
)
210-
ar_params = pm.Normal("ar_params", sigma=0.1, dims=["ar_lag"])
211-
ma_params = pm.Normal("ma_params", sigma=1, dims=["ma_lag"])
210+
ar_params = pm.Normal("ar_params", sigma=0.1, dims=["lag_ar"])
211+
ma_params = pm.Normal("ma_params", sigma=1, dims=["lag_ma"])
212212
sigma_state = pm.Exponential("sigma_state", 0.5)
213213
sigma_obs = pm.Exponential("sigma_obs", 0.1)
214214

@@ -344,8 +344,8 @@ def test_interpretable_states_are_interpretable(arima_mod_interp, pymc_mod_inter
344344
prior = pm.sample_prior_predictive(draws=10)
345345

346346
prior_outputs = arima_mod_interp.sample_unconditional_prior(prior)
347-
ar_lags = prior.prior.coords["ar_lag"].values - 1
348-
ma_lags = prior.prior.coords["ma_lag"].values - 1
347+
ar_lags = prior.prior.coords["lag_ar"].values - 1
348+
ma_lags = prior.prior.coords["lag_ma"].values - 1
349349

350350
# Check the first p states are lags of the previous state
351351
for t, tm1 in zip(ar_lags[1:], ar_lags[:-1]):

tests/statespace/models/test_VARMAX.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def pymc_mod(varma_mod, data):
5757
"state_chol", n=varma_mod.k_posdef, eta=1, sd_dist=pm.Exponential.dist(1)
5858
)
5959
ar_params = pm.Normal(
60-
"ar_params", mu=0, sigma=0.1, dims=["observed_state", "ar_lag", "observed_state_aux"]
60+
"ar_params", mu=0, sigma=0.1, dims=["observed_state", "lag_ar", "observed_state_aux"]
6161
)
6262
state_cov = pm.Deterministic(
6363
"state_cov", state_chol @ state_chol.T, dims=["shock", "shock_aux"]

0 commit comments

Comments
 (0)