@@ -352,7 +352,7 @@ def get_priors_from_df(parameter_df: pd.DataFrame,
352352
353353
354354def scale (parameter : numbers .Number , scale_str : 'str' ) -> numbers .Number :
355- """Scale parameter according to scale_str
355+ """Scale parameter according to ` scale_str`.
356356
357357 Arguments:
358358 parameter:
@@ -375,7 +375,7 @@ def scale(parameter: numbers.Number, scale_str: 'str') -> numbers.Number:
375375
376376
377377def unscale (parameter : numbers .Number , scale_str : 'str' ) -> numbers .Number :
378- """Unscale parameter according to scale_str
378+ """Unscale parameter according to ` scale_str`.
379379
380380 Arguments:
381381 parameter:
@@ -397,12 +397,49 @@ def unscale(parameter: numbers.Number, scale_str: 'str') -> numbers.Number:
397397 raise ValueError ("Invalid parameter scaling: " + scale_str )
398398
399399
400- def map_scale (parameters : Iterable [numbers .Number ],
401- scale_strs : Iterable [str ]) -> Iterable [numbers .Number ]:
402- """As scale(), but for Iterables"""
400+ def map_scale (
401+ parameters : Iterable [numbers .Number ],
402+ scale_strs : Union [Iterable [str ], str ]
403+ ) -> Iterable [numbers .Number ]:
404+ """Scale the parameters, i.e. as `scale()`, but for Iterables.
405+
406+ Arguments:
407+ parameters:
408+ Parameters to be scaled.
409+ scale_strs:
410+ Scales to apply. Broadcast if a single string.
411+
412+ Returns:
413+ parameters:
414+ The scaled parameters.
415+ """
416+ if isinstance (scale_strs , str ):
417+ scale_strs = [scale_strs ] * len (parameters )
403418 return map (lambda x : scale (x [0 ], x [1 ]), zip (parameters , scale_strs ))
404419
405420
421+ def map_unscale (
422+ parameters : Iterable [numbers .Number ],
423+ scale_strs : Union [Iterable [str ], str ]
424+ ) -> Iterable [numbers .Number ]:
425+ """Unscale the parameters, i.e. as `unscale()`, but for Iterables.
426+
427+ Arguments:
428+ parameters:
429+ Parameters to be unscaled.
430+ scale_strs:
431+ Scales that the parameters are currently on.
432+ Broadcast if a single string.
433+
434+ Returns:
435+ parameters:
436+ The unscaled parameters.
437+ """
438+ if isinstance (scale_strs , str ):
439+ scale_strs = [scale_strs ] * len (parameters )
440+ return map (lambda x : unscale (x [0 ], x [1 ]), zip (parameters , scale_strs ))
441+
442+
406443def normalize_parameter_df (parameter_df : pd .DataFrame ) -> pd .DataFrame :
407444 """Add missing columns and fill in default values."""
408445 df = parameter_df .copy (deep = True )
0 commit comments