Skip to content

Commit 92eae7a

Browse files
committed
merged resolved conflicts
2 parents 46e4a39 + f584e79 commit 92eae7a

File tree

17 files changed

+896
-214
lines changed

17 files changed

+896
-214
lines changed

pymc_extras/statespace/models/structural/components/autoregressive.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class AutoregressiveComponent(Component):
6565
def __init__(
6666
self,
6767
order: int = 1,
68-
name: str = "AutoRegressive",
68+
name: str = "auto_regressive",
6969
observed_state_names: list[str] | None = None,
7070
):
7171
if observed_state_names is None:
@@ -92,27 +92,30 @@ def __init__(
9292
)
9393

9494
def populate_component_properties(self):
95+
k_states = self.k_states // self.k_endog
96+
9597
self.state_names = [
96-
f"L{i + 1}.{state_name}"
97-
for i in range(self.k_states)
98+
f"L{i + 1}[{state_name}]"
9899
for state_name in self.observed_state_names
100+
for i in range(k_states)
99101
]
100-
self.shock_names = [f"{name}_{self.name}_innovation" for name in self.observed_state_names]
101-
self.param_names = ["ar_params", "sigma_ar"]
102-
self.param_dims = {"ar_params": (AR_PARAM_DIM,)}
103-
self.coords = {AR_PARAM_DIM: self.ar_lags.tolist()}
102+
103+
self.shock_names = [f"{self.name}[{obs_name}]" for obs_name in self.observed_state_names]
104+
self.param_names = [f"{self.name}_params", f"{self.name}_sigma"]
105+
self.param_dims = {f"{self.name}_params": (f"{self.name}_lag",)}
106+
self.coords = {f"{self.name}_lag": self.ar_lags.tolist()}
104107

105108
if self.k_endog > 1:
106-
self.param_dims["ar_params"] = (
109+
self.param_dims[f"{self.name}_params"] = (
107110
f"{self.name}_endog",
108111
AR_PARAM_DIM,
109112
)
110-
self.param_dims["sigma_ar"] = (f"{self.name}_endog",)
113+
self.param_dims[f"{self.name}_sigma"] = (f"{self.name}_endog",)
111114

112115
self.coords[f"{self.name}_endog"] = self.observed_state_names
113116

114117
self.param_info = {
115-
"ar_params": {
118+
f"{self.name}_params": {
116119
"shape": (self.k_states,) if self.k_endog == 1 else (self.k_endog, self.k_states),
117120
"constraints": None,
118121
"dims": (AR_PARAM_DIM,)
@@ -122,7 +125,7 @@ def populate_component_properties(self):
122125
AR_PARAM_DIM,
123126
),
124127
},
125-
"sigma_ar": {
128+
f"{self.name}_sigma": {
126129
"shape": () if self.k_endog == 1 else (self.k_endog,),
127130
"constraints": "Positive",
128131
"dims": None if self.k_endog == 1 else (f"{self.name}_endog",),
@@ -136,10 +139,10 @@ def make_symbolic_graph(self) -> None:
136139

137140
k_nonzero = int(sum(self.order))
138141
ar_params = self.make_and_register_variable(
139-
"ar_params", shape=(k_nonzero,) if k_endog == 1 else (k_endog, k_nonzero)
142+
f"{self.name}_params", shape=(k_nonzero,) if k_endog == 1 else (k_endog, k_nonzero)
140143
)
141144
sigma_ar = self.make_and_register_variable(
142-
"sigma_ar", shape=() if k_endog == 1 else (k_endog,)
145+
f"{self.name}_sigma", shape=() if k_endog == 1 else (k_endog,)
143146
)
144147

145148
if k_endog == 1:

pymc_extras/statespace/models/structural/components/cycle.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22

33
from pytensor import tensor as pt
4+
from pytensor.tensor.slinalg import block_diag
45
from scipy import linalg
56

67
from pymc_extras.statespace.models.structural.core import Component
@@ -96,7 +97,6 @@ class CycleComponent(Component):
9697
9798
cycle_strength = pm.Normal("business_cycle", dims=ss_mod.param_dims["business_cycle"])
9899
cycle_length = pm.Uniform('business_cycle_length', lower=6, upper=12)
99-
100100
sigma_cycle = pm.HalfNormal('sigma_business_cycle', sigma=1)
101101
102102
ss_mod.build_statespace_graph(data)
@@ -124,13 +124,15 @@ class CycleComponent(Component):
124124
with pm.Model(coords=ss_mod.coords) as model:
125125
P0 = pm.Deterministic("P0", pt.eye(ss_mod.k_states), dims=ss_mod.param_dims["P0"])
126126
# Initial states: shape (3, 2) for 3 variables, 2 states each
127-
cycle_init = pm.Normal('business_cycle', dims=('business_cycle_endog', 'business_cycle_state'))
127+
cycle_init = pm.Normal('business_cycle', dims=ss_mod.param_dims["business_cycle"])
128128
129129
# Dampening factor: scalar (shared across variables)
130-
dampening = pm.Uniform('business_cycle_dampening_factor', lower=0.8, upper=1.0)
130+
dampening = pm.Beta("business_cycle_dampening_factor", 2, 2)
131131
132132
# Innovation variances: shape (3,) for variable-specific variances
133-
sigma_cycle = pm.HalfNormal('sigma_business_cycle', dims=('business_cycle_endog',))
133+
sigma_cycle = pm.HalfNormal(
134+
"sigma_business_cycle", dims=ss_mod.param_dims["sigma_business_cycle"]
135+
)
134136
135137
ss_mod.build_statespace_graph(data)
136138
idata = pm.sample()
@@ -220,12 +222,8 @@ def make_symbolic_graph(self) -> None:
220222
if self.k_endog == 1:
221223
self.ssm["transition", :, :] = T
222224
else:
223-
# can't make the linalg.block_diag logic work here
224-
# doing it manually for now
225-
for i in range(self.k_endog):
226-
start_idx = i * 2
227-
end_idx = (i + 1) * 2
228-
self.ssm["transition", start_idx:end_idx, start_idx:end_idx] = T
225+
transition = block_diag(*[T for _ in range(self.k_endog)])
226+
self.ssm["transition"] = pt.specify_shape(transition, (self.k_states, self.k_states))
229227

230228
if self.innovations:
231229
if self.k_endog == 1:
@@ -235,16 +233,20 @@ def make_symbolic_graph(self) -> None:
235233
sigma_cycle = self.make_and_register_variable(
236234
f"sigma_{self.name}", shape=(self.k_endog,)
237235
)
238-
# can't make the linalg.block_diag logic work here
239-
# doing it manually for now
240-
for i in range(self.k_endog):
241-
start_idx = i * 2
242-
end_idx = (i + 1) * 2
243-
Q_block = pt.eye(2) * sigma_cycle[i] ** 2
244-
self.ssm["state_cov", start_idx:end_idx, start_idx:end_idx] = Q_block
236+
state_cov = block_diag(
237+
*[pt.eye(2) * sigma_cycle[i] ** 2 for i in range(self.k_endog)]
238+
)
239+
self.ssm["state_cov"] = pt.specify_shape(state_cov, (self.k_states, self.k_states))
245240

246241
def populate_component_properties(self):
247-
self.state_names = [f"{self.name}_{f}" for f in ["Cos", "Sin"]]
242+
if self.k_endog == 1:
243+
self.state_names = [f"{self.name}_{f}" for f in ["Cos", "Sin"]]
244+
else:
245+
# For multivariate cycles, create state names for each observed state
246+
self.state_names = []
247+
for var_name in self.observed_state_names:
248+
self.state_names.extend([f"{self.name}_{var_name}_{f}" for f in ["Cos", "Sin"]])
249+
248250
self.param_names = [f"{self.name}"]
249251

250252
if self.k_endog == 1:
@@ -260,7 +262,7 @@ def populate_component_properties(self):
260262
else:
261263
self.param_dims = {self.name: (f"{self.name}_endog", f"{self.name}_state")}
262264
self.coords = {
263-
f"{self.name}_state": self.state_names,
265+
f"{self.name}_state": [f"{self.name}_Cos", f"{self.name}_Sin"],
264266
f"{self.name}_endog": self.observed_state_names,
265267
}
266268
self.param_info = {

pymc_extras/statespace/models/structural/components/level_trend.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import numpy as np
2-
3-
from scipy import linalg
2+
import pytensor.tensor as pt
43

54
from pymc_extras.statespace.models.structural.core import Component
65
from pymc_extras.statespace.models.structural.utils import order_to_mask
@@ -13,7 +12,6 @@ class LevelTrendComponent(Component):
1312
1413
Parameters
1514
----------
16-
__________
1715
order : int
1816
1917
Number of time derivatives of the trend to include in the model. For example, when order=3, the trend will
@@ -114,7 +112,7 @@ def __init__(
114112
self,
115113
order: int | list[int] = 2,
116114
innovations_order: int | list[int] | None = None,
117-
name: str = "LevelTrend",
115+
name: str = "level_trend",
118116
observed_state_names: list[str] | None = None,
119117
):
120118
if innovations_order is None:
@@ -166,37 +164,44 @@ def populate_component_properties(self):
166164
k_posdef = self.k_posdef // k_endog
167165

168166
name_slice = POSITION_DERIVATIVE_NAMES[:k_states]
169-
self.param_names = [f"initial_{self.name}"]
167+
self.param_names = [f"{self.name}_initial"]
170168
base_names = [name for name, mask in zip(name_slice, self._order_mask) if mask]
171169
self.state_names = [
172170
f"{name}[{obs_name}]" for obs_name in self.observed_state_names for name in base_names
173171
]
174-
self.param_dims = {f"initial_{self.name}": (f"{self.name}_state",)}
172+
self.param_dims = {f"{self.name}_initial": (f"{self.name}_state",)}
175173
self.coords = {f"{self.name}_state": base_names}
176174

177175
if k_endog > 1:
178176
self.param_dims[f"{self.name}_state"] = (
179177
f"{self.name}_endog",
180178
f"{self.name}_state",
181179
)
182-
self.param_dims = {f"initial_{self.name}": (f"{self.name}_endog", f"{self.name}_state")}
180+
self.param_dims = {f"{self.name}_initial": (f"{self.name}_endog", f"{self.name}_state")}
183181
self.coords[f"{self.name}_endog"] = self.observed_state_names
184182

185183
shape = (k_endog, k_states) if k_endog > 1 else (k_states,)
186-
self.param_info = {f"initial_{self.name}": {"shape": shape, "constraints": None}}
184+
self.param_info = {f"{self.name}_initial": {"shape": shape, "constraints": None}}
187185

188186
if self.k_posdef > 0:
189-
self.param_names += [f"sigma_{self.name}"]
190-
self.shock_names = [
187+
self.param_names += [f"{self.name}_sigma"]
188+
189+
shock_base_names = [
191190
name for name, mask in zip(name_slice, self.innovations_order) if mask
192191
]
193-
self.param_dims[f"sigma_{self.name}"] = (
192+
self.shock_names = [
193+
f"{name}[{obs_name}]"
194+
for obs_name in self.observed_state_names
195+
for name in shock_base_names
196+
]
197+
198+
self.param_dims[f"{self.name}_sigma"] = (
194199
(f"{self.name}_shock",)
195200
if k_endog == 1
196201
else (f"{self.name}_endog", f"{self.name}_shock")
197202
)
198203
self.coords[f"{self.name}_shock"] = self.shock_names
199-
self.param_info[f"sigma_{self.name}"] = {
204+
self.param_info[f"{self.name}_sigma"] = {
200205
"shape": (k_posdef,) if k_endog == 1 else (k_endog, k_posdef),
201206
"constraints": "Positive",
202207
}
@@ -210,28 +215,34 @@ def make_symbolic_graph(self) -> None:
210215
k_posdef = self.k_posdef // k_endog
211216

212217
initial_trend = self.make_and_register_variable(
213-
f"initial_{self.name}",
218+
f"{self.name}_initial",
214219
shape=(k_states,) if k_endog == 1 else (k_endog, k_states),
215220
)
216221
self.ssm["initial_state", :] = initial_trend.ravel()
217222

218-
triu_idx = np.triu_indices(k_states)
219-
T = np.zeros((k_states, k_states))
220-
T[triu_idx[0], triu_idx[1]] = 1
223+
triu_idx = pt.triu_indices(k_states)
224+
T = pt.zeros((k_states, k_states))[triu_idx[0], triu_idx[1]].set(1)
221225

222-
self.ssm["transition"] = linalg.block_diag(*[T for _ in range(k_endog)])
226+
self.ssm["transition", :, :] = pt.specify_shape(
227+
pt.linalg.block_diag(*[T for _ in range(k_endog)]), (self.k_states, self.k_states)
228+
)
223229

224230
R = np.eye(k_states)
225231
R = R[:, self.innovations_order]
226232

227-
self.ssm["selection", :, :] = linalg.block_diag(*[R for _ in range(k_endog)])
233+
self.ssm["selection", :, :] = pt.specify_shape(
234+
pt.linalg.block_diag(*[R for _ in range(k_endog)]), (self.k_states, self.k_posdef)
235+
)
228236

229237
Z = np.array([1.0] + [0.0] * (k_states - 1)).reshape((1, -1))
230-
self.ssm["design"] = linalg.block_diag(*[Z for _ in range(k_endog)])
238+
239+
self.ssm["design", :, :] = pt.specify_shape(
240+
pt.linalg.block_diag(*[Z for _ in range(k_endog)]), (self.k_endog, self.k_states)
241+
)
231242

232243
if k_posdef > 0:
233244
sigma_trend = self.make_and_register_variable(
234-
f"sigma_{self.name}",
245+
f"{self.name}_sigma",
235246
shape=(k_posdef,) if k_endog == 1 else (k_endog, k_posdef),
236247
)
237248
diag_idx = np.diag_indices(k_posdef * k_endog)

0 commit comments

Comments
 (0)