Skip to content

Commit bc0b5c8

Browse files
committed
Adding measurement error support
1 parent 5026c48 commit bc0b5c8

File tree

1 file changed

+24
-5
lines changed
  • pymc_extras/statespace/models

1 file changed

+24
-5
lines changed

pymc_extras/statespace/models/DFM.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ class BayesianDynamicFactor(PyMCStateSpace):
6060
The type of Kalman Filter to use. Options are "standard", "single", "univariate", "steady_state",
6161
and "cholesky". See the docs for kalman filters for more details.
6262
63+
measurement_error: bool, default True
64+
If true, a measurement error term is added to the model.
65+
6366
verbose: bool, default True
6467
If true, a message will be logged to the terminal explaining the variable names, dimensions, and supports.
6568
@@ -100,7 +103,7 @@ class BayesianDynamicFactor(PyMCStateSpace):
100103
factors. Careful prior specification is typically required for good estimation.
101104
102105
Currently, the implementation assumes same factor order for all the factors,
103-
does not yet support measurement error, exogenous variables and joint (VAR) error modeling.
106+
does not yet support exogenous variables and joint (VAR) error modeling.
104107
105108
Examples
106109
--------
@@ -161,6 +164,7 @@ def __init__(
161164
error_var: bool = False,
162165
error_cov_type: str = "diagonal",
163166
filter_type: str = "standard",
167+
measurement_error: bool = False,
164168
verbose: bool = True,
165169
):
166170
if k_endog is None and endog_names is None:
@@ -174,6 +178,7 @@ def __init__(
174178
raise NotImplementedError(
175179
"Joint error modeling (error_var=True) is not yet implemented."
176180
)
181+
177182
if exog is not None:
178183
raise NotImplementedError("Exogenous variables (exog) are not yet implemented.")
179184

@@ -185,7 +190,6 @@ def __init__(
185190
self.error_var = error_var
186191
self.error_cov_type = error_cov_type
187192
self.exog = exog
188-
# TODO add measurement error support
189193
# TODO add exogenous variables support?
190194

191195
# Determine the dimension for the latent factor states.
@@ -213,7 +217,7 @@ def __init__(
213217
k_posdef=k_posdef,
214218
filter_type=filter_type,
215219
verbose=verbose,
216-
measurement_error=False,
220+
measurement_error=measurement_error,
217221
)
218222

219223
@property
@@ -226,6 +230,7 @@ def param_names(self):
226230
"factor_sigma",
227231
"error_ar",
228232
"error_sigma",
233+
"sigma_obs",
229234
]
230235

231236
# Handle cases where parameters should be excluded based on model settings
@@ -236,6 +241,8 @@ def param_names(self):
236241
if self.error_cov_type == "unstructured":
237242
names.remove("error_sigma")
238243
names.append("error_cov")
244+
if not self.measurement_error:
245+
names.remove("sigma_obs")
239246

240247
return names
241248

@@ -274,6 +281,10 @@ def param_info(self) -> dict[str, dict[str, Any]]:
274281
"shape": (self.k_endog, self.k_endog),
275282
"constraints": "Positive Semi-definite",
276283
},
284+
"sigma_obs": {
285+
"shape": (self.k_endog,),
286+
"constraints": "Positive Semi-definite",
287+
},
277288
}
278289

279290
for name in self.param_names:
@@ -361,7 +372,8 @@ def param_dims(self):
361372
coord_map["error_sigma"] = (OBS_STATE_DIM,)
362373
if self.error_cov_type == "unstructured":
363374
coord_map["error_sigma"] = (OBS_STATE_DIM, OBS_STATE_AUX_DIM)
364-
375+
if self.measurement_error:
376+
coord_map["sigma_obs"] = (OBS_STATE_DIM,)
365377
return coord_map
366378

367379
def make_symbolic_graph(self):
@@ -456,4 +468,11 @@ def build_ar_block_matrix(ar_coeffs):
456468
)
457469

458470
# Observation covariance matrix
459-
self.ssm["obs_cov", :, :] = 0.0
471+
if self.measurement_error:
472+
sigma_obs = self.make_and_register_variable(
473+
"sigma_obs", shape=(self.k_endog,), dtype=floatX
474+
)
475+
self.ssm["obs_cov", :, :] = pt.diag(sigma_obs)
476+
else:
477+
# If measurement error is not used, set obs_cov to zero
478+
self.ssm["obs_cov", :, :] = 0.0

0 commit comments

Comments
 (0)