Skip to content

Commit 1156799

Browse files
Broadcast and mapping of scale and unscale functions (#505)
* allow broadcasting in `map_scale` * fixup * fixup * add docstring
1 parent 7e80cf1 commit 1156799

File tree

2 files changed

+70
-6
lines changed

2 files changed

+70
-6
lines changed

petab/parameters.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def get_priors_from_df(parameter_df: pd.DataFrame,
352352

353353

354354
def scale(parameter: numbers.Number, scale_str: 'str') -> numbers.Number:
355-
"""Scale parameter according to scale_str
355+
"""Scale parameter according to `scale_str`.
356356
357357
Arguments:
358358
parameter:
@@ -375,7 +375,7 @@ def scale(parameter: numbers.Number, scale_str: 'str') -> numbers.Number:
375375

376376

377377
def unscale(parameter: numbers.Number, scale_str: 'str') -> numbers.Number:
378-
"""Unscale parameter according to scale_str
378+
"""Unscale parameter according to `scale_str`.
379379
380380
Arguments:
381381
parameter:
@@ -397,12 +397,49 @@ def unscale(parameter: numbers.Number, scale_str: 'str') -> numbers.Number:
397397
raise ValueError("Invalid parameter scaling: " + scale_str)
398398

399399

400-
def map_scale(parameters: Iterable[numbers.Number],
401-
scale_strs: Iterable[str]) -> Iterable[numbers.Number]:
402-
"""As scale(), but for Iterables"""
400+
def map_scale(
401+
parameters: Iterable[numbers.Number],
402+
scale_strs: Union[Iterable[str], str]
403+
) -> Iterable[numbers.Number]:
404+
"""Scale the parameters, i.e. as `scale()`, but for Iterables.
405+
406+
Arguments:
407+
parameters:
408+
Parameters to be scaled.
409+
scale_strs:
410+
Scales to apply. Broadcast if a single string.
411+
412+
Returns:
413+
parameters:
414+
The scaled parameters.
415+
"""
416+
if isinstance(scale_strs, str):
417+
scale_strs = [scale_strs] * len(parameters)
403418
return map(lambda x: scale(x[0], x[1]), zip(parameters, scale_strs))
404419

405420

421+
def map_unscale(
422+
parameters: Iterable[numbers.Number],
423+
scale_strs: Union[Iterable[str], str]
424+
) -> Iterable[numbers.Number]:
425+
"""Unscale the parameters, i.e. as `unscale()`, but for Iterables.
426+
427+
Arguments:
428+
parameters:
429+
Parameters to be unscaled.
430+
scale_strs:
431+
Scales that the parameters are currently on.
432+
Broadcast if a single string.
433+
434+
Returns:
435+
parameters:
436+
The unscaled parameters.
437+
"""
438+
if isinstance(scale_strs, str):
439+
scale_strs = [scale_strs] * len(parameters)
440+
return map(lambda x: unscale(x[0], x[1]), zip(parameters, scale_strs))
441+
442+
406443
def normalize_parameter_df(parameter_df: pd.DataFrame) -> pd.DataFrame:
407444
"""Add missing columns and fill in default values."""
408445
df = parameter_df.copy(deep=True)

tests/test_parameters.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Tests for petab/parameters.py"""
22
import tempfile
33
import pytest
4-
4+
import numpy as np
55
import pandas as pd
6+
67
import petab
78
from petab.C import *
89

@@ -170,3 +171,29 @@ def test_normalize_parameter_df():
170171

171172
# check is valid petab
172173
petab.check_parameter_df(actual)
174+
175+
176+
def test_scale_unscale():
177+
"""Test the parameter scaling functions."""
178+
par = 2.5
179+
# scale
180+
assert petab.scale(par, LIN) == par
181+
assert petab.scale(par, LOG) == np.log(par)
182+
assert petab.scale(par, LOG10) == np.log10(par)
183+
# unscale
184+
assert petab.unscale(par, LIN) == par
185+
assert petab.unscale(par, LOG) == np.exp(par)
186+
assert petab.unscale(par, LOG10) == 10**par
187+
188+
# map scale
189+
assert list(petab.map_scale([par]*3, [LIN, LOG, LOG10])) == \
190+
[par, np.log(par), np.log10(par)]
191+
# map unscale
192+
assert list(petab.map_unscale([par]*3, [LIN, LOG, LOG10])) == \
193+
[par, np.exp(par), 10**par]
194+
195+
# map broadcast
196+
assert list(petab.map_scale([par, 2*par], LOG)) == \
197+
list(np.log([par, 2*par]))
198+
assert list(petab.map_unscale([par, 2*par], LOG)) == \
199+
list(np.exp([par, 2*par]))

0 commit comments

Comments
 (0)