diff --git a/CHANGELOG.md b/CHANGELOG.md index f938db63a2..b80892c1e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -109,6 +109,10 @@ See also our [versioning policy](https://amici.readthedocs.io/en/latest/versioni * The default directory for model import changed, and a base directory can now be specified via the `AMICI_MODELS_ROOT` environment variable. See `amici.get_model_dir` for details. +* IDs and names of model entities are now not only accessible via `Model`, but + also via `ReturnData` + (`ReturnData.{free_parameter_ids,observable_ids,...}`, + `ReturnData.{free_parameter_names,observable_names,...}`). **Fixes** diff --git a/include/amici/rdata.h b/include/amici/rdata.h index 2f200641fc..cefd882ded 100644 --- a/include/amici/rdata.h +++ b/include/amici/rdata.h @@ -9,6 +9,8 @@ #include "amici/vector.h" #include +#include +#include namespace amici { class ReturnData; @@ -28,7 +30,7 @@ void serialize(Archive& ar, amici::ReturnData& r, unsigned int version); namespace amici { /** - * @brief Stores all data to be returned by amici::runAmiciSimulation. + * @brief Stores all data to be returned by amici::run_simulation. * * NOTE: multi-dimensional arrays are stored in row-major order (C-style) */ @@ -57,6 +59,18 @@ class ReturnData : public ModelDimensions { * @param sigma_res_ indicates whether additional residuals are to be added * for each sigma * @param sigma_offset_ offset to ensure real-valuedness of sigma residuals + * @param free_parameter_ids_ IDs of the free parameters + * @param free_parameter_names_ Names of the free parameters + * @param fixed_parameter_ids_ IDs of the fixed parameters + * @param fixed_parameter_names_ Names of the fixed parameters + * @param state_ids_ IDs of state variables + * @param state_names_ Names of state variables + * @param state_ids_solver_ IDs of solver state variables + * @param state_names_solver_ Names of solver state variables + * @param observable_ids_ IDs of observables + * @param observable_names_ Names of observables + * @param expression_ids_ IDs of expressions + * @param expression_names_ Names of expressions */ ReturnData( std::vector ts_, ModelDimensions const& model_dimensions_, @@ -64,7 +78,19 @@ class ReturnData : public ModelDimensions { std::vector pscale_, SecondOrderMode o2mode_, SensitivityOrder sensi_, SensitivityMethod sensi_meth_, RDataReporting rdrm_, bool quadratic_llh_, bool sigma_res_, - realtype sigma_offset_ + realtype sigma_offset_, + std::span free_parameter_ids_, + std::span free_parameter_names_, + std::span fixed_parameter_ids_, + std::span fixed_parameter_names_, + std::span state_ids_, + std::span state_names_, + std::span state_ids_solver_, + std::span state_names_solver_, + std::span observable_ids_, + std::span observable_names_, + std::span expression_ids_, + std::span expression_names_ ); /** @@ -112,7 +138,7 @@ class ReturnData : public ModelDimensions { * (shape `nx_solver` x `nx_solver`, row-major) evaluated at `t_last`. * * The corresponding state variable IDs can be obtained from - * `Model::getStateIdsSolver()`. + * `state_ids_solver()`. */ std::vector J; @@ -124,7 +150,7 @@ class ReturnData : public ModelDimensions { * at timepoints `ReturnData::ts` (shape `nt` x `nw`, row major). * * The corresponding expression IDs can be obtained from - * `Model::getExpressionIds()`. + * `expression_ids`. */ std::vector w; @@ -171,7 +197,7 @@ class ReturnData : public ModelDimensions { * (shape `nt` x `nx_rdata`, row-major). * * The corresponding state variable IDs can be obtained from - * `Model::getStateIds()`. + * `state_ids`. */ std::vector x; @@ -184,7 +210,7 @@ class ReturnData : public ModelDimensions { * (shape `nt` x `nplist` x `nx_rdata`, row-major). * * The corresponding state variable IDs can be obtained from - * `Model::getStateIds()`. + * `state_ids`. */ std::vector sx; @@ -195,7 +221,7 @@ class ReturnData : public ModelDimensions { * (shape `nt` x `ny`, row-major). * * The corresponding observable IDs can be obtained from - * `Model::getObservableIds()`. + * `observable_ids`. */ std::vector y; @@ -211,7 +237,7 @@ class ReturnData : public ModelDimensions { * (shape `nt` x `nplist` x `ny`, row-major). * * The corresponding observable IDs can be obtained from - * `Model::getObservableIds()`. + * `observable_ids`. */ std::vector sy; @@ -412,7 +438,7 @@ class ReturnData : public ModelDimensions { * @brief Initial state of the main simulation (shape `nx_rdata`). * * The corresponding state variable IDs can be obtained from - * `Model::getStateIds()`. + * `state_ids`. */ std::vector x0; @@ -422,7 +448,7 @@ class ReturnData : public ModelDimensions { * The values of the state variables at the pre-equilibration steady state * (shape `nx_rdata`). * The corresponding state variable IDs can be obtained from - * `Model::getStateIds()`. + * `state_ids`. */ std::vector x_ss; @@ -526,10 +552,46 @@ class ReturnData : public ModelDimensions { * @brief Indices of the parameters w.r.t. which sensitivities were * computed. * - * The indices refer to parameter IDs in Model::getParameterIds(). + * The indices refer to parameter IDs in `free_parameter_ids`. */ std::vector plist; + /** IDs of the free parameters */ + std::span free_parameter_ids; + + /** Names of the free parameters */ + std::span free_parameter_names; + + /** IDs of the fixed parameters */ + std::span fixed_parameter_ids; + + /** Names of the fixed parameters */ + std::span fixed_parameter_names; + + /** IDs of state variables */ + std::span state_ids; + + /** Names of state variables */ + std::span state_names; + + /** IDs of solver state variables */ + std::span state_ids_solver; + + /** Names of solver state variables */ + std::span state_names_solver; + + /** IDs of observables */ + std::span observable_ids; + + /** Names of observables */ + std::span observable_names; + + /** IDs of expressions */ + std::span expression_ids; + + /** Names of expressions */ + std::span expression_names; + protected: /** offset for sigma_residuals */ realtype sigma_offset{0.0}; diff --git a/python/sdist/amici/sim/sundials/_swig_wrappers.py b/python/sdist/amici/sim/sundials/_swig_wrappers.py index 564f05b5e0..6d834e6f0a 100644 --- a/python/sdist/amici/sim/sundials/_swig_wrappers.py +++ b/python/sdist/amici/sim/sundials/_swig_wrappers.py @@ -21,7 +21,6 @@ AmiciExpDataVector, AmiciModel, AmiciSolver, - RDataReporting, SensitivityMethod, SensitivityOrder, Solver, @@ -82,8 +81,6 @@ def run_simulation( _get_ptr(solver), _get_ptr(edata), _get_ptr(model) ) _log_simulation(rdata) - if solver.get_return_data_reporting_mode() == RDataReporting.full: - _ids_and_names_to_rdata(rdata, model) return ReturnDataView(rdata) @@ -129,8 +126,6 @@ def run_simulations( ) for rdata in rdata_ptr_list: _log_simulation(rdata) - if solver.get_return_data_reporting_mode() == RDataReporting.full: - _ids_and_names_to_rdata(rdata, model) return [ReturnDataView(r) for r in rdata_ptr_list] @@ -269,29 +264,6 @@ def _log_simulation(rdata: amici_swig.ReturnData): ) -def _ids_and_names_to_rdata( - rdata: amici_swig.ReturnData, model: amici_swig.Model -): - """Copy entity IDs and names from a Model to ReturnData.""" - for entity_type in ( - "state", - "observable", - "expression", - "free_parameter", - "fixed_parameter", - ): - for name_or_id in ("ids", "names"): - names_or_ids = getattr(model, f"get_{entity_type}_{name_or_id}")() - setattr( - rdata, - f"{entity_type.lower()}_{name_or_id.lower()}", - names_or_ids, - ) - - rdata.state_ids_solver = model.get_state_ids_solver() - rdata.state_names_solver = model.get_state_names_solver() - - @contextlib.contextmanager def _solver_settings(solver, sensi_method=None, sensi_order=None): """Context manager to temporarily apply solver settings.""" diff --git a/python/tests/test_swig_interface.py b/python/tests/test_swig_interface.py index 1cf4a638f1..3b97bdc09c 100644 --- a/python/tests/test_swig_interface.py +++ b/python/tests/test_swig_interface.py @@ -604,6 +604,23 @@ def test_rdataview(sbml_example_presimulation_module): print(str(e)) +def test_rdata_ids(sbml_example_presimulation_module): + """Test that rdata IDs are correctly set.""" + model_module = sbml_example_presimulation_module + model = model_module.get_model() + + model.set_timepoints([0, 1, 2]) + rdata = model.simulate() + + assert isinstance(rdata.free_parameter_ids, tuple) + assert rdata.free_parameter_ids == model.get_free_parameter_ids() + assert rdata.fixed_parameter_ids == model.get_fixed_parameter_ids() + assert rdata.state_ids == model.get_state_ids() + assert rdata.state_ids_solver == model.get_state_ids_solver() + assert rdata.observable_ids == model.get_observable_ids() + assert rdata.expression_ids == model.get_expression_ids() + + def test_python_exceptions(sbml_example_presimulation_module): """Test that C++ exceptions are correctly caught and re-raised in Python.""" from amici.sim.sundials import run_simulation diff --git a/src/rdata.cpp b/src/rdata.cpp index dd9eef0756..bae76cdef0 100644 --- a/src/rdata.cpp +++ b/src/rdata.cpp @@ -25,7 +25,13 @@ ReturnData::ReturnData(Solver const& solver, Model const& model) model.get_second_order_mode(), solver.get_sensitivity_order(), solver.get_sensitivity_method(), solver.get_return_data_reporting_mode(), model.has_quadratic_llh(), - model.get_add_sigma_residuals(), model.get_minimum_sigma_residuals() + model.get_add_sigma_residuals(), model.get_minimum_sigma_residuals(), + model.get_free_parameter_ids(), model.get_free_parameter_names(), + model.get_fixed_parameter_ids(), model.get_fixed_parameter_names(), + model.get_state_ids(), model.get_state_names(), + model.get_state_ids_solver(), model.get_state_names_solver(), + model.get_observable_ids(), model.get_observable_names(), + model.get_expression_ids(), model.get_expression_names() ) {} ReturnData::ReturnData( @@ -34,7 +40,19 @@ ReturnData::ReturnData( std::vector pscale_, SecondOrderMode o2mode_, SensitivityOrder sensi_, SensitivityMethod sensi_meth_, RDataReporting rdrm_, bool quadratic_llh_, bool sigma_res_, - realtype sigma_offset_ + realtype sigma_offset_, + std::span free_parameter_ids_, + std::span free_parameter_names_, + std::span fixed_parameter_ids_, + std::span fixed_parameter_names_, + std::span state_ids_, + std::span state_names_, + std::span state_ids_solver_, + std::span state_names_solver_, + std::span observable_ids_, + std::span observable_names_, + std::span expression_ids_, + std::span expression_names_ ) : ModelDimensions(model_dimensions_) , ts(std::move(ts_)) @@ -49,6 +67,18 @@ ReturnData::ReturnData( , rdata_reporting(rdrm_) , sigma_res(sigma_res_) , plist(plist_) + , free_parameter_ids(free_parameter_ids_) + , free_parameter_names(free_parameter_names_) + , fixed_parameter_ids(fixed_parameter_ids_) + , fixed_parameter_names(fixed_parameter_names_) + , state_ids(state_ids_) + , state_names(state_names_) + , state_ids_solver(state_ids_solver_) + , state_names_solver(state_names_solver_) + , observable_ids(observable_ids_) + , observable_names(observable_names_) + , expression_ids(expression_ids_) + , expression_names(expression_names_) , sigma_offset(sigma_offset_) , nroots_(ne) { model_dimensions_.validate(); diff --git a/tests/cpp/testfunctions.cpp b/tests/cpp/testfunctions.cpp index f683f2b6a5..f68aedf880 100644 --- a/tests/cpp/testfunctions.cpp +++ b/tests/cpp/testfunctions.cpp @@ -24,10 +24,10 @@ extern std::unique_ptr get_model(); } // namespace generic_model std::map> var_names { - {"p", {"p0", "p1", "p2", "p3", "p4"}}, - {"k", {"k0", "k1", "k2"}}, + {"p", {"p0", "p1", "p2", "p3", "p4", "p5"}}, + {"k", {"k0", "k1", "k2", "k3", "k4", "p5"}}, {"x", {"x0", "x1", "x2", "x3", "x4", "x5"}}, - {"y", {"y0", "y1", "y2"}} + {"y", {"y0", "y1", "y2", "y3", "y4", "y5"}} }; std::span getVariableNames(std::string const& name, int length) {