22
33import warnings
44from itertools import count
5- from typing import Dict , List , Optional , Union
5+ from typing import Dict , Iterable , List , Optional , Sequence , Union
66
77import numpy as np
88
1515
1616__all__ = ["EnsembleSampler" , "walkers_independent" ]
1717
18- try :
19- from collections .abc import Iterable
20- except ImportError :
21- # for py2.7, will be an Exception in 3.8
22- from collections import Iterable
18+ ParameterNamesT = Union [
19+ Sequence [str ], Dict [str , Union [slice , int , Sequence [int ]]]
20+ ]
21+
22+
23+ def infer_dict_mapping (state ):
24+ i0 = 0
25+ param_slice_shape = {}
26+ for key , val in state .items ():
27+ val = np .asarray (val )
28+ i1 = i0 + val .size
29+ slc = slice (i0 , i1 ) if val .size > 1 else i0
30+ param_slice_shape [key ] = slc , val .shape
31+ i0 = i1
32+
33+ return param_slice_shape
34+
35+
36+ def array_to_dict (ary , param_slice_shape ):
37+ return {
38+ key : ary [:, slc ].reshape ((- 1 ,)+ shape )
39+ for key , (slc , shape ) in param_slice_shape .items ()
40+ }
41+
42+
43+ def array_to_list_of_dicts (ary , param_slice_shape ):
44+ # reshape adds a small amount of overhead; don't do it unless necessary
45+ return [{
46+ key : ary_i [slc ].reshape (shape ) if len (shape ) > 1 else ary_i [slc ]
47+ for key , (slc , shape ) in param_slice_shape .items ()
48+ } for ary_i in ary ]
49+
50+
51+ def collapse_and_hstack (values , nwalkers = None ):
52+ shape = (nwalkers , - 1 ) if nwalkers is not None else - 1
53+ return np .hstack ([np .asarray (val ).reshape (shape ) for val in values ])
2354
2455
2556class EnsembleSampler (object ):
@@ -62,7 +93,8 @@ class EnsembleSampler(object):
6293 to accept a list of position vectors instead of just one. Note
6394 that ``pool`` will be ignored if this is ``True``.
6495 (default: ``False``)
65- parameter_names (Optional[Union[List[str], Dict[str, List[int]]]]):
96+ parameter_names (Union[Sequence[str],
97+ Dict[str, Union[slice, int, Sequence[int]]]):
6698 names of individual parameters or groups of parameters. If
6799 specified, the ``log_prob_fn`` will recieve a dictionary of
68100 parameters, rather than a ``np.ndarray``.
@@ -81,7 +113,7 @@ def __init__(
81113 backend = None ,
82114 vectorize = False ,
83115 blobs_dtype = None ,
84- parameter_names : Optional [Union [ Dict [ str , int ], List [ str ]] ] = None ,
116+ parameter_names : Optional [ParameterNamesT ] = None ,
85117 # Deprecated...
86118 a = None ,
87119 postargs = None ,
@@ -163,48 +195,39 @@ def __init__(
163195 # ``args`` and ``kwargs`` pickleable.
164196 self .log_prob_fn = _FunctionWrapper (log_prob_fn , args , kwargs )
165197
166- # Save the parameter names
167- self .params_are_named : bool = parameter_names is not None
168- if self .params_are_named :
169- assert isinstance (parameter_names , (list , dict ))
170-
171- # Don't support vectorizing yet
172- msg = "named parameters with vectorization unsupported for now"
173- assert not self .vectorize , msg
174-
175- # Check for duplicate names
176- dupes = set ()
177- uniq = []
178- for name in parameter_names :
179- if name not in dupes :
180- uniq .append (name )
181- dupes .add (name )
182- msg = f"duplicate paramters: { dupes } "
183- assert len (uniq ) == len (parameter_names ), msg
184-
185- if isinstance (parameter_names , list ):
186- # Check for all named
187- msg = "name all parameters or set `parameter_names` to `None`"
188- assert len (parameter_names ) == ndim , msg
189- # Convert a list to a dict
190- parameter_names : Dict [str , int ] = {
191- name : i for i , name in enumerate (parameter_names )
198+ if parameter_names is not None :
199+ if isinstance (parameter_names , Sequence ):
200+ if len (parameter_names ) != ndim :
201+ raise ValueError (
202+ f"`parameter_names` does not specify { ndim } names" )
203+ parameter_names = dict (zip (parameter_names , range (ndim )))
204+
205+ indices = np .arange (ndim )
206+
207+ try :
208+ index_map = {
209+ key : indices [slc ]
210+ for key , slc in parameter_names .items ()
192211 }
212+ indexed = collapse_and_hstack (index_map .values ())
213+ except IndexError as err :
214+ msg = "`parameter_names` specifies out-of-bounds element(s)"
215+ raise ValueError (msg ) from err
193216
194- # Check not too many names
195- msg = "too many names"
196- assert len ( parameter_names ) <= ndim , msg
197-
198- # Check all indices appear
199- values = [
200- v if isinstance ( v , list ) else [ v ]
201- for v in parameter_names . values ()
202- ]
203- values = [ item for sublist in values for item in sublist ]
204- values = set ( values )
205- msg = f"not all values appear -- set should be 0 to { ndim - 1 } "
206- assert values == set ( np . arange ( ndim )), msg
207- self .parameter_names = parameter_names
217+ if len ( indexed ) != ndim :
218+ raise ValueError (
219+ "` parameter_names` does not specify indices for"
220+ f" { ndim } parameters"
221+ )
222+ if set ( indexed ) != set ( indices ):
223+ raise ValueError (
224+ "` parameter_names` does not specify indices"
225+ f" 0 through { ndim - 1 } "
226+ )
227+
228+ self . param_slice_shape = infer_dict_mapping ( index_map )
229+ else :
230+ self .param_slice_shape = None
208231
209232 @property
210233 def random_state (self ):
@@ -266,7 +289,8 @@ def sample(
266289 """Advance the chain as a generator
267290
268291 Args:
269- initial_state (State or ndarray[nwalkers, ndim]): The initial
292+ initial_state (State or ndarray[nwalkers, ndim] or
293+ dict[str, float | np.ndarray[nwalkers. ...]]): The initial
270294 :class:`State` or positions of the walkers in the
271295 parameter space.
272296 iterations (Optional[int or NoneType]): The number of steps to generate.
@@ -302,6 +326,12 @@ def sample(
302326 if iterations is None and store :
303327 raise ValueError ("'store' must be False when 'iterations' is None" )
304328 # Interpret the input as a walker state and check the dimensions.
329+ if isinstance (initial_state , dict ):
330+ _state = {key : val [0 ] for key , val in initial_state .items ()}
331+ self .param_slice_shape = infer_dict_mapping (_state )
332+ initial_state = collapse_and_hstack (
333+ initial_state .values (), self .nwalkers )
334+
305335 state = State (initial_state , copy = True )
306336 state_shape = np .shape (state .coords )
307337 if state_shape != (self .nwalkers , self .ndim ):
@@ -472,8 +502,11 @@ def compute_log_prob(self, coords):
472502 raise ValueError ("At least one parameter value was NaN" )
473503
474504 # If the parmaeters are named, then switch to dictionaries
475- if self .params_are_named :
476- p = ndarray_to_list_of_dicts (p , self .parameter_names )
505+ if self .param_slice_shape :
506+ if self .vectorize :
507+ p = array_to_dict (p , self .param_slice_shape )
508+ else :
509+ p = array_to_list_of_dicts (p , self .param_slice_shape )
477510
478511 # Run the log-probability calculations (optionally in parallel).
479512 if self .vectorize :
@@ -664,21 +697,3 @@ def _scaled_cond(a):
664697 return np .inf
665698 c = b / bsum
666699 return np .linalg .cond (c .astype (float ))
667-
668-
669- def ndarray_to_list_of_dicts (
670- x : np .ndarray , key_map : Dict [str , Union [int , List [int ]]]
671- ) -> List [Dict [str , Union [np .number , np .ndarray ]]]:
672- """
673- A helper function to convert a ``np.ndarray`` into a list
674- of dictionaries of parameters. Used when parameters are named.
675-
676- Args:
677- x (np.ndarray): parameter array of shape ``(N, n_dim)``, where
678- ``N`` is an integer
679- key_map (Dict[str, Union[int, List[int]]):
680-
681- Returns:
682- list of dictionaries of parameters
683- """
684- return [{key : xi [val ] for key , val in key_map .items ()} for xi in x ]
0 commit comments