Skip to content

Commit a4bb6fb

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 794fbf8 commit a4bb6fb

File tree

3 files changed

+33
-33
lines changed

3 files changed

+33
-33
lines changed

tests/statespace/core/test_statespace.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,9 +1241,9 @@ def test_param_dims_coords(ss_mod_multi_component):
12411241
assert dims is None
12421242
continue
12431243
for i, s in zip(shape, dims):
1244-
assert i == len(ss_mod_multi_component.coords[s]), (
1245-
f"Mismatch between shape {i} and dimension {s}"
1246-
)
1244+
assert i == len(
1245+
ss_mod_multi_component.coords[s]
1246+
), f"Mismatch between shape {i} and dimension {s}"
12471247

12481248

12491249
@pytest.mark.filterwarnings("ignore:Provided data contains missing values")

tests/statespace/filters/test_kalman_filter.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ def test_output_shapes_one_state_one_observed(filter_func, rng):
7373

7474
for output_idx, name in enumerate(output_names):
7575
expected_output = get_expected_shape(name, p, m, r, n)
76-
assert outputs[output_idx].shape == expected_output, (
77-
f"Shape of {name} does not match expected"
78-
)
76+
assert (
77+
outputs[output_idx].shape == expected_output
78+
), f"Shape of {name} does not match expected"
7979

8080

8181
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
@@ -86,9 +86,9 @@ def test_output_shapes_when_all_states_are_stochastic(filter_func, rng):
8686
outputs = filter_func(*inputs)
8787
for output_idx, name in enumerate(output_names):
8888
expected_output = get_expected_shape(name, p, m, r, n)
89-
assert outputs[output_idx].shape == expected_output, (
90-
f"Shape of {name} does not match expected"
91-
)
89+
assert (
90+
outputs[output_idx].shape == expected_output
91+
), f"Shape of {name} does not match expected"
9292

9393

9494
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
@@ -99,9 +99,9 @@ def test_output_shapes_when_some_states_are_deterministic(filter_func, rng):
9999
outputs = filter_func(*inputs)
100100
for output_idx, name in enumerate(output_names):
101101
expected_output = get_expected_shape(name, p, m, r, n)
102-
assert outputs[output_idx].shape == expected_output, (
103-
f"Shape of {name} does not match expected"
104-
)
102+
assert (
103+
outputs[output_idx].shape == expected_output
104+
), f"Shape of {name} does not match expected"
105105

106106

107107
@pytest.fixture
@@ -161,9 +161,9 @@ def test_output_shapes_with_time_varying_matrices(f_standard_nd, rng):
161161

162162
for output_idx, name in enumerate(output_names):
163163
expected_output = get_expected_shape(name, p, m, r, n)
164-
assert outputs[output_idx].shape == expected_output, (
165-
f"Shape of {name} does not match expected"
166-
)
164+
assert (
165+
outputs[output_idx].shape == expected_output
166+
), f"Shape of {name} does not match expected"
167167

168168

169169
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
@@ -175,9 +175,9 @@ def test_output_with_deterministic_observation_equation(filter_func, rng):
175175

176176
for output_idx, name in enumerate(output_names):
177177
expected_output = get_expected_shape(name, p, m, r, n)
178-
assert outputs[output_idx].shape == expected_output, (
179-
f"Shape of {name} does not match expected"
180-
)
178+
assert (
179+
outputs[output_idx].shape == expected_output
180+
), f"Shape of {name} does not match expected"
181181

182182

183183
@pytest.mark.parametrize(
@@ -190,9 +190,9 @@ def test_output_with_multiple_observed(filter_func, filter_name, rng):
190190
outputs = filter_func(*inputs)
191191
for output_idx, name in enumerate(output_names):
192192
expected_output = get_expected_shape(name, p, m, r, n)
193-
assert outputs[output_idx].shape == expected_output, (
194-
f"Shape of {name} does not match expected"
195-
)
193+
assert (
194+
outputs[output_idx].shape == expected_output
195+
), f"Shape of {name} does not match expected"
196196

197197

198198
@pytest.mark.parametrize(
@@ -206,9 +206,9 @@ def test_missing_data(filter_func, filter_name, p, rng):
206206
outputs = filter_func(*inputs)
207207
for output_idx, name in enumerate(output_names):
208208
expected_output = get_expected_shape(name, p, m, r, n)
209-
assert outputs[output_idx].shape == expected_output, (
210-
f"Shape of {name} does not match expected"
211-
)
209+
assert (
210+
outputs[output_idx].shape == expected_output
211+
), f"Shape of {name} does not match expected"
212212

213213

214214
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)

tests/statespace/models/structural/test_against_statsmodels.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,16 @@ def _assert_coord_shapes_match_matrices(mod, params):
6767
n_shocks = max(1, len(mod.coords[SHOCK_DIM]))
6868
n_obs = len(mod.coords[OBS_STATE_DIM])
6969

70-
assert x0.shape[-1:] == (n_states,), (
71-
f"x0 expected to have shape (n_states, ), found {x0.shape[-1:]}"
72-
)
70+
assert x0.shape[-1:] == (
71+
n_states,
72+
), f"x0 expected to have shape (n_states, ), found {x0.shape[-1:]}"
7373
assert P0.shape[-2:] == (
7474
n_states,
7575
n_states,
7676
), f"P0 expected to have shape (n_states, n_states), found {P0.shape[-2:]}"
77-
assert c.shape[-1:] == (n_states,), (
78-
f"c expected to have shape (n_states, ), found {c.shape[-1:]}"
79-
)
77+
assert c.shape[-1:] == (
78+
n_states,
79+
), f"c expected to have shape (n_states, ), found {c.shape[-1:]}"
8080
assert d.shape[-1:] == (n_obs,), f"d expected to have shape (n_obs, ), found {d.shape[-1:]}"
8181
assert T.shape[-2:] == (
8282
n_states,
@@ -107,9 +107,9 @@ def _assert_keys_match(test_dict, expected_dict):
107107
assert len(key_diff) == 0, f"{', '.join(key_diff)} were not found in the test_dict keys."
108108

109109
key_diff = set(param_keys) - set(expected_keys)
110-
assert len(key_diff) == 0, (
111-
f"{', '.join(key_diff)} were keys of the tests_dict not in expected_dict."
112-
)
110+
assert (
111+
len(key_diff) == 0
112+
), f"{', '.join(key_diff)} were keys of the tests_dict not in expected_dict."
113113

114114

115115
def _assert_param_dims_correct(param_dims, expected_dims):

0 commit comments

Comments
 (0)