14
14
ALL_STATE_AUX_DIM ,
15
15
ALL_STATE_DIM ,
16
16
AR_PARAM_DIM ,
17
- EXOGENOUS_DIM ,
17
+ EXOG_STATE_DIM ,
18
18
MA_PARAM_DIM ,
19
19
OBS_STATE_AUX_DIM ,
20
20
OBS_STATE_DIM ,
@@ -342,15 +342,15 @@ def data_info(self) -> dict[str, dict[str, Any]]:
342
342
if isinstance (self .exog_state_names , list ):
343
343
info = {
344
344
"exogenous_data" : {
345
- "dims" : (TIME_DIM , EXOGENOUS_DIM ),
345
+ "dims" : (TIME_DIM , EXOG_STATE_DIM ),
346
346
"shape" : (None , self .k_exog ),
347
347
}
348
348
}
349
349
350
350
elif isinstance (self .exog_state_names , dict ):
351
351
info = {
352
352
f"{ endog_state } _exogenous_data" : {
353
- "dims" : (TIME_DIM , f"{ EXOGENOUS_DIM } _{ endog_state } " ),
353
+ "dims" : (TIME_DIM , f"{ EXOG_STATE_DIM } _{ endog_state } " ),
354
354
"shape" : (None , len (exog_names )),
355
355
}
356
356
for endog_state , exog_names in self .exog_state_names .items ()
@@ -399,10 +399,10 @@ def coords(self) -> dict[str, Sequence]:
399
399
coords .update ({MA_PARAM_DIM : list (range (1 , self .q + 1 ))})
400
400
401
401
if isinstance (self .exog_state_names , list ):
402
- coords [EXOGENOUS_DIM ] = self .exog_state_names
402
+ coords [EXOG_STATE_DIM ] = self .exog_state_names
403
403
elif isinstance (self .exog_state_names , dict ):
404
404
for name , exog_names in self .exog_state_names .items ():
405
- coords [f"{ EXOGENOUS_DIM } _{ name } " ] = exog_names
405
+ coords [f"{ EXOG_STATE_DIM } _{ name } " ] = exog_names
406
406
407
407
return coords
408
408
@@ -428,12 +428,12 @@ def param_dims(self):
428
428
del coord_map ["x0" ]
429
429
430
430
if isinstance (self .exog_state_names , list ):
431
- coord_map ["beta_exog" ] = (OBS_STATE_DIM , EXOGENOUS_DIM )
431
+ coord_map ["beta_exog" ] = (OBS_STATE_DIM , EXOG_STATE_DIM )
432
432
elif isinstance (self .exog_state_names , dict ):
433
433
# If each state has its own exogenous variables, each parameter needs it own dim, since we expect the
434
434
# dim labels to all be different (otherwise we'd be in the list case).
435
435
for name in self .exog_state_names .keys ():
436
- coord_map [f"beta_{ name } " ] = (f"{ EXOGENOUS_DIM } _{ name } " ,)
436
+ coord_map [f"beta_{ name } " ] = (f"{ EXOG_STATE_DIM } _{ name } " ,)
437
437
438
438
return coord_map
439
439
0 commit comments