Skip to content

Commit c0158ba

Browse files
authored
Startpoint sampling for a subset of parameters (#230)
Allow passing a list of parameter IDs to startpoint sampling for subsetting/reordering parameters.
1 parent e678dd7 commit c0158ba

File tree

4 files changed

+36
-6
lines changed

4 files changed

+36
-6
lines changed

petab/parameters.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,20 +424,32 @@ def append_overrides(overrides):
424424
def get_priors_from_df(
425425
parameter_df: pd.DataFrame,
426426
mode: Literal["initialization", "objective"],
427+
parameter_ids: Sequence[str] = None,
427428
) -> List[Tuple]:
428429
"""Create list with information about the parameter priors
429430
430431
Arguments:
431432
parameter_df: PEtab parameter table
432433
mode: ``'initialization'`` or ``'objective'``
434+
parameter_ids: A sequence of parameter IDs for which to sample starting points.
435+
For subsetting or reordering the parameters.
436+
Defaults to all estimated parameters.
433437
434438
Returns:
435439
List with prior information.
436440
"""
437-
438441
# get types and parameters of priors from dataframe
439442
par_to_estimate = parameter_df.loc[parameter_df[ESTIMATE] == 1]
440443

444+
if parameter_ids:
445+
try:
446+
par_to_estimate = par_to_estimate.loc[parameter_ids, :]
447+
except KeyError as e:
448+
missing_ids = set(parameter_ids) - set(par_to_estimate.index)
449+
raise KeyError(
450+
f"Parameter table does not contain estimated parameter(s) {missing_ids}."
451+
) from e
452+
441453
prior_list = []
442454
for _, row in par_to_estimate.iterrows():
443455
# retrieve info about type

petab/problem.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -930,13 +930,13 @@ def create_parameter_df(self, *args, **kwargs):
930930
**kwargs,
931931
)
932932

933-
def sample_parameter_startpoints(self, n_starts: int = 100):
933+
def sample_parameter_startpoints(self, n_starts: int = 100, **kwargs):
934934
"""Create 2D array with starting points for optimization
935935
936936
See :py:func:`petab.sample_parameter_startpoints`.
937937
"""
938938
return sampling.sample_parameter_startpoints(
939-
self.parameter_df, n_starts=n_starts
939+
self.parameter_df, n_starts=n_starts, **kwargs
940940
)
941941

942942
def sample_parameter_startpoints_dict(

petab/sampling.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Functions related to parameter sampling"""
22

3-
from typing import Tuple
3+
from typing import Sequence, Tuple
44

55
import numpy as np
66
import pandas as pd
@@ -110,24 +110,28 @@ def sample_parameter_startpoints(
110110
parameter_df: pd.DataFrame,
111111
n_starts: int = 100,
112112
seed: int = None,
113+
parameter_ids: Sequence[str] = None,
113114
) -> np.array:
114115
"""Create :class:`numpy.array` with starting points for an optimization
115116
116117
Arguments:
117118
parameter_df: PEtab parameter DataFrame
118119
n_starts: Number of points to be sampled
119120
seed: Random number generator seed (see :func:`numpy.random.seed`)
121+
parameter_ids: A sequence of parameter IDs for which to sample starting points.
122+
For subsetting or reordering the parameters.
123+
Defaults to all estimated parameters.
120124
121125
Returns:
122126
Array of sampled starting points with dimensions
123-
n_startpoints x n_optimization_parameters
127+
`n_startpoints` x `n_optimization_parameters`
124128
"""
125129
if seed is not None:
126130
np.random.seed(seed)
127131

128132
# get types and parameters of priors from dataframe
129133
prior_list = parameters.get_priors_from_df(
130-
parameter_df, mode=INITIALIZATION
134+
parameter_df, mode=INITIALIZATION, parameter_ids=parameter_ids
131135
)
132136

133137
startpoints = [sample_from_prior(prior, n_starts) for prior in prior_list]

tests/test_petab.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def test_get_priors_from_df():
179179
"""Check petab.get_priors_from_df."""
180180
parameter_df = pd.DataFrame(
181181
{
182+
PARAMETER_ID: ["p1", "p2", "p3", "p4", "p5"],
182183
PARAMETER_SCALE: [LOG10, LOG10, LOG10, LOG10, LOG10],
183184
LOWER_BOUND: [1e-8, 1e-9, 1e-10, 1e-11, 1e-5],
184185
UPPER_BOUND: [1e8, 1e9, 1e10, 1e11, 1e5],
@@ -193,6 +194,7 @@ def test_get_priors_from_df():
193194
],
194195
}
195196
)
197+
parameter_df = petab.get_parameter_df(parameter_df)
196198

197199
prior_list = petab.get_priors_from_df(parameter_df, mode=INITIALIZATION)
198200

@@ -225,6 +227,18 @@ def test_get_priors_from_df():
225227
assert prior_pars[1] == (-5, 5)
226228
assert prior_pars[2] == (1e-5, 1e5)
227229

230+
# check subsetting / reordering works
231+
prior_list_subset = petab.get_priors_from_df(
232+
parameter_df, mode=INITIALIZATION, parameter_ids=["p2", "p1"]
233+
)
234+
assert len(prior_list_subset) == 2
235+
assert prior_list_subset == [prior_list[1], prior_list[0]]
236+
237+
with pytest.raises(KeyError, match="Parameter table does not contain"):
238+
petab.get_priors_from_df(
239+
parameter_df, mode=INITIALIZATION, parameter_ids=["non_existent"]
240+
)
241+
228242

229243
def test_startpoint_sampling(fujita_model_scaling):
230244
n_starts = 10

0 commit comments

Comments
 (0)