Skip to content

Commit 513a9d0

Browse files
Remove tests referencing old code
1 parent 08995be commit 513a9d0

File tree

3 files changed

+5
-13
lines changed

3 files changed

+5
-13
lines changed

pymc_experimental/statespace/filters/kalman_filter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -730,14 +730,14 @@ def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]:
730730
loglike_obs,
731731
) = results
732732

733-
def square_sequnece(L):
733+
def square_sequnece(L, k):
734734
X = pt.einsum("...ij,...kj->...ik", L, L.copy())
735-
X = pt.specify_shape(X, (n, self.n_states, self.n_states))
735+
X = pt.specify_shape(X, (n, k, k))
736736
return X
737737

738-
filtered_covariances = square_sequnece(filtered_covariances_cholesky)
739-
predicted_covariances = square_sequnece(predicted_covariances_cholesky)
740-
observed_covariances = square_sequnece(observed_covariances_cholesky)
738+
filtered_covariances = square_sequnece(filtered_covariances_cholesky, k=self.n_states)
739+
predicted_covariances = square_sequnece(predicted_covariances_cholesky, k=self.n_states)
740+
observed_covariances = square_sequnece(observed_covariances_cholesky, k=self.n_endog)
741741

742742
return [
743743
filtered_states,

tests/statespace/test_distributions.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@
3838
"standard",
3939
"cholesky",
4040
"univariate",
41-
"single",
42-
"steady_state",
4341
]
4442

4543

tests/statespace/test_statespace.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -234,12 +234,6 @@ def test_invalid_filter_name_raises():
234234
mod = make_statespace_mod(k_endog=1, k_states=5, k_posdef=1, filter_type="invalid_filter")
235235

236236

237-
def test_singleseriesfilter_raises_if_k_endog_gt_one():
238-
msg = 'Cannot use filter_type = "single" with multiple observed time series'
239-
with pytest.raises(ValueError, match=msg):
240-
mod = make_statespace_mod(k_endog=10, k_states=5, k_posdef=1, filter_type="single")
241-
242-
243237
def test_unpack_before_insert_raises(rng):
244238
p, m, r, n = 2, 5, 1, 10
245239
data, *inputs = make_test_inputs(p, m, r, n, rng, missing_data=0)

0 commit comments

Comments
 (0)