Skip to content

Commit 4327f29

Browse files
committed
Store model entity IDs and names in ReturnData
This makes `ReturnData` more usable without additionally passing around the corresponding `Model`, and avoids the previous hack for setting those IDs specifically for `RDataReporting::full`. Only `span`s are copied, so barely any additional memory is required.
1 parent 0d6f26a commit 4327f29

File tree

5 files changed

+126
-41
lines changed

5 files changed

+126
-41
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ See also our [versioning policy](https://amici.readthedocs.io/en/latest/versioni
109109
* The default directory for model import changed, and a base directory
110110
can now be specified via the `AMICI_MODELS_ROOT` environment variable.
111111
See `amici.get_model_dir` for details.
112+
* IDs and names of model entities are now not only accessible via `Model`, but
113+
also via `ReturnData`
114+
(`ReturnData.{free_parameter_ids,observable_ids,...}`,
115+
`ReturnData.{free_parameter_names,observable_names,...}`).
112116

113117
**Fixes**
114118

include/amici/rdata.h

Lines changed: 73 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#include "amici/vector.h"
1010

1111
#include <vector>
12+
#include <span>
13+
#include <string_view>
1214

1315
namespace amici {
1416
class ReturnData;
@@ -28,7 +30,7 @@ void serialize(Archive& ar, amici::ReturnData& r, unsigned int version);
2830
namespace amici {
2931

3032
/**
31-
* @brief Stores all data to be returned by amici::runAmiciSimulation.
33+
* @brief Stores all data to be returned by amici::run_simulation.
3234
*
3335
* NOTE: multi-dimensional arrays are stored in row-major order (C-style)
3436
*/
@@ -57,14 +59,38 @@ class ReturnData : public ModelDimensions {
5759
* @param sigma_res_ indicates whether additional residuals are to be added
5860
* for each sigma
5961
* @param sigma_offset_ offset to ensure real-valuedness of sigma residuals
62+
* @param free_parameter_ids_ IDs of the free parameters
63+
* @param free_parameter_names_ Names of the free parameters
64+
* @param fixed_parameter_ids_ IDs of the fixed parameters
65+
* @param fixed_parameter_names_ Names of the fixed parameters
66+
* @param state_ids_ IDs of state variables
67+
* @param state_names_ Names of state variables
68+
* @param state_ids_solver_ IDs of solver state variables
69+
* @param state_names_solver_ Names of solver state variables
70+
* @param observable_ids_ IDs of observables
71+
* @param observable_names_ Names of observables
72+
* @param expression_ids_ IDs of expressions
73+
* @param expression_names_ Names of expressions
6074
*/
6175
ReturnData(
6276
std::vector<realtype> ts_, ModelDimensions const& model_dimensions_,
6377
int nmaxevent_, int newton_maxsteps_, std::vector<int> plist_,
6478
std::vector<ParameterScaling> pscale_, SecondOrderMode o2mode_,
6579
SensitivityOrder sensi_, SensitivityMethod sensi_meth_,
6680
RDataReporting rdrm_, bool quadratic_llh_, bool sigma_res_,
67-
realtype sigma_offset_
81+
realtype sigma_offset_,
82+
std::span<std::string_view const> free_parameter_ids_,
83+
std::span<std::string_view const> free_parameter_names_,
84+
std::span<std::string_view const> fixed_parameter_ids_,
85+
std::span<std::string_view const> fixed_parameter_names_,
86+
std::span<std::string_view const> state_ids_,
87+
std::span<std::string_view const> state_names_,
88+
std::span<std::string_view const> state_ids_solver_,
89+
std::span<std::string_view const> state_names_solver_,
90+
std::span<std::string_view const> observable_ids_,
91+
std::span<std::string_view const> observable_names_,
92+
std::span<std::string_view const> expression_ids_,
93+
std::span<std::string_view const> expression_names_
6894
);
6995

7096
/**
@@ -112,7 +138,7 @@ class ReturnData : public ModelDimensions {
112138
* (shape `nx_solver` x `nx_solver`, row-major) evaluated at `t_last`.
113139
*
114140
* The corresponding state variable IDs can be obtained from
115-
* `Model::getStateIdsSolver()`.
141+
* `state_ids_solver()`.
116142
*/
117143
std::vector<realtype> J;
118144

@@ -124,7 +150,7 @@ class ReturnData : public ModelDimensions {
124150
* at timepoints `ReturnData::ts` (shape `nt` x `nw`, row major).
125151
*
126152
* The corresponding expression IDs can be obtained from
127-
* `Model::getExpressionIds()`.
153+
* `expression_ids`.
128154
*/
129155
std::vector<realtype> w;
130156

@@ -171,7 +197,7 @@ class ReturnData : public ModelDimensions {
171197
* (shape `nt` x `nx_rdata`, row-major).
172198
*
173199
* The corresponding state variable IDs can be obtained from
174-
* `Model::getStateIds()`.
200+
* `state_ids`.
175201
*/
176202
std::vector<realtype> x;
177203

@@ -184,7 +210,7 @@ class ReturnData : public ModelDimensions {
184210
* (shape `nt` x `nplist` x `nx_rdata`, row-major).
185211
*
186212
* The corresponding state variable IDs can be obtained from
187-
* `Model::getStateIds()`.
213+
* `state_ids`.
188214
*/
189215
std::vector<realtype> sx;
190216

@@ -195,7 +221,7 @@ class ReturnData : public ModelDimensions {
195221
* (shape `nt` x `ny`, row-major).
196222
*
197223
* The corresponding observable IDs can be obtained from
198-
* `Model::getObservableIds()`.
224+
* `observable_ids`.
199225
*/
200226
std::vector<realtype> y;
201227

@@ -211,7 +237,7 @@ class ReturnData : public ModelDimensions {
211237
* (shape `nt` x `nplist` x `ny`, row-major).
212238
*
213239
* The corresponding observable IDs can be obtained from
214-
* `Model::getObservableIds()`.
240+
* `observable_ids`.
215241
*/
216242
std::vector<realtype> sy;
217243

@@ -412,7 +438,7 @@ class ReturnData : public ModelDimensions {
412438
* @brief Initial state of the main simulation (shape `nx_rdata`).
413439
*
414440
* The corresponding state variable IDs can be obtained from
415-
* `Model::getStateIds()`.
441+
* `state_ids`.
416442
*/
417443
std::vector<realtype> x0;
418444

@@ -422,7 +448,7 @@ class ReturnData : public ModelDimensions {
422448
* The values of the state variables at the pre-equilibration steady state
423449
* (shape `nx_rdata`).
424450
* The corresponding state variable IDs can be obtained from
425-
* `Model::getStateIds()`.
451+
* `state_ids`.
426452
*/
427453
std::vector<realtype> x_ss;
428454

@@ -526,10 +552,46 @@ class ReturnData : public ModelDimensions {
526552
* @brief Indices of the parameters w.r.t. which sensitivities were
527553
* computed.
528554
*
529-
* The indices refer to parameter IDs in Model::getParameterIds().
555+
* The indices refer to parameter IDs in `free_parameter_ids`.
530556
*/
531557
std::vector<int> plist;
532558

559+
/** IDs of the free parameters */
560+
const std::span<std::string_view const> free_parameter_ids;
561+
562+
/** Names of the free parameters */
563+
const std::span<std::string_view const> free_parameter_names;
564+
565+
/** IDs of the fixed parameters */
566+
const std::span<std::string_view const> fixed_parameter_ids;
567+
568+
/** Names of the fixed parameters */
569+
const std::span<std::string_view const> fixed_parameter_names;
570+
571+
/** IDs of state variables */
572+
const std::span<std::string_view const> state_ids;
573+
574+
/** Names of state variables */
575+
const std::span<std::string_view const> state_names;
576+
577+
/** IDs of solver state variables */
578+
const std::span<std::string_view const> state_ids_solver;
579+
580+
/** Names of solver state variables */
581+
const std::span<std::string_view const> state_names_solver;
582+
583+
/** IDs of observables */
584+
const std::span<std::string_view const> observable_ids;
585+
586+
/** Names of observables */
587+
const std::span<std::string_view const> observable_names;
588+
589+
/** IDs of expressions */
590+
const std::span<std::string_view const> expression_ids;
591+
592+
/** Names of expressions */
593+
const std::span<std::string_view const> expression_names;
594+
533595
protected:
534596
/** offset for sigma_residuals */
535597
realtype sigma_offset{0.0};

python/sdist/amici/sim/sundials/_swig_wrappers.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
AmiciExpDataVector,
2222
AmiciModel,
2323
AmiciSolver,
24-
RDataReporting,
2524
SensitivityMethod,
2625
SensitivityOrder,
2726
Solver,
@@ -82,8 +81,6 @@ def run_simulation(
8281
_get_ptr(solver), _get_ptr(edata), _get_ptr(model)
8382
)
8483
_log_simulation(rdata)
85-
if solver.get_return_data_reporting_mode() == RDataReporting.full:
86-
_ids_and_names_to_rdata(rdata, model)
8784
return ReturnDataView(rdata)
8885

8986

@@ -129,8 +126,6 @@ def run_simulations(
129126
)
130127
for rdata in rdata_ptr_list:
131128
_log_simulation(rdata)
132-
if solver.get_return_data_reporting_mode() == RDataReporting.full:
133-
_ids_and_names_to_rdata(rdata, model)
134129

135130
return [ReturnDataView(r) for r in rdata_ptr_list]
136131

@@ -269,29 +264,6 @@ def _log_simulation(rdata: amici_swig.ReturnData):
269264
)
270265

271266

272-
def _ids_and_names_to_rdata(
273-
rdata: amici_swig.ReturnData, model: amici_swig.Model
274-
):
275-
"""Copy entity IDs and names from a Model to ReturnData."""
276-
for entity_type in (
277-
"state",
278-
"observable",
279-
"expression",
280-
"free_parameter",
281-
"fixed_parameter",
282-
):
283-
for name_or_id in ("ids", "names"):
284-
names_or_ids = getattr(model, f"get_{entity_type}_{name_or_id}")()
285-
setattr(
286-
rdata,
287-
f"{entity_type.lower()}_{name_or_id.lower()}",
288-
names_or_ids,
289-
)
290-
291-
rdata.state_ids_solver = model.get_state_ids_solver()
292-
rdata.state_names_solver = model.get_state_names_solver()
293-
294-
295267
@contextlib.contextmanager
296268
def _solver_settings(solver, sensi_method=None, sensi_order=None):
297269
"""Context manager to temporarily apply solver settings."""

python/tests/test_swig_interface.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,23 @@ def test_rdataview(sbml_example_presimulation_module):
604604
print(str(e))
605605

606606

607+
def test_rdata_ids(sbml_example_presimulation_module):
608+
"""Test that rdata IDs are correctly set."""
609+
model_module = sbml_example_presimulation_module
610+
model = model_module.get_model()
611+
612+
model.set_timepoints([0, 1, 2])
613+
rdata = model.simulate()
614+
615+
assert isinstance(rdata.free_parameter_ids, tuple)
616+
assert rdata.free_parameter_ids == model.get_free_parameter_ids()
617+
assert rdata.fixed_parameter_ids == model.get_fixed_parameter_ids()
618+
assert rdata.state_ids == model.get_state_ids()
619+
assert rdata.state_ids_solver == model.get_state_ids_solver()
620+
assert rdata.observable_ids == model.get_observable_ids()
621+
assert rdata.expression_ids == model.get_expression_ids()
622+
623+
607624
def test_python_exceptions(sbml_example_presimulation_module):
608625
"""Test that C++ exceptions are correctly caught and re-raised in Python."""
609626
from amici.sim.sundials import run_simulation

src/rdata.cpp

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,13 @@ ReturnData::ReturnData(Solver const& solver, Model const& model)
2525
model.get_second_order_mode(), solver.get_sensitivity_order(),
2626
solver.get_sensitivity_method(),
2727
solver.get_return_data_reporting_mode(), model.has_quadratic_llh(),
28-
model.get_add_sigma_residuals(), model.get_minimum_sigma_residuals()
28+
model.get_add_sigma_residuals(), model.get_minimum_sigma_residuals(),
29+
model.get_free_parameter_ids(), model.get_free_parameter_names(),
30+
model.get_fixed_parameter_ids(), model.get_fixed_parameter_names(),
31+
model.get_state_ids(), model.get_state_names(),
32+
model.get_state_ids_solver(), model.get_state_names_solver(),
33+
model.get_observable_ids(), model.get_observable_names(),
34+
model.get_expression_ids(), model.get_expression_names()
2935
) {}
3036

3137
ReturnData::ReturnData(
@@ -34,7 +40,19 @@ ReturnData::ReturnData(
3440
std::vector<ParameterScaling> pscale_, SecondOrderMode o2mode_,
3541
SensitivityOrder sensi_, SensitivityMethod sensi_meth_,
3642
RDataReporting rdrm_, bool quadratic_llh_, bool sigma_res_,
37-
realtype sigma_offset_
43+
realtype sigma_offset_,
44+
std::span<std::string_view const> free_parameter_ids_,
45+
std::span<std::string_view const> free_parameter_names_,
46+
std::span<std::string_view const> fixed_parameter_ids_,
47+
std::span<std::string_view const> fixed_parameter_names_,
48+
std::span<std::string_view const> state_ids_,
49+
std::span<std::string_view const> state_names_,
50+
std::span<std::string_view const> state_ids_solver_,
51+
std::span<std::string_view const> state_names_solver_,
52+
std::span<std::string_view const> observable_ids_,
53+
std::span<std::string_view const> observable_names_,
54+
std::span<std::string_view const> expression_ids_,
55+
std::span<std::string_view const> expression_names_
3856
)
3957
: ModelDimensions(model_dimensions_)
4058
, ts(std::move(ts_))
@@ -49,6 +67,18 @@ ReturnData::ReturnData(
4967
, rdata_reporting(rdrm_)
5068
, sigma_res(sigma_res_)
5169
, plist(plist_)
70+
, free_parameter_ids(free_parameter_ids_)
71+
, free_parameter_names(free_parameter_names_)
72+
, fixed_parameter_ids(fixed_parameter_ids_)
73+
, fixed_parameter_names(fixed_parameter_names_)
74+
, state_ids(state_ids_)
75+
, state_names(state_names_)
76+
, state_ids_solver(state_ids_solver_)
77+
, state_names_solver(state_names_solver_)
78+
, observable_ids(observable_ids_)
79+
, observable_names(observable_names_)
80+
, expression_ids(expression_ids_)
81+
, expression_names(expression_names_)
5282
, sigma_offset(sigma_offset_)
5383
, nroots_(ne) {
5484
model_dimensions_.validate();

0 commit comments

Comments
 (0)