1
1
import numpy as np
2
-
3
- from scipy import linalg
2
+ import pytensor .tensor as pt
4
3
5
4
from pymc_extras .statespace .models .structural .core import Component
6
5
from pymc_extras .statespace .models .structural .utils import order_to_mask
@@ -114,7 +113,7 @@ def __init__(
114
113
self ,
115
114
order : int | list [int ] = 2 ,
116
115
innovations_order : int | list [int ] | None = None ,
117
- name : str = "LevelTrend " ,
116
+ name : str = "level_trend " ,
118
117
observed_state_names : list [str ] | None = None ,
119
118
):
120
119
if innovations_order is None :
@@ -166,35 +165,46 @@ def populate_component_properties(self):
166
165
k_posdef = self .k_posdef // k_endog
167
166
168
167
name_slice = POSITION_DERIVATIVE_NAMES [:k_states ]
169
- self .param_names = ["initial_trend " ]
168
+ self .param_names = [f" { self . name } _initial " ]
170
169
base_names = [name for name , mask in zip (name_slice , self ._order_mask ) if mask ]
171
170
self .state_names = [
172
- f"{ name } [{ obs_name } ]" for obs_name in self .observed_state_names for name in base_names
171
+ f"{ name } [{ obs_name } ]" if k_endog > 1 else name
172
+ for obs_name in self .observed_state_names
173
+ for name in base_names
173
174
]
174
- self .param_dims = {"initial_trend " : ("trend_state " ,)}
175
- self .coords = {"trend_state " : base_names }
175
+ self .param_dims = {f" { self . name } _initial " : (f" { self . name } _state " ,)}
176
+ self .coords = {f" { self . name } _state " : base_names }
176
177
177
178
if k_endog > 1 :
178
- self .param_dims ["trend_state " ] = (
179
- "trend_endog " ,
180
- "trend_state " ,
179
+ self .param_dims [f" { self . name } _state " ] = (
180
+ f" { self . name } _endog " ,
181
+ f" { self . name } _state " ,
181
182
)
182
- self .param_dims = {"initial_trend " : ("trend_endog " , "trend_state " )}
183
- self .coords ["trend_endog " ] = self .observed_state_names
183
+ self .param_dims = {f" { self . name } _initial " : (f" { self . name } _endog " , f" { self . name } _state " )}
184
+ self .coords [f" { self . name } _endog " ] = self .observed_state_names
184
185
185
186
shape = (k_endog , k_states ) if k_endog > 1 else (k_states ,)
186
- self .param_info = {"initial_trend " : {"shape" : shape , "constraints" : None }}
187
+ self .param_info = {f" { self . name } _initial " : {"shape" : shape , "constraints" : None }}
187
188
188
189
if self .k_posdef > 0 :
189
- self .param_names += ["sigma_trend" ]
190
- self .shock_names = [
190
+ self .param_names += [f"{ self .name } _sigma" ]
191
+
192
+ shock_base_names = [
191
193
name for name , mask in zip (name_slice , self .innovations_order ) if mask
192
194
]
193
- self .param_dims ["sigma_trend" ] = (
194
- ("trend_shock" ,) if k_endog == 1 else ("trend_endog" , "trend_shock" )
195
+ self .shock_names = [
196
+ f"{ name } [{ obs_name } ]" if k_endog > 1 else name
197
+ for obs_name in self .observed_state_names
198
+ for name in shock_base_names
199
+ ]
200
+
201
+ self .param_dims [f"{ self .name } _sigma" ] = (
202
+ (f"{ self .name } _shock" ,)
203
+ if k_endog == 1
204
+ else (f"{ self .name } _endog" , f"{ self .name } _shock" )
195
205
)
196
- self .coords ["trend_shock " ] = self .shock_names
197
- self .param_info ["sigma_trend " ] = {
206
+ self .coords [f" { self . name } _shock " ] = self .shock_names
207
+ self .param_info [f" { self . name } _sigma " ] = {
198
208
"shape" : (k_posdef ,) if k_endog == 1 else (k_endog , k_posdef ),
199
209
"constraints" : "Positive" ,
200
210
}
@@ -208,28 +218,34 @@ def make_symbolic_graph(self) -> None:
208
218
k_posdef = self .k_posdef // k_endog
209
219
210
220
initial_trend = self .make_and_register_variable (
211
- "initial_trend " ,
221
+ f" { self . name } _initial " ,
212
222
shape = (k_states ,) if k_endog == 1 else (k_endog , k_states ),
213
223
)
214
224
self .ssm ["initial_state" , :] = initial_trend .ravel ()
215
225
216
- triu_idx = np .triu_indices (k_states )
217
- T = np .zeros ((k_states , k_states ))
218
- T [triu_idx [0 ], triu_idx [1 ]] = 1
226
+ triu_idx = pt .triu_indices (k_states )
227
+ T = pt .zeros ((k_states , k_states ))[triu_idx [0 ], triu_idx [1 ]].set (1 )
219
228
220
- self .ssm ["transition" ] = linalg .block_diag (* [T for _ in range (k_endog )])
229
+ self .ssm ["transition" , :, :] = pt .specify_shape (
230
+ pt .linalg .block_diag (* [T for _ in range (k_endog )]), (self .k_states , self .k_states )
231
+ )
221
232
222
233
R = np .eye (k_states )
223
234
R = R [:, self .innovations_order ]
224
235
225
- self .ssm ["selection" , :, :] = linalg .block_diag (* [R for _ in range (k_endog )])
236
+ self .ssm ["selection" , :, :] = pt .specify_shape (
237
+ pt .linalg .block_diag (* [R for _ in range (k_endog )]), (self .k_states , self .k_posdef )
238
+ )
226
239
227
240
Z = np .array ([1.0 ] + [0.0 ] * (k_states - 1 )).reshape ((1 , - 1 ))
228
- self .ssm ["design" ] = linalg .block_diag (* [Z for _ in range (k_endog )])
241
+
242
+ self .ssm ["design" , :, :] = pt .specify_shape (
243
+ pt .linalg .block_diag (* [Z for _ in range (k_endog )]), (self .k_endog , self .k_states )
244
+ )
229
245
230
246
if k_posdef > 0 :
231
247
sigma_trend = self .make_and_register_variable (
232
- "sigma_trend " ,
248
+ f" { self . name } _sigma " ,
233
249
shape = (k_posdef ,) if k_endog == 1 else (k_endog , k_posdef ),
234
250
)
235
251
diag_idx = np .diag_indices (k_posdef * k_endog )
0 commit comments