@@ -111,6 +111,7 @@ def __new__(
111111 steps = None ,
112112 mode = None ,
113113 sequence_names = None ,
114+ append_x0 = True ,
114115 ** kwargs ,
115116 ):
116117 # Ignore dims in support shape because they are just passed along to the "observed" and "latent" distributions
@@ -138,12 +139,27 @@ def __new__(
138139 steps = steps ,
139140 mode = mode ,
140141 sequence_names = sequence_names ,
142+ append_x0 = append_x0 ,
141143 ** kwargs ,
142144 )
143145
144146 @classmethod
145147 def dist (
146- cls , a0 , P0 , c , d , T , Z , R , H , Q , steps = None , mode = None , sequence_names = None , ** kwargs
148+ cls ,
149+ a0 ,
150+ P0 ,
151+ c ,
152+ d ,
153+ T ,
154+ Z ,
155+ R ,
156+ H ,
157+ Q ,
158+ steps = None ,
159+ mode = None ,
160+ sequence_names = None ,
161+ append_x0 = True ,
162+ ** kwargs ,
147163 ):
148164 steps = get_support_shape_1d (
149165 support_shape = steps , shape = kwargs .get ("shape" , None ), support_shape_offset = 0
@@ -155,11 +171,31 @@ def dist(
155171 steps = pt .as_tensor_variable (intX (steps ), ndim = 0 )
156172
157173 return super ().dist (
158- [a0 , P0 , c , d , T , Z , R , H , Q , steps ], mode = mode , sequence_names = sequence_names , ** kwargs
174+ [a0 , P0 , c , d , T , Z , R , H , Q , steps ],
175+ mode = mode ,
176+ sequence_names = sequence_names ,
177+ append_x0 = append_x0 ,
178+ ** kwargs ,
159179 )
160180
161181 @classmethod
162- def rv_op (cls , a0 , P0 , c , d , T , Z , R , H , Q , steps , size = None , mode = None , sequence_names = None ):
182+ def rv_op (
183+ cls ,
184+ a0 ,
185+ P0 ,
186+ c ,
187+ d ,
188+ T ,
189+ Z ,
190+ R ,
191+ H ,
192+ Q ,
193+ steps ,
194+ size = None ,
195+ mode = None ,
196+ sequence_names = None ,
197+ append_x0 = True ,
198+ ):
163199 if sequence_names is None :
164200 sequence_names = []
165201
@@ -239,8 +275,12 @@ def step_fn(*args):
239275 strict = True ,
240276 )
241277
242- statespace_ = pt .concatenate ([init_dist_ [None ], statespace ], axis = 0 )
243- statespace_ = pt .specify_shape (statespace_ , (steps + 1 , None ))
278+ if append_x0 :
279+ statespace_ = pt .concatenate ([init_dist_ [None ], statespace ], axis = 0 )
280+ statespace_ = pt .specify_shape (statespace_ , (steps + 1 , None ))
281+ else :
282+ statespace_ = statespace
283+ statespace_ = pt .specify_shape (statespace_ , (steps , None ))
244284
245285 (ss_rng ,) = tuple (updates .values ())
246286 linear_gaussian_ss_op = LinearGaussianStateSpaceRV (
@@ -276,6 +316,7 @@ def __new__(
276316 k_endog = None ,
277317 sequence_names = None ,
278318 mode = None ,
319+ append_x0 = True ,
279320 ** kwargs ,
280321 ):
281322 dims = kwargs .pop ("dims" , None )
@@ -304,9 +345,10 @@ def __new__(
304345 steps = steps ,
305346 mode = mode ,
306347 sequence_names = sequence_names ,
348+ append_x0 = append_x0 ,
307349 ** kwargs ,
308350 )
309- latent_obs_combined = pt .specify_shape (latent_obs_combined , (steps + 1 , None ))
351+ latent_obs_combined = pt .specify_shape (latent_obs_combined , (steps + int ( append_x0 ) , None ))
310352 if k_endog is None :
311353 k_endog = cls ._get_k_endog (H )
312354 latent_slice = slice (None , - k_endog )
0 commit comments