Skip to content

Commit 1a7d961

Browse files
Add append_x0 argument to LinearGuassianStateSpace
1 parent 3925964 commit 1a7d961

File tree

1 file changed

+48
-6
lines changed

1 file changed

+48
-6
lines changed

pymc_experimental/statespace/filters/distributions.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __new__(
111111
steps=None,
112112
mode=None,
113113
sequence_names=None,
114+
append_x0=True,
114115
**kwargs,
115116
):
116117
# Ignore dims in support shape because they are just passed along to the "observed" and "latent" distributions
@@ -138,12 +139,27 @@ def __new__(
138139
steps=steps,
139140
mode=mode,
140141
sequence_names=sequence_names,
142+
append_x0=append_x0,
141143
**kwargs,
142144
)
143145

144146
@classmethod
145147
def dist(
146-
cls, a0, P0, c, d, T, Z, R, H, Q, steps=None, mode=None, sequence_names=None, **kwargs
148+
cls,
149+
a0,
150+
P0,
151+
c,
152+
d,
153+
T,
154+
Z,
155+
R,
156+
H,
157+
Q,
158+
steps=None,
159+
mode=None,
160+
sequence_names=None,
161+
append_x0=True,
162+
**kwargs,
147163
):
148164
steps = get_support_shape_1d(
149165
support_shape=steps, shape=kwargs.get("shape", None), support_shape_offset=0
@@ -155,11 +171,31 @@ def dist(
155171
steps = pt.as_tensor_variable(intX(steps), ndim=0)
156172

157173
return super().dist(
158-
[a0, P0, c, d, T, Z, R, H, Q, steps], mode=mode, sequence_names=sequence_names, **kwargs
174+
[a0, P0, c, d, T, Z, R, H, Q, steps],
175+
mode=mode,
176+
sequence_names=sequence_names,
177+
append_x0=append_x0,
178+
**kwargs,
159179
)
160180

161181
@classmethod
162-
def rv_op(cls, a0, P0, c, d, T, Z, R, H, Q, steps, size=None, mode=None, sequence_names=None):
182+
def rv_op(
183+
cls,
184+
a0,
185+
P0,
186+
c,
187+
d,
188+
T,
189+
Z,
190+
R,
191+
H,
192+
Q,
193+
steps,
194+
size=None,
195+
mode=None,
196+
sequence_names=None,
197+
append_x0=True,
198+
):
163199
if sequence_names is None:
164200
sequence_names = []
165201

@@ -239,8 +275,12 @@ def step_fn(*args):
239275
strict=True,
240276
)
241277

242-
statespace_ = pt.concatenate([init_dist_[None], statespace], axis=0)
243-
statespace_ = pt.specify_shape(statespace_, (steps + 1, None))
278+
if append_x0:
279+
statespace_ = pt.concatenate([init_dist_[None], statespace], axis=0)
280+
statespace_ = pt.specify_shape(statespace_, (steps + 1, None))
281+
else:
282+
statespace_ = statespace
283+
statespace_ = pt.specify_shape(statespace_, (steps, None))
244284

245285
(ss_rng,) = tuple(updates.values())
246286
linear_gaussian_ss_op = LinearGaussianStateSpaceRV(
@@ -276,6 +316,7 @@ def __new__(
276316
k_endog=None,
277317
sequence_names=None,
278318
mode=None,
319+
append_x0=True,
279320
**kwargs,
280321
):
281322
dims = kwargs.pop("dims", None)
@@ -304,9 +345,10 @@ def __new__(
304345
steps=steps,
305346
mode=mode,
306347
sequence_names=sequence_names,
348+
append_x0=append_x0,
307349
**kwargs,
308350
)
309-
latent_obs_combined = pt.specify_shape(latent_obs_combined, (steps + 1, None))
351+
latent_obs_combined = pt.specify_shape(latent_obs_combined, (steps + int(append_x0), None))
310352
if k_endog is None:
311353
k_endog = cls._get_k_endog(H)
312354
latent_slice = slice(None, -k_endog)

0 commit comments

Comments
 (0)