@@ -35,17 +35,20 @@ def infer_dict_mapping(state):
3535
3636def array_to_dict (ary , param_slice_shape ):
3737 return {
38- key : ary [:, slc ].reshape ((- 1 ,)+ shape )
38+ key : ary [:, slc ].reshape ((- 1 ,) + shape )
3939 for key , (slc , shape ) in param_slice_shape .items ()
4040 }
4141
4242
4343def array_to_list_of_dicts (ary , param_slice_shape ):
4444 # reshape adds a small amount of overhead; don't do it unless necessary
45- return [{
46- key : ary_i [slc ].reshape (shape ) if len (shape ) > 1 else ary_i [slc ]
47- for key , (slc , shape ) in param_slice_shape .items ()
48- } for ary_i in ary ]
45+ return [
46+ {
47+ key : ary_i [slc ].reshape (shape ) if len (shape ) > 1 else ary_i [slc ]
48+ for key , (slc , shape ) in param_slice_shape .items ()
49+ }
50+ for ary_i in ary
51+ ]
4952
5053
5154def collapse_and_hstack (values , nwalkers = None ):
@@ -199,15 +202,15 @@ def __init__(
199202 if isinstance (parameter_names , Sequence ):
200203 if len (parameter_names ) != ndim :
201204 raise ValueError (
202- f"`parameter_names` does not specify { ndim } names" )
205+ f"`parameter_names` does not specify { ndim } names"
206+ )
203207 parameter_names = dict (zip (parameter_names , range (ndim )))
204208
205209 indices = np .arange (ndim )
206210
207211 try :
208212 index_map = {
209- key : indices [slc ]
210- for key , slc in parameter_names .items ()
213+ key : indices [slc ] for key , slc in parameter_names .items ()
211214 }
212215 indexed = collapse_and_hstack (index_map .values ())
213216 except IndexError as err :
@@ -330,7 +333,8 @@ def sample(
330333 _state = {key : val [0 ] for key , val in initial_state .items ()}
331334 self .param_slice_shape = infer_dict_mapping (_state )
332335 initial_state = collapse_and_hstack (
333- initial_state .values (), self .nwalkers )
336+ initial_state .values (), self .nwalkers
337+ )
334338
335339 state = State (initial_state , copy = True )
336340 state_shape = np .shape (state .coords )
0 commit comments