Skip to content

Commit 997e86b

Browse files
ran pre-commit on all files
1 parent b5aa9bf commit 997e86b

File tree

11 files changed

+55
-58
lines changed

11 files changed

+55
-58
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ ci:
33

44
repos:
55
- repo: https://github.com/pre-commit/pre-commit-hooks
6-
rev: v4.6.0
6+
rev: v6.0.0
77
hooks:
88
- id: check-merge-conflict
99
- id: check-toml
@@ -15,7 +15,7 @@ repos:
1515
- id: trailing-whitespace
1616

1717
- repo: https://github.com/astral-sh/ruff-pre-commit
18-
rev: v0.5.5
18+
rev: v0.12.8
1919
hooks:
2020
- id: ruff
2121
args: [ --fix, --unsafe-fixes, --exit-non-zero-on-fix ]

notebooks/Making a Custom Statespace Model.ipynb

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
}
1616
],
1717
"source": [
18-
"import numpy as np\n",
19-
"import matplotlib.pyplot as plt\n",
2018
"import arviz as az\n",
21-
"\n",
22-
"from pymc_extras.statespace.core.statespace import PyMCStateSpace\n",
19+
"import matplotlib.pyplot as plt\n",
20+
"import numpy as np\n",
21+
"import pymc as pm\n",
2322
"import pytensor.tensor as pt\n",
24-
"import pymc as pm"
23+
"\n",
24+
"from pymc_extras.statespace.core.statespace import PyMCStateSpace"
2525
]
2626
},
2727
{
@@ -45,7 +45,7 @@
4545
"\n",
4646
"\n",
4747
"def print_model_ssm(mod, how=\"eval\"):\n",
48-
" nice_heading = f'{\"name\":<20}{\"__repr__\":<50}{\"shape\":<10}{\"value\":<20}'\n",
48+
" nice_heading = f\"{'name':<20}{'__repr__':<50}{'shape':<10}{'value':<20}\"\n",
4949
" print(nice_heading)\n",
5050
" print(\"=\" * len(nice_heading))\n",
5151
" if how == \"eval\":\n",
@@ -1270,7 +1270,7 @@
12701270
],
12711271
"source": [
12721272
"az.plot_posterior(\n",
1273-
" idata, var_names=[\"ar_params\", \"sigma_x\"], ref_val=true_ar.tolist() + [true_sigma_x]\n",
1273+
" idata, var_names=[\"ar_params\", \"sigma_x\"], ref_val=[*true_ar.tolist(), true_sigma_x]\n",
12741274
");"
12751275
]
12761276
},
@@ -1333,13 +1333,12 @@
13331333
"metadata": {},
13341334
"outputs": [],
13351335
"source": [
1336+
"from pymc_extras.statespace.models.utilities import make_default_coords\n",
13361337
"from pymc_extras.statespace.utils.constants import (\n",
1337-
" ALL_STATE_DIM,\n",
13381338
" ALL_STATE_AUX_DIM,\n",
1339-
" OBS_STATE_DIM,\n",
1339+
" ALL_STATE_DIM,\n",
13401340
" SHOCK_DIM,\n",
13411341
")\n",
1342-
"from pymc_extras.statespace.models.utilities import make_default_coords\n",
13431342
"\n",
13441343
"\n",
13451344
"class AutoRegressiveThree(PyMCStateSpace):\n",

notebooks/discrete_markov_chain.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@
2727
"outputs": [],
2828
"source": [
2929
"import arviz as az\n",
30+
"import matplotlib.pyplot as plt\n",
3031
"import numpy as np\n",
32+
"import pandas as pd\n",
3133
"import pymc as pm\n",
3234
"import pytensor\n",
3335
"import pytensor.tensor as pt\n",
34-
"import pandas as pd\n",
3536
"import statsmodels.api as sm\n",
36-
"import matplotlib.pyplot as plt\n",
3737
"\n",
3838
"from matplotlib import ticker as mtick\n",
3939
"\n",
@@ -525,7 +525,7 @@
525525
" \"dates\": dta_hamilton.index,\n",
526526
" \"obs_dates\": dta_hamilton.index[order:],\n",
527527
" \"states\": [\"State_1\", \"State_2\"],\n",
528-
" \"ar_params\": [f\"L{i+1}.phi\" for i in range(order)],\n",
528+
" \"ar_params\": [f\"L{i + 1}.phi\" for i in range(order)],\n",
529529
"}\n",
530530
"\n",
531531
"with pm.Model(coords=coords) as hmm:\n",

pymc_extras/printing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def model_table(
166166

167167
for var in group:
168168
var_name = var.name
169-
sep = f'[b]{" ~" if (var in model.basic_RVs) else " ="}[/b]'
169+
sep = f"[b]{' ~' if (var in model.basic_RVs) else ' ='}[/b]"
170170
var_expr = variable_expression(model, var, truncate_deterministic)
171171
dims_expr = dims_expression(model, var)
172172
if dims_expr == "[]":

pymc_extras/statespace/core/statespace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
def _validate_filter_arg(filter_arg):
6161
if filter_arg.lower() not in FILTER_OUTPUT_TYPES:
6262
raise ValueError(
63-
f'filter_output should be one of {", ".join(FILTER_OUTPUT_TYPES)}, received {filter_arg}'
63+
f"filter_output should be one of {', '.join(FILTER_OUTPUT_TYPES)}, received {filter_arg}"
6464
)
6565

6666

pymc_extras/statespace/models/SARIMAX.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def __init__(
224224
if state_structure not in SARIMAX_STATE_STRUCTURES:
225225
raise ValueError(
226226
f"Got invalid argument {state_structure} for state structure, expected one of "
227-
f'{", ".join(SARIMAX_STATE_STRUCTURES)}'
227+
f"{', '.join(SARIMAX_STATE_STRUCTURES)}"
228228
)
229229

230230
if state_structure == "interpretable" and (self.d + self.D) > 0:

pymc_extras/statespace/utils/data_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _validate_data_shape(data_shape, n_obs, obs_coords=None, check_col_names=Fal
5353
if len(missing_cols) > 0:
5454
raise ValueError(
5555
"Columns of DataFrame provided as data do not match state names. The following states were"
56-
f'not found: {", ".join(missing_cols)}. This may result in unexpected results in complex'
56+
f"not found: {', '.join(missing_cols)}. This may result in unexpected results in complex"
5757
f"statespace models"
5858
)
5959

tests/statespace/core/test_statespace.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -884,8 +884,7 @@ def test_invalid_scenarios():
884884
# Giving a list, tuple, or Series when a matrix of data is expected should always raise
885885
with pytest.raises(
886886
ValueError,
887-
match="Scenario data for variable 'a' has the wrong number of columns. "
888-
"Expected 2, got 1",
887+
match="Scenario data for variable 'a' has the wrong number of columns. Expected 2, got 1",
889888
):
890889
for data_type in [list, tuple, pd.Series]:
891890
ss_mod._validate_scenario_data(data_type(np.zeros(10)))
@@ -894,15 +893,14 @@ def test_invalid_scenarios():
894893
# Providing irrevelant data raises
895894
with pytest.raises(
896895
ValueError,
897-
match="Scenario data provided for variable 'jk lol', which is not an exogenous " "variable",
896+
match="Scenario data provided for variable 'jk lol', which is not an exogenous variable",
898897
):
899898
ss_mod._validate_scenario_data({"jk lol": np.zeros(10)})
900899

901900
# Incorrect 2nd dimension of a non-dataframe
902901
with pytest.raises(
903902
ValueError,
904-
match="Scenario data for variable 'a' has the wrong number of columns. Expected "
905-
"2, got 1",
903+
match="Scenario data for variable 'a' has the wrong number of columns. Expected 2, got 1",
906904
):
907905
scenario = np.zeros(10).tolist()
908906
ss_mod._validate_scenario_data(scenario)
@@ -1243,9 +1241,9 @@ def test_param_dims_coords(ss_mod_multi_component):
12431241
assert dims is None
12441242
continue
12451243
for i, s in zip(shape, dims):
1246-
assert i == len(
1247-
ss_mod_multi_component.coords[s]
1248-
), f"Mismatch between shape {i} and dimension {s}"
1244+
assert i == len(ss_mod_multi_component.coords[s]), (
1245+
f"Mismatch between shape {i} and dimension {s}"
1246+
)
12491247

12501248

12511249
@pytest.mark.filterwarnings("ignore:Provided data contains missing values")

tests/statespace/filters/test_kalman_filter.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ def test_output_shapes_one_state_one_observed(filter_func, rng):
7373

7474
for output_idx, name in enumerate(output_names):
7575
expected_output = get_expected_shape(name, p, m, r, n)
76-
assert (
77-
outputs[output_idx].shape == expected_output
78-
), f"Shape of {name} does not match expected"
76+
assert outputs[output_idx].shape == expected_output, (
77+
f"Shape of {name} does not match expected"
78+
)
7979

8080

8181
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
@@ -86,9 +86,9 @@ def test_output_shapes_when_all_states_are_stochastic(filter_func, rng):
8686
outputs = filter_func(*inputs)
8787
for output_idx, name in enumerate(output_names):
8888
expected_output = get_expected_shape(name, p, m, r, n)
89-
assert (
90-
outputs[output_idx].shape == expected_output
91-
), f"Shape of {name} does not match expected"
89+
assert outputs[output_idx].shape == expected_output, (
90+
f"Shape of {name} does not match expected"
91+
)
9292

9393

9494
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
@@ -99,9 +99,9 @@ def test_output_shapes_when_some_states_are_deterministic(filter_func, rng):
9999
outputs = filter_func(*inputs)
100100
for output_idx, name in enumerate(output_names):
101101
expected_output = get_expected_shape(name, p, m, r, n)
102-
assert (
103-
outputs[output_idx].shape == expected_output
104-
), f"Shape of {name} does not match expected"
102+
assert outputs[output_idx].shape == expected_output, (
103+
f"Shape of {name} does not match expected"
104+
)
105105

106106

107107
@pytest.fixture
@@ -161,9 +161,9 @@ def test_output_shapes_with_time_varying_matrices(f_standard_nd, rng):
161161

162162
for output_idx, name in enumerate(output_names):
163163
expected_output = get_expected_shape(name, p, m, r, n)
164-
assert (
165-
outputs[output_idx].shape == expected_output
166-
), f"Shape of {name} does not match expected"
164+
assert outputs[output_idx].shape == expected_output, (
165+
f"Shape of {name} does not match expected"
166+
)
167167

168168

169169
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
@@ -175,9 +175,9 @@ def test_output_with_deterministic_observation_equation(filter_func, rng):
175175

176176
for output_idx, name in enumerate(output_names):
177177
expected_output = get_expected_shape(name, p, m, r, n)
178-
assert (
179-
outputs[output_idx].shape == expected_output
180-
), f"Shape of {name} does not match expected"
178+
assert outputs[output_idx].shape == expected_output, (
179+
f"Shape of {name} does not match expected"
180+
)
181181

182182

183183
@pytest.mark.parametrize(
@@ -190,9 +190,9 @@ def test_output_with_multiple_observed(filter_func, filter_name, rng):
190190
outputs = filter_func(*inputs)
191191
for output_idx, name in enumerate(output_names):
192192
expected_output = get_expected_shape(name, p, m, r, n)
193-
assert (
194-
outputs[output_idx].shape == expected_output
195-
), f"Shape of {name} does not match expected"
193+
assert outputs[output_idx].shape == expected_output, (
194+
f"Shape of {name} does not match expected"
195+
)
196196

197197

198198
@pytest.mark.parametrize(
@@ -206,9 +206,9 @@ def test_missing_data(filter_func, filter_name, p, rng):
206206
outputs = filter_func(*inputs)
207207
for output_idx, name in enumerate(output_names):
208208
expected_output = get_expected_shape(name, p, m, r, n)
209-
assert (
210-
outputs[output_idx].shape == expected_output
211-
), f"Shape of {name} does not match expected"
209+
assert outputs[output_idx].shape == expected_output, (
210+
f"Shape of {name} does not match expected"
211+
)
212212

213213

214214
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)

tests/statespace/models/structural/test_against_statsmodels.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,16 @@ def _assert_coord_shapes_match_matrices(mod, params):
6767
n_shocks = max(1, len(mod.coords[SHOCK_DIM]))
6868
n_obs = len(mod.coords[OBS_STATE_DIM])
6969

70-
assert x0.shape[-1:] == (
71-
n_states,
72-
), f"x0 expected to have shape (n_states, ), found {x0.shape[-1:]}"
70+
assert x0.shape[-1:] == (n_states,), (
71+
f"x0 expected to have shape (n_states, ), found {x0.shape[-1:]}"
72+
)
7373
assert P0.shape[-2:] == (
7474
n_states,
7575
n_states,
7676
), f"P0 expected to have shape (n_states, n_states), found {P0.shape[-2:]}"
77-
assert c.shape[-1:] == (
78-
n_states,
79-
), f"c expected to have shape (n_states, ), found {c.shape[-1:]}"
77+
assert c.shape[-1:] == (n_states,), (
78+
f"c expected to have shape (n_states, ), found {c.shape[-1:]}"
79+
)
8080
assert d.shape[-1:] == (n_obs,), f"d expected to have shape (n_obs, ), found {d.shape[-1:]}"
8181
assert T.shape[-2:] == (
8282
n_states,
@@ -107,9 +107,9 @@ def _assert_keys_match(test_dict, expected_dict):
107107
assert len(key_diff) == 0, f"{', '.join(key_diff)} were not found in the test_dict keys."
108108

109109
key_diff = set(param_keys) - set(expected_keys)
110-
assert (
111-
len(key_diff) == 0
112-
), f"{', '.join(key_diff)} were keys of the tests_dict not in expected_dict."
110+
assert len(key_diff) == 0, (
111+
f"{', '.join(key_diff)} were keys of the tests_dict not in expected_dict."
112+
)
113113

114114

115115
def _assert_param_dims_correct(param_dims, expected_dims):

0 commit comments

Comments
 (0)