Skip to content

Commit 73eac98

Browse files
Bump aesara to 2.7.8. (#5995)
* Bump aesara to 2.7.8. * Bump pre-commit aesara. * Ignore mypy typing errors for pymc/util.py * Add safety check for `xarray.Dataset` key types Co-authored-by: Michael Osthege <[email protected]>
1 parent 1101818 commit 73eac98

9 files changed

+37
-12
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ repos:
2626
- types-filelock
2727
- types-setuptools
2828
- arviz
29-
- aesara==2.7.7
29+
- aesara==2.7.8
3030
- aeppl==0.0.32
3131
always_run: true
3232
require_serial: true

conda-envs/environment-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ channels:
66
dependencies:
77
# Base dependencies
88
- aeppl=0.0.32
9-
- aesara=2.7.7
9+
- aesara=2.7.8
1010
- arviz>=0.12.0
1111
- blas
1212
- cachetools>=4.2.1

conda-envs/environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ channels:
66
dependencies:
77
# Base dependencies
88
- aeppl=0.0.32
9-
- aesara=2.7.7
9+
- aesara=2.7.8
1010
- arviz>=0.12.0
1111
- blas
1212
- cachetools>=4.2.1

conda-envs/windows-environment-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ channels:
66
dependencies:
77
# Base dependencies (see install guide for Windows)
88
- aeppl=0.0.32
9-
- aesara=2.7.7
9+
- aesara=2.7.8
1010
- arviz>=0.12.0
1111
- blas
1212
- cachetools>=4.2.1

conda-envs/windows-environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ channels:
66
dependencies:
77
# Base dependencies (see install guide for Windows)
88
- aeppl=0.0.32
9-
- aesara=2.7.7
9+
- aesara=2.7.8
1010
- arviz>=0.12.0
1111
- blas
1212
- cachetools>=4.2.1

pymc/tests/test_util.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,20 @@
1414

1515
import numpy as np
1616
import pytest
17+
import xarray
1718

1819
from cachetools import cached
1920

2021
import pymc as pm
2122

2223
from pymc.distributions.transforms import RVTransform
23-
from pymc.util import UNSET, hash_key, hashable, locally_cachedmethod
24+
from pymc.util import (
25+
UNSET,
26+
dataset_to_point_list,
27+
hash_key,
28+
hashable,
29+
locally_cachedmethod,
30+
)
2431

2532

2633
class TestTransformName:
@@ -142,3 +149,18 @@ def fn(a=UNSET):
142149
help(fn)
143150
captured = capsys.readouterr()
144151
assert "a=UNSET" in captured.out
152+
153+
154+
def test_dataset_to_point_list():
155+
ds = xarray.Dataset()
156+
ds["A"] = xarray.DataArray([[1, 2, 3]] * 2, dims=("chain", "draw"))
157+
pl = dataset_to_point_list(ds)
158+
assert isinstance(pl, list)
159+
assert len(pl) == 6
160+
assert isinstance(pl[0], dict)
161+
assert isinstance(pl[0]["A"], np.ndarray)
162+
163+
# Check that non-str keys are caught
164+
ds[3] = xarray.DataArray([1, 2, 3])
165+
with pytest.raises(ValueError, match="must be str"):
166+
dataset_to_point_list(ds)

pymc/util.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import functools
1616

17-
from typing import Dict, List, Tuple, Union, cast
17+
from typing import Dict, Hashable, List, Tuple, Union, cast
1818

1919
import arviz
2020
import cloudpickle
@@ -232,15 +232,18 @@ def enhanced(*args, **kwargs):
232232

233233

234234
def dataset_to_point_list(ds: xarray.Dataset) -> List[Dict[str, np.ndarray]]:
235+
# All keys of the dataset must be a str
236+
for vn in ds.keys():
237+
if not isinstance(vn, str):
238+
raise ValueError(f"Variable names must be str, but dataset key {vn} is a {type(vn)}.")
235239
# make dicts
236-
points: List[Dict[str, np.ndarray]] = []
237-
vn: str
240+
points: List[Dict[Hashable, np.ndarray]] = []
238241
da: "xarray.DataArray"
239242
for c in ds.chain:
240243
for d in ds.draw:
241244
points.append({vn: da.sel(chain=c, draw=d).values for vn, da in ds.items()})
242245
# use the list of points
243-
return points
246+
return cast(List[Dict[str, np.ndarray]], points)
244247

245248

246249
def chains_and_samples(data: Union[xarray.Dataset, arviz.InferenceData]) -> Tuple[int, int]:

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# See that file for comments about the need/usage of each dependency.
33

44
aeppl==0.0.32
5-
aesara==2.7.7
5+
aesara==2.7.8
66
arviz>=0.12.0
77
cachetools>=4.2.1
88
cloudpickle

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
aeppl==0.0.32
2-
aesara==2.7.7
2+
aesara==2.7.8
33
arviz>=0.12.0
44
cachetools>=4.2.1
55
cloudpickle

0 commit comments

Comments
 (0)