Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ See also our [versioning policy](https://amici.readthedocs.io/en/latest/versioni
* The default directory for model import changed, and a base directory
can now be specified via the `AMICI_MODELS_ROOT` environment variable.
See `amici.get_model_dir` for details.
* IDs and names of model entities are now not only accessible via `Model`, but
also via `ReturnData`
(`ReturnData.{free_parameter_ids,observable_ids,...}`,
`ReturnData.{free_parameter_names,observable_names,...}`).

**Fixes**

Expand Down
84 changes: 73 additions & 11 deletions include/amici/rdata.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "amici/vector.h"

#include <vector>
#include <span>
#include <string_view>

namespace amici {
class ReturnData;
Expand All @@ -28,7 +30,7 @@ void serialize(Archive& ar, amici::ReturnData& r, unsigned int version);
namespace amici {

/**
* @brief Stores all data to be returned by amici::runAmiciSimulation.
* @brief Stores all data to be returned by amici::run_simulation.
*
* NOTE: multi-dimensional arrays are stored in row-major order (C-style)
*/
Expand Down Expand Up @@ -57,14 +59,38 @@ class ReturnData : public ModelDimensions {
* @param sigma_res_ indicates whether additional residuals are to be added
* for each sigma
* @param sigma_offset_ offset to ensure real-valuedness of sigma residuals
* @param free_parameter_ids_ IDs of the free parameters
* @param free_parameter_names_ Names of the free parameters
* @param fixed_parameter_ids_ IDs of the fixed parameters
* @param fixed_parameter_names_ Names of the fixed parameters
* @param state_ids_ IDs of state variables
* @param state_names_ Names of state variables
* @param state_ids_solver_ IDs of solver state variables
* @param state_names_solver_ Names of solver state variables
* @param observable_ids_ IDs of observables
* @param observable_names_ Names of observables
* @param expression_ids_ IDs of expressions
* @param expression_names_ Names of expressions
*/
ReturnData(
std::vector<realtype> ts_, ModelDimensions const& model_dimensions_,
int nmaxevent_, int newton_maxsteps_, std::vector<int> plist_,
std::vector<ParameterScaling> pscale_, SecondOrderMode o2mode_,
SensitivityOrder sensi_, SensitivityMethod sensi_meth_,
RDataReporting rdrm_, bool quadratic_llh_, bool sigma_res_,
realtype sigma_offset_
realtype sigma_offset_,
std::span<std::string_view const> free_parameter_ids_,
std::span<std::string_view const> free_parameter_names_,
std::span<std::string_view const> fixed_parameter_ids_,
std::span<std::string_view const> fixed_parameter_names_,
std::span<std::string_view const> state_ids_,
std::span<std::string_view const> state_names_,
std::span<std::string_view const> state_ids_solver_,
std::span<std::string_view const> state_names_solver_,
std::span<std::string_view const> observable_ids_,
std::span<std::string_view const> observable_names_,
std::span<std::string_view const> expression_ids_,
std::span<std::string_view const> expression_names_
);

/**
Expand Down Expand Up @@ -112,7 +138,7 @@ class ReturnData : public ModelDimensions {
* (shape `nx_solver` x `nx_solver`, row-major) evaluated at `t_last`.
*
* The corresponding state variable IDs can be obtained from
* `Model::getStateIdsSolver()`.
* `state_ids_solver()`.
*/
std::vector<realtype> J;

Expand All @@ -124,7 +150,7 @@ class ReturnData : public ModelDimensions {
* at timepoints `ReturnData::ts` (shape `nt` x `nw`, row major).
*
* The corresponding expression IDs can be obtained from
* `Model::getExpressionIds()`.
* `expression_ids`.
*/
std::vector<realtype> w;

Expand Down Expand Up @@ -171,7 +197,7 @@ class ReturnData : public ModelDimensions {
* (shape `nt` x `nx_rdata`, row-major).
*
* The corresponding state variable IDs can be obtained from
* `Model::getStateIds()`.
* `state_ids`.
*/
std::vector<realtype> x;

Expand All @@ -184,7 +210,7 @@ class ReturnData : public ModelDimensions {
* (shape `nt` x `nplist` x `nx_rdata`, row-major).
*
* The corresponding state variable IDs can be obtained from
* `Model::getStateIds()`.
* `state_ids`.
*/
std::vector<realtype> sx;

Expand All @@ -195,7 +221,7 @@ class ReturnData : public ModelDimensions {
* (shape `nt` x `ny`, row-major).
*
* The corresponding observable IDs can be obtained from
* `Model::getObservableIds()`.
* `observable_ids`.
*/
std::vector<realtype> y;

Expand All @@ -211,7 +237,7 @@ class ReturnData : public ModelDimensions {
* (shape `nt` x `nplist` x `ny`, row-major).
*
* The corresponding observable IDs can be obtained from
* `Model::getObservableIds()`.
* `observable_ids`.
*/
std::vector<realtype> sy;

Expand Down Expand Up @@ -412,7 +438,7 @@ class ReturnData : public ModelDimensions {
* @brief Initial state of the main simulation (shape `nx_rdata`).
*
* The corresponding state variable IDs can be obtained from
* `Model::getStateIds()`.
* `state_ids`.
*/
std::vector<realtype> x0;

Expand All @@ -422,7 +448,7 @@ class ReturnData : public ModelDimensions {
* The values of the state variables at the pre-equilibration steady state
* (shape `nx_rdata`).
* The corresponding state variable IDs can be obtained from
* `Model::getStateIds()`.
* `state_ids`.
*/
std::vector<realtype> x_ss;

Expand Down Expand Up @@ -526,10 +552,46 @@ class ReturnData : public ModelDimensions {
* @brief Indices of the parameters w.r.t. which sensitivities were
* computed.
*
* The indices refer to parameter IDs in Model::getParameterIds().
* The indices refer to parameter IDs in `free_parameter_ids`.
*/
std::vector<int> plist;

/** IDs of the free parameters */
std::span<std::string_view const> free_parameter_ids;

/** Names of the free parameters */
std::span<std::string_view const> free_parameter_names;

/** IDs of the fixed parameters */
std::span<std::string_view const> fixed_parameter_ids;

/** Names of the fixed parameters */
std::span<std::string_view const> fixed_parameter_names;

/** IDs of state variables */
std::span<std::string_view const> state_ids;

/** Names of state variables */
std::span<std::string_view const> state_names;

/** IDs of solver state variables */
std::span<std::string_view const> state_ids_solver;

/** Names of solver state variables */
std::span<std::string_view const> state_names_solver;

/** IDs of observables */
std::span<std::string_view const> observable_ids;

/** Names of observables */
std::span<std::string_view const> observable_names;

/** IDs of expressions */
std::span<std::string_view const> expression_ids;

/** Names of expressions */
std::span<std::string_view const> expression_names;

protected:
/** offset for sigma_residuals */
realtype sigma_offset{0.0};
Expand Down
28 changes: 0 additions & 28 deletions python/sdist/amici/sim/sundials/_swig_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
AmiciExpDataVector,
AmiciModel,
AmiciSolver,
RDataReporting,
SensitivityMethod,
SensitivityOrder,
Solver,
Expand Down Expand Up @@ -82,8 +81,6 @@ def run_simulation(
_get_ptr(solver), _get_ptr(edata), _get_ptr(model)
)
_log_simulation(rdata)
if solver.get_return_data_reporting_mode() == RDataReporting.full:
_ids_and_names_to_rdata(rdata, model)
return ReturnDataView(rdata)


Expand Down Expand Up @@ -129,8 +126,6 @@ def run_simulations(
)
for rdata in rdata_ptr_list:
_log_simulation(rdata)
if solver.get_return_data_reporting_mode() == RDataReporting.full:
_ids_and_names_to_rdata(rdata, model)

return [ReturnDataView(r) for r in rdata_ptr_list]

Expand Down Expand Up @@ -269,29 +264,6 @@ def _log_simulation(rdata: amici_swig.ReturnData):
)


def _ids_and_names_to_rdata(
rdata: amici_swig.ReturnData, model: amici_swig.Model
):
"""Copy entity IDs and names from a Model to ReturnData."""
for entity_type in (
"state",
"observable",
"expression",
"free_parameter",
"fixed_parameter",
):
for name_or_id in ("ids", "names"):
names_or_ids = getattr(model, f"get_{entity_type}_{name_or_id}")()
setattr(
rdata,
f"{entity_type.lower()}_{name_or_id.lower()}",
names_or_ids,
)

rdata.state_ids_solver = model.get_state_ids_solver()
rdata.state_names_solver = model.get_state_names_solver()


@contextlib.contextmanager
def _solver_settings(solver, sensi_method=None, sensi_order=None):
"""Context manager to temporarily apply solver settings."""
Expand Down
17 changes: 17 additions & 0 deletions python/tests/test_swig_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,23 @@ def test_rdataview(sbml_example_presimulation_module):
print(str(e))


def test_rdata_ids(sbml_example_presimulation_module):
"""Test that rdata IDs are correctly set."""
model_module = sbml_example_presimulation_module
model = model_module.get_model()

model.set_timepoints([0, 1, 2])
rdata = model.simulate()

assert isinstance(rdata.free_parameter_ids, tuple)
assert rdata.free_parameter_ids == model.get_free_parameter_ids()
assert rdata.fixed_parameter_ids == model.get_fixed_parameter_ids()
assert rdata.state_ids == model.get_state_ids()
assert rdata.state_ids_solver == model.get_state_ids_solver()
assert rdata.observable_ids == model.get_observable_ids()
assert rdata.expression_ids == model.get_expression_ids()


def test_python_exceptions(sbml_example_presimulation_module):
"""Test that C++ exceptions are correctly caught and re-raised in Python."""
from amici.sim.sundials import run_simulation
Expand Down
34 changes: 32 additions & 2 deletions src/rdata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@ ReturnData::ReturnData(Solver const& solver, Model const& model)
model.get_second_order_mode(), solver.get_sensitivity_order(),
solver.get_sensitivity_method(),
solver.get_return_data_reporting_mode(), model.has_quadratic_llh(),
model.get_add_sigma_residuals(), model.get_minimum_sigma_residuals()
model.get_add_sigma_residuals(), model.get_minimum_sigma_residuals(),
model.get_free_parameter_ids(), model.get_free_parameter_names(),
model.get_fixed_parameter_ids(), model.get_fixed_parameter_names(),
model.get_state_ids(), model.get_state_names(),
model.get_state_ids_solver(), model.get_state_names_solver(),
model.get_observable_ids(), model.get_observable_names(),
model.get_expression_ids(), model.get_expression_names()
) {}

ReturnData::ReturnData(
Expand All @@ -34,7 +40,19 @@ ReturnData::ReturnData(
std::vector<ParameterScaling> pscale_, SecondOrderMode o2mode_,
SensitivityOrder sensi_, SensitivityMethod sensi_meth_,
RDataReporting rdrm_, bool quadratic_llh_, bool sigma_res_,
realtype sigma_offset_
realtype sigma_offset_,
std::span<std::string_view const> free_parameter_ids_,
std::span<std::string_view const> free_parameter_names_,
std::span<std::string_view const> fixed_parameter_ids_,
std::span<std::string_view const> fixed_parameter_names_,
std::span<std::string_view const> state_ids_,
std::span<std::string_view const> state_names_,
std::span<std::string_view const> state_ids_solver_,
std::span<std::string_view const> state_names_solver_,
std::span<std::string_view const> observable_ids_,
std::span<std::string_view const> observable_names_,
std::span<std::string_view const> expression_ids_,
std::span<std::string_view const> expression_names_
)
: ModelDimensions(model_dimensions_)
, ts(std::move(ts_))
Expand All @@ -49,6 +67,18 @@ ReturnData::ReturnData(
, rdata_reporting(rdrm_)
, sigma_res(sigma_res_)
, plist(plist_)
, free_parameter_ids(free_parameter_ids_)
, free_parameter_names(free_parameter_names_)
, fixed_parameter_ids(fixed_parameter_ids_)
, fixed_parameter_names(fixed_parameter_names_)
, state_ids(state_ids_)
, state_names(state_names_)
, state_ids_solver(state_ids_solver_)
, state_names_solver(state_names_solver_)
, observable_ids(observable_ids_)
, observable_names(observable_names_)
, expression_ids(expression_ids_)
, expression_names(expression_names_)
, sigma_offset(sigma_offset_)
, nroots_(ne) {
model_dimensions_.validate();
Expand Down
6 changes: 3 additions & 3 deletions tests/cpp/testfunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ extern std::unique_ptr<amici::Model> get_model();
} // namespace generic_model

std::map<std::string, std::vector<std::string_view>> var_names {
{"p", {"p0", "p1", "p2", "p3", "p4"}},
{"k", {"k0", "k1", "k2"}},
{"p", {"p0", "p1", "p2", "p3", "p4", "p5"}},
{"k", {"k0", "k1", "k2", "k3", "k4", "p5"}},
{"x", {"x0", "x1", "x2", "x3", "x4", "x5"}},
{"y", {"y0", "y1", "y2"}}
{"y", {"y0", "y1", "y2", "y3", "y4", "y5"}}
};

std::span<std::string_view const> getVariableNames(std::string const& name, int length) {
Expand Down
Loading