Skip to content

Shape error in UnivariateFilter for time-varying matrices #612

@sebcroft

Description

@sebcroft

Hi!

I seem to be running into an error when using the univariate filter with time-varying state-space matrices. However, the same model set-up works fine when using the standard filter with time-varying matrices, and it also works for the univariate filter if all the matrices are time-invariant.

Below is a minimal-ish example that should hopefully reproduce it:

  • pymc : 5.26.1
  • pytensor : 2.35.1
  • pymc_extras : 0.5.0
import pytensor.tensor as pt
import pymc as pm
from pymc_extras.statespace.core import PyMCStateSpace

class StateSpace(PyMCStateSpace):
    def __init__(self, k_endog, k_states, k_posdef, n, filter_type):
        self.n = n
        super().__init__(k_endog=k_endog, k_states=k_states, k_posdef=k_posdef, filter_type=filter_type)

    @property
    def param_names(self):
        return ['x0', 'P0', 'sigma_epses', 'sigma_etas']


    def make_symbolic_graph(self):
        x0 = self.make_and_register_variable('x0', shape=(self.k_states,))
        P0 = self.make_and_register_variable('P0', shape=(self.k_states,self.k_states))
        sigma_epses = self.make_and_register_variable('sigma_epses', shape=(self.k_endog,))
        sigma_etas = self.make_and_register_variable('sigma_etas', shape=(self.k_posdef,))        

        self.ssm['initial_state', :] = x0
        self.ssm['initial_state_cov', :, :] = P0
        self.ssm['obs_cov', *np.diag_indices(self.k_endog)] = sigma_epses**2
        self.ssm['state_cov', *np.diag_indices(self.k_posdef)] = sigma_etas**2

            
        self.ssm['transition', :, :] = pt.zeros((self.k_states, self.k_states))
        self.ssm['selection'] = pt.zeros((self.n, self.k_states, self.k_posdef)) # think this is enough for PyMCStateSpace to assume it's time varying
        self.ssm['design', :, :] = pt.zeros((self.k_endog, self.k_states)) 


def model_builder(ss):

    y = np.zeros((ss.n, ss.k_endog))
    
    with pm.Model() as model:
        x0 = pm.Normal('x0', shape=(ss.k_states,))
        P0_diag = pm.Exponential('P0_diag', 1, shape=(ss.k_states,))
        P0 = pm.Deterministic('P0', pt.diag(P0_diag))
        sigma_eps = pm.Exponential('sigma_epses', 1, shape=(ss.k_endog,))
        sigma_eta = pm.Exponential('sigma_etas', 1, shape=(ss.k_posdef,))
        ss.build_statespace_graph(data = y, mode='JAX')
    return model

So when running:

n = 15
p=2
m=3
r=4
standard = StateSpace(p, m, r, n, filter_type='standard')
model_st = model_builder(standard)

--this works as expected (with no error).

However, running:

univariate = StateSpace(p, m, r, n, filter_type='univariate')
model_uv = model_builder(univariate)

should give the error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[17], line 2
      1 univariate = StateSpace(p, m, r, n, filter_type='univariate')
----> 2 model_uv = model_builder(univariate)

Cell In[11], line 11, in model_builder(ss)
      9     sigma_eps = pm.Exponential('sigma_epses', 1, shape=(ss.k_endog,))
     10     sigma_eta = pm.Exponential('sigma_etas', 1, shape=(ss.k_posdef,))
---> 11     ss.build_statespace_graph(data = y, mode='JAX')
     12 return model
...
...pymc-extras/~/anaconda3/envs/ss_issue/Lib/site-packages/pymc_extras/statespace/filters/kalman_filter.py#line=783), in UnivariateFilter.kalman_step(self, y, a, P, c, d, T, Z, R, H, Q)
    781 nan_mask = pt.isnan(y)
    783 W = pt.set_subtensor(pt.eye(y.shape[0])[nan_mask, nan_mask], 0.0)
--> 784 Z_masked = W.dot(Z)
    785 H_masked = W.dot(H)
    786 y_masked = pt.set_subtensor(y[nan_mask], 0.0)
... 
ValueError: Incompatible shared dimension for dot product: (2, 2), (3, 3)

Note: when all matrices are time-invariant, i.e, in this case, setting:

self.ssm['selection'] = pt.zeros((self.k_states, self.k_posdef))

the line model_uv = model_builder(univariate) works perfectly fine.

Is this a bug? Or maybe just a problem on my end... If it's bug, I'd be more than happy to have a go at fixing it!

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions