Skip to content

Commit cf5c9d4

Browse files
committed
infer named parameter layout from dict-type initial_state
1 parent e7d15d8 commit cf5c9d4

File tree

1 file changed

+84
-69
lines changed

1 file changed

+84
-69
lines changed

src/emcee/ensemble.py

Lines changed: 84 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import warnings
44
from itertools import count
5-
from typing import Dict, List, Optional, Union
5+
from typing import Dict, Iterable, List, Optional, Sequence, Union
66

77
import numpy as np
88

@@ -15,11 +15,42 @@
1515

1616
__all__ = ["EnsembleSampler", "walkers_independent"]
1717

18-
try:
19-
from collections.abc import Iterable
20-
except ImportError:
21-
# for py2.7, will be an Exception in 3.8
22-
from collections import Iterable
18+
ParameterNamesT = Union[
19+
Sequence[str], Dict[str, Union[slice, int, Sequence[int]]]
20+
]
21+
22+
23+
def infer_dict_mapping(state):
24+
i0 = 0
25+
param_slice_shape = {}
26+
for key, val in state.items():
27+
val = np.asarray(val)
28+
i1 = i0 + val.size
29+
slc = slice(i0, i1) if val.size > 1 else i0
30+
param_slice_shape[key] = slc, val.shape
31+
i0 = i1
32+
33+
return param_slice_shape
34+
35+
36+
def array_to_dict(ary, param_slice_shape):
37+
return {
38+
key: ary[:, slc].reshape((-1,)+shape)
39+
for key, (slc, shape) in param_slice_shape.items()
40+
}
41+
42+
43+
def array_to_list_of_dicts(ary, param_slice_shape):
44+
# reshape adds a small amount of overhead; don't do it unless necessary
45+
return [{
46+
key: ary_i[slc].reshape(shape) if len(shape) > 1 else ary_i[slc]
47+
for key, (slc, shape) in param_slice_shape.items()
48+
} for ary_i in ary]
49+
50+
51+
def collapse_and_hstack(values, nwalkers=None):
52+
shape = (nwalkers, -1) if nwalkers is not None else -1
53+
return np.hstack([np.asarray(val).reshape(shape) for val in values])
2354

2455

2556
class EnsembleSampler(object):
@@ -62,7 +93,8 @@ class EnsembleSampler(object):
6293
to accept a list of position vectors instead of just one. Note
6394
that ``pool`` will be ignored if this is ``True``.
6495
(default: ``False``)
65-
parameter_names (Optional[Union[List[str], Dict[str, List[int]]]]):
96+
parameter_names (Union[Sequence[str],
97+
Dict[str, Union[slice, int, Sequence[int]]]):
6698
names of individual parameters or groups of parameters. If
6799
specified, the ``log_prob_fn`` will recieve a dictionary of
68100
parameters, rather than a ``np.ndarray``.
@@ -81,7 +113,7 @@ def __init__(
81113
backend=None,
82114
vectorize=False,
83115
blobs_dtype=None,
84-
parameter_names: Optional[Union[Dict[str, int], List[str]]] = None,
116+
parameter_names: Optional[ParameterNamesT] = None,
85117
# Deprecated...
86118
a=None,
87119
postargs=None,
@@ -163,48 +195,39 @@ def __init__(
163195
# ``args`` and ``kwargs`` pickleable.
164196
self.log_prob_fn = _FunctionWrapper(log_prob_fn, args, kwargs)
165197

166-
# Save the parameter names
167-
self.params_are_named: bool = parameter_names is not None
168-
if self.params_are_named:
169-
assert isinstance(parameter_names, (list, dict))
170-
171-
# Don't support vectorizing yet
172-
msg = "named parameters with vectorization unsupported for now"
173-
assert not self.vectorize, msg
174-
175-
# Check for duplicate names
176-
dupes = set()
177-
uniq = []
178-
for name in parameter_names:
179-
if name not in dupes:
180-
uniq.append(name)
181-
dupes.add(name)
182-
msg = f"duplicate paramters: {dupes}"
183-
assert len(uniq) == len(parameter_names), msg
184-
185-
if isinstance(parameter_names, list):
186-
# Check for all named
187-
msg = "name all parameters or set `parameter_names` to `None`"
188-
assert len(parameter_names) == ndim, msg
189-
# Convert a list to a dict
190-
parameter_names: Dict[str, int] = {
191-
name: i for i, name in enumerate(parameter_names)
198+
if parameter_names is not None:
199+
if isinstance(parameter_names, Sequence):
200+
if len(parameter_names) != ndim:
201+
raise ValueError(
202+
f"`parameter_names` does not specify {ndim} names")
203+
parameter_names = dict(zip(parameter_names, range(ndim)))
204+
205+
indices = np.arange(ndim)
206+
207+
try:
208+
index_map = {
209+
key: indices[slc]
210+
for key, slc in parameter_names.items()
192211
}
212+
indexed = collapse_and_hstack(index_map.values())
213+
except IndexError as err:
214+
msg = "`parameter_names` specifies out-of-bounds element(s)"
215+
raise ValueError(msg) from err
193216

194-
# Check not too many names
195-
msg = "too many names"
196-
assert len(parameter_names) <= ndim, msg
197-
198-
# Check all indices appear
199-
values = [
200-
v if isinstance(v, list) else [v]
201-
for v in parameter_names.values()
202-
]
203-
values = [item for sublist in values for item in sublist]
204-
values = set(values)
205-
msg = f"not all values appear -- set should be 0 to {ndim-1}"
206-
assert values == set(np.arange(ndim)), msg
207-
self.parameter_names = parameter_names
217+
if len(indexed) != ndim:
218+
raise ValueError(
219+
"`parameter_names` does not specify indices for"
220+
f" {ndim} parameters"
221+
)
222+
if set(indexed) != set(indices):
223+
raise ValueError(
224+
"`parameter_names` does not specify indices"
225+
f" 0 through {ndim-1}"
226+
)
227+
228+
self.param_slice_shape = infer_dict_mapping(index_map)
229+
else:
230+
self.param_slice_shape = None
208231

209232
@property
210233
def random_state(self):
@@ -266,7 +289,8 @@ def sample(
266289
"""Advance the chain as a generator
267290
268291
Args:
269-
initial_state (State or ndarray[nwalkers, ndim]): The initial
292+
initial_state (State or ndarray[nwalkers, ndim] or
293+
dict[str, float | np.ndarray[nwalkers. ...]]): The initial
270294
:class:`State` or positions of the walkers in the
271295
parameter space.
272296
iterations (Optional[int or NoneType]): The number of steps to generate.
@@ -302,6 +326,12 @@ def sample(
302326
if iterations is None and store:
303327
raise ValueError("'store' must be False when 'iterations' is None")
304328
# Interpret the input as a walker state and check the dimensions.
329+
if isinstance(initial_state, dict):
330+
_state = {key: val[0] for key, val in initial_state.items()}
331+
self.param_slice_shape = infer_dict_mapping(_state)
332+
initial_state = collapse_and_hstack(
333+
initial_state.values(), self.nwalkers)
334+
305335
state = State(initial_state, copy=True)
306336
state_shape = np.shape(state.coords)
307337
if state_shape != (self.nwalkers, self.ndim):
@@ -472,8 +502,11 @@ def compute_log_prob(self, coords):
472502
raise ValueError("At least one parameter value was NaN")
473503

474504
# If the parmaeters are named, then switch to dictionaries
475-
if self.params_are_named:
476-
p = ndarray_to_list_of_dicts(p, self.parameter_names)
505+
if self.param_slice_shape:
506+
if self.vectorize:
507+
p = array_to_dict(p, self.param_slice_shape)
508+
else:
509+
p = array_to_list_of_dicts(p, self.param_slice_shape)
477510

478511
# Run the log-probability calculations (optionally in parallel).
479512
if self.vectorize:
@@ -664,21 +697,3 @@ def _scaled_cond(a):
664697
return np.inf
665698
c = b / bsum
666699
return np.linalg.cond(c.astype(float))
667-
668-
669-
def ndarray_to_list_of_dicts(
670-
x: np.ndarray, key_map: Dict[str, Union[int, List[int]]]
671-
) -> List[Dict[str, Union[np.number, np.ndarray]]]:
672-
"""
673-
A helper function to convert a ``np.ndarray`` into a list
674-
of dictionaries of parameters. Used when parameters are named.
675-
676-
Args:
677-
x (np.ndarray): parameter array of shape ``(N, n_dim)``, where
678-
``N`` is an integer
679-
key_map (Dict[str, Union[int, List[int]]):
680-
681-
Returns:
682-
list of dictionaries of parameters
683-
"""
684-
return [{key: xi[val] for key, val in key_map.items()} for xi in x]

0 commit comments

Comments
 (0)