@@ -111,6 +111,7 @@ def __new__(
111
111
steps = None ,
112
112
mode = None ,
113
113
sequence_names = None ,
114
+ append_x0 = True ,
114
115
** kwargs ,
115
116
):
116
117
# Ignore dims in support shape because they are just passed along to the "observed" and "latent" distributions
@@ -138,12 +139,27 @@ def __new__(
138
139
steps = steps ,
139
140
mode = mode ,
140
141
sequence_names = sequence_names ,
142
+ append_x0 = append_x0 ,
141
143
** kwargs ,
142
144
)
143
145
144
146
@classmethod
145
147
def dist (
146
- cls , a0 , P0 , c , d , T , Z , R , H , Q , steps = None , mode = None , sequence_names = None , ** kwargs
148
+ cls ,
149
+ a0 ,
150
+ P0 ,
151
+ c ,
152
+ d ,
153
+ T ,
154
+ Z ,
155
+ R ,
156
+ H ,
157
+ Q ,
158
+ steps = None ,
159
+ mode = None ,
160
+ sequence_names = None ,
161
+ append_x0 = True ,
162
+ ** kwargs ,
147
163
):
148
164
steps = get_support_shape_1d (
149
165
support_shape = steps , shape = kwargs .get ("shape" , None ), support_shape_offset = 0
@@ -155,11 +171,31 @@ def dist(
155
171
steps = pt .as_tensor_variable (intX (steps ), ndim = 0 )
156
172
157
173
return super ().dist (
158
- [a0 , P0 , c , d , T , Z , R , H , Q , steps ], mode = mode , sequence_names = sequence_names , ** kwargs
174
+ [a0 , P0 , c , d , T , Z , R , H , Q , steps ],
175
+ mode = mode ,
176
+ sequence_names = sequence_names ,
177
+ append_x0 = append_x0 ,
178
+ ** kwargs ,
159
179
)
160
180
161
181
@classmethod
162
- def rv_op (cls , a0 , P0 , c , d , T , Z , R , H , Q , steps , size = None , mode = None , sequence_names = None ):
182
+ def rv_op (
183
+ cls ,
184
+ a0 ,
185
+ P0 ,
186
+ c ,
187
+ d ,
188
+ T ,
189
+ Z ,
190
+ R ,
191
+ H ,
192
+ Q ,
193
+ steps ,
194
+ size = None ,
195
+ mode = None ,
196
+ sequence_names = None ,
197
+ append_x0 = True ,
198
+ ):
163
199
if sequence_names is None :
164
200
sequence_names = []
165
201
@@ -239,8 +275,12 @@ def step_fn(*args):
239
275
strict = True ,
240
276
)
241
277
242
- statespace_ = pt .concatenate ([init_dist_ [None ], statespace ], axis = 0 )
243
- statespace_ = pt .specify_shape (statespace_ , (steps + 1 , None ))
278
+ if append_x0 :
279
+ statespace_ = pt .concatenate ([init_dist_ [None ], statespace ], axis = 0 )
280
+ statespace_ = pt .specify_shape (statespace_ , (steps + 1 , None ))
281
+ else :
282
+ statespace_ = statespace
283
+ statespace_ = pt .specify_shape (statespace_ , (steps , None ))
244
284
245
285
(ss_rng ,) = tuple (updates .values ())
246
286
linear_gaussian_ss_op = LinearGaussianStateSpaceRV (
@@ -276,6 +316,7 @@ def __new__(
276
316
k_endog = None ,
277
317
sequence_names = None ,
278
318
mode = None ,
319
+ append_x0 = True ,
279
320
** kwargs ,
280
321
):
281
322
dims = kwargs .pop ("dims" , None )
@@ -304,9 +345,10 @@ def __new__(
304
345
steps = steps ,
305
346
mode = mode ,
306
347
sequence_names = sequence_names ,
348
+ append_x0 = append_x0 ,
307
349
** kwargs ,
308
350
)
309
- latent_obs_combined = pt .specify_shape (latent_obs_combined , (steps + 1 , None ))
351
+ latent_obs_combined = pt .specify_shape (latent_obs_combined , (steps + int ( append_x0 ) , None ))
310
352
if k_endog is None :
311
353
k_endog = cls ._get_k_endog (H )
312
354
latent_slice = slice (None , - k_endog )
0 commit comments