Skip to content

Commit 83c432c

Browse files
authored
Refactor SimulationState, add SolutionState (#2887)
Often `SimulationState` is passed around where only a subset of the data is needed. Extract that to a separate `SolutionState`. This makes storing intermediate results a bit easier and avoids missing individual components.
1 parent 29a9c79 commit 83c432c

File tree

11 files changed

+262
-270
lines changed

11 files changed

+262
-270
lines changed

include/amici/backwardproblem.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ class SteadyStateBackwardProblem {
151151
* @param ws Workspace for backward simulation
152152
*/
153153
SteadyStateBackwardProblem(
154-
Solver const& solver, Model& model, SimulationState& final_state,
154+
Solver const& solver, Model& model, SolutionState& final_state,
155155
gsl::not_null<BwdSimWorkspace*> ws
156156
);
157157

@@ -248,7 +248,7 @@ class SteadyStateBackwardProblem {
248248
AmiVector xQ_;
249249

250250
/** Final state from pre/post-equilibration forward problem */
251-
SimulationState& final_state_;
251+
SolutionState& final_state_;
252252

253253
/** Newton solver */
254254
NewtonSolver newton_solver_;

include/amici/forwardproblem.h

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,10 @@ struct FwdSimWorkspace {
113113
FwdSimWorkspace(
114114
gsl::not_null<Model*> const& model, gsl::not_null<Solver*> solver
115115
)
116-
: x(model->nx_solver, solver->getSunContext())
116+
: sol(NAN, model->nx_solver, model->nplist(), solver->getSunContext())
117117
, x_old(model->nx_solver, solver->getSunContext())
118-
, dx(model->nx_solver, solver->getSunContext())
119118
, xdot(model->nx_solver, solver->getSunContext())
120119
, xdot_old(model->nx_solver, solver->getSunContext())
121-
, sx(model->nx_solver, model->nplist(), solver->getSunContext())
122120
, sdx(model->nx_solver, model->nplist(), solver->getSunContext())
123121
, stau(model->nplist())
124122
, roots_found(model->ne, 0)
@@ -127,27 +125,18 @@ struct FwdSimWorkspace {
127125
, rootvals(gsl::narrow<decltype(rootvals)::size_type>(model->ne), 0.0)
128126

129127
{}
130-
/** current simulation time */
131-
realtype t{NAN};
132-
133-
/** state vector (dimension: nx_solver) */
134-
AmiVector x;
128+
/** Current solution state */
129+
SolutionState sol;
135130

136131
/** old state vector (dimension: nx_solver) */
137132
AmiVector x_old;
138133

139-
/** differential state vector (dimension: nx_solver) */
140-
AmiVector dx;
141-
142134
/** time derivative state vector (dimension: nx_solver) */
143135
AmiVector xdot;
144136

145137
/** old time derivative state vector (dimension: nx_solver) */
146138
AmiVector xdot_old;
147139

148-
/** sensitivity state vector array (dimension: nx_cl x nplist, row-major) */
149-
AmiVectorArray sx;
150-
151140
/** differential sensitivity state vector array
152141
* (dimension: nx_cl x nplist, row-major) */
153142
AmiVectorArray sdx;
@@ -362,7 +351,7 @@ class EventHandlingSimulator {
362351
}))
363352
return;
364353

365-
result.discs.emplace_back(ws_->t);
354+
result.discs.emplace_back(ws_->sol.t);
366355
store_event(edata);
367356
}
368357

@@ -428,15 +417,15 @@ class SteadystateProblem {
428417
* @return x
429418
*/
430419
[[nodiscard]] AmiVector const& getState() const {
431-
return period_result_.final_state_.x;
420+
return period_result_.final_state_.sol.x;
432421
}
433422

434423
/**
435424
* @brief Return state sensitivity at steady state
436425
* @return sx
437426
*/
438427
[[nodiscard]] AmiVectorArray const& getStateSensitivity() const {
439-
return period_result_.final_state_.sx;
428+
return period_result_.final_state_.sol.sx;
440429
}
441430

442431
/**
@@ -460,7 +449,7 @@ class SteadystateProblem {
460449
* @return Time at which steady state was found (model time units).
461450
*/
462451
[[nodiscard]] realtype getSteadyStateTime() const {
463-
return period_result_.final_state_.t;
452+
return period_result_.final_state_.sol.t;
464453
}
465454

466455
/**
@@ -504,7 +493,6 @@ class SteadystateProblem {
504493

505494
/**
506495
* @brief Try to determine the steady state by using Newton's method.
507-
* @param model Model instance.
508496
* @param newton_retry Flag indicating whether Newton's method is being
509497
* relaunched.
510498
*/
@@ -672,18 +660,12 @@ class ForwardProblem {
672660
*/
673661
std::vector<realtype> getAdjointUpdates(Model& model, ExpData const& edata);
674662

675-
/**
676-
* @brief Accessor for t
677-
* @return t
678-
*/
679-
[[nodiscard]] realtype getTime() const { return t_; }
680-
681663
/**
682664
* @brief Accessor for sx
683665
* @return sx
684666
*/
685667
[[nodiscard]] AmiVectorArray const& getStateSensitivity() const {
686-
return ws_.sx;
668+
return ws_.sol.sx;
687669
}
688670

689671
/**
@@ -713,7 +695,7 @@ class ForwardProblem {
713695
* @return time point
714696
*/
715697
[[nodiscard]] realtype getFinalTime() const {
716-
return main_simulator_.result.final_state_.t;
698+
return main_simulator_.result.final_state_.sol.t;
717699
}
718700

719701
/**
@@ -732,7 +714,8 @@ class ForwardProblem {
732714
*/
733715
[[nodiscard]] SimulationState const&
734716
getSimulationStateTimepoint(int const it) const {
735-
if (model->getTimepoint(it) == main_simulator_.result.initial_state_.t)
717+
if (model->getTimepoint(it)
718+
== main_simulator_.result.initial_state_.sol.t)
736719
return getInitialSimulationState();
737720
auto const map_iter = main_simulator_.result.timepoint_states_.find(
738721
model->getTimepoint(it)
@@ -877,9 +860,6 @@ class ForwardProblem {
877860
* (dimension nJ x nx x nMaxEvent, ordering =?) */
878861
std::vector<realtype> dJzdx_;
879862

880-
/** current time */
881-
realtype t_;
882-
883863
/** flag to indicate whether solver was preeinitialized via preequilibration
884864
*/
885865
bool preequilibrated_{false};

include/amici/model_state.h

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -462,12 +462,11 @@ struct ModelStateDerived {
462462
};
463463

464464
/**
465-
* @brief implements an exchange format to store and transfer the state of a
466-
* simulation at a specific timepoint.
465+
* @brief Container for the IVP solution state at a specific timepoint.
467466
*/
468-
struct SimulationState {
467+
struct SolutionState {
469468
/** timepoint */
470-
realtype t;
469+
realtype t{NAN};
471470
/**
472471
* partial state vector, excluding states eliminated from conservation laws
473472
*/
@@ -482,8 +481,35 @@ struct SimulationState {
482481
* conservation laws
483482
*/
484483
AmiVectorArray sx;
484+
485+
SolutionState() = default;
486+
487+
/**
488+
* @brief Constructor.
489+
* @param t_ Current timepoint.
490+
* @param nx_solver Number of solver state variables.
491+
* @param nplist Number of parameter w.r.t. which to compute sensitivities.
492+
* @param ctx SUNDIALS context.
493+
*/
494+
SolutionState(
495+
realtype t_, long int nx_solver, long int nplist, SUNContext ctx
496+
)
497+
: t(t_)
498+
, x(nx_solver, ctx)
499+
, dx(nx_solver, ctx)
500+
, sx(nx_solver, nplist, ctx) {}
501+
};
502+
503+
/**
504+
* @brief implements an exchange format to store and transfer the state of a
505+
* simulation at a specific timepoint.
506+
*/
507+
struct SimulationState {
508+
/** Solution state */
509+
SolutionState sol;
510+
485511
/** state of the model that was used for simulation */
486-
ModelState state;
512+
ModelState mod;
487513
};
488514

489515
} // namespace amici

include/amici/newton_solver.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define amici_newton_solver_h
33

44
#include "amici/defines.h"
5+
#include "amici/model_state.h"
56
#include "amici/sundials_linsol_wrapper.h"
67
#include "amici/vector.h"
78

@@ -46,6 +47,15 @@ struct DEStateView {
4647
: t(t_)
4748
, x(x_)
4849
, dx(dx_) {}
50+
51+
/**
52+
* @brief Construct a DEStateView from a SolutionState reference.
53+
* @param sol
54+
*/
55+
DEStateView(SolutionState& sol)
56+
: t(sol.t)
57+
, x(sol.x)
58+
, dx(sol.dx) {}
4959
};
5060

5161
/**

include/amici/rdata.h

Lines changed: 25 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -642,26 +642,20 @@ class ReturnData : public ModelDimensions {
642642
template <class T>
643643
void
644644
storeJacobianAndDerivativeInReturnData(T const& problem, Model& model) {
645-
auto simulation_state = problem.getFinalSimulationState();
646-
model.setModelState(simulation_state.state);
647-
645+
auto const& simulation_state = problem.getFinalSimulationState();
646+
model.setModelState(simulation_state.mod);
647+
auto const& sol = simulation_state.sol;
648648
sundials::Context sunctx;
649649
AmiVector xdot(nx_solver, sunctx);
650650
if (!this->xdot.empty() || !this->J.empty())
651-
model.fxdot(
652-
simulation_state.t, simulation_state.x, simulation_state.dx,
653-
xdot
654-
);
651+
model.fxdot(sol.t, sol.x, sol.dx, xdot);
655652

656653
if (!this->xdot.empty())
657654
writeSlice(xdot, this->xdot);
658655

659656
if (!this->J.empty()) {
660657
SUNMatrixWrapper J(nx_solver, nx_solver, sunctx);
661-
model.fJ(
662-
simulation_state.t, 0.0, simulation_state.x,
663-
simulation_state.dx, xdot, J
664-
);
658+
model.fJ(sol.t, 0.0, sol.x, sol.dx, xdot, J);
665659
// CVODES uses colmajor, so we need to transform to rowmajor
666660
for (int ix = 0; ix < model.nx_solver; ix++)
667661
for (int jx = 0; jx < model.nx_solver; jx++)
@@ -674,13 +668,11 @@ class ReturnData : public ModelDimensions {
674668
* @brief Residual function
675669
* @param it time index
676670
* @param model model that was used for forward/backward simulation
677-
* @param simulation_state simulation state the timepoint `it`
671+
* @param sol Solution state the timepoint `it`
678672
* @param edata ExpData instance containing observable data
679673
*/
680-
void fres(
681-
int it, Model& model, SimulationState const& simulation_state,
682-
ExpData const& edata
683-
);
674+
void
675+
fres(int it, Model& model, SolutionState const& sol, ExpData const& edata);
684676

685677
/**
686678
* @brief Chi-squared function
@@ -693,25 +685,21 @@ class ReturnData : public ModelDimensions {
693685
* @brief Residual sensitivity function
694686
* @param it time index
695687
* @param model model that was used for forward/backward simulation
696-
* @param simulation_state simulation state the timepoint `it`
688+
* @param sol solution state the timepoint `it`
697689
* @param edata ExpData instance containing observable data
698690
*/
699-
void fsres(
700-
int it, Model& model, SimulationState const& simulation_state,
701-
ExpData const& edata
702-
);
691+
void
692+
fsres(int it, Model& model, SolutionState const& sol, ExpData const& edata);
703693

704694
/**
705695
* @brief Fisher information matrix function
706696
* @param it time index
707697
* @param model model that was used for forward/backward simulation
708-
* @param simulation_state simulation state the timepoint `it`
698+
* @param sol Solution state the timepoint `it`
709699
* @param edata ExpData instance containing observable data
710700
*/
711-
void fFIM(
712-
int it, Model& model, SimulationState const& simulation_state,
713-
ExpData const& edata
714-
);
701+
void
702+
fFIM(int it, Model& model, SolutionState const& sol, ExpData const& edata);
715703

716704
/**
717705
* @brief Set likelihood, state variables, outputs and respective
@@ -755,54 +743,49 @@ class ReturnData : public ModelDimensions {
755743
* the model state was set appropriately
756744
* @param it timepoint index
757745
* @param model model that was used in forward solve
758-
* @param simulation_state simulation state the timepoint `it`
746+
* @param sol solution state the timepoint `it`
759747
* @param edata ExpData instance carrying experimental data
760748
*/
761749
void getDataOutput(
762-
int it, Model& model, SimulationState const& simulation_state,
763-
ExpData const* edata
750+
int it, Model& model, SolutionState const& sol, ExpData const* edata
764751
);
765752

766753
/**
767754
* @brief Extracts data information for forward sensitivity analysis,
768755
* expects that the model state was set appropriately
769756
* @param it index of current timepoint
770757
* @param model model that was used in forward solve
771-
* @param simulation_state simulation state the timepoint `it`
758+
* @param sol Solution state the timepoint `it`
772759
* @param edata ExpData instance carrying experimental data
773760
*/
774761
void getDataSensisFSA(
775-
int it, Model& model, SimulationState const& simulation_state,
776-
ExpData const* edata
762+
int it, Model& model, SolutionState const& sol, ExpData const* edata
777763
);
778764

779765
/**
780766
* @brief Extracts output information for events, expects that the model
781767
* state was set appropriately
782-
* @param t event timepoint
783768
* @param rootidx information about which roots fired
784769
* (1 indicating fired, 0/-1 for not)
785770
* @param model model that was used in forward solve
786-
* @param simulation_state simulation state the timepoint `it`
771+
* @param sol Solution state the timepoint `it`
787772
* @param edata ExpData instance carrying experimental data
788773
*/
789774
void getEventOutput(
790-
realtype t, std::vector<int> const& rootidx, Model& model,
791-
SimulationState const& simulation_state, ExpData const* edata
775+
std::vector<int> const& rootidx, Model& model, SolutionState const& sol,
776+
ExpData const* edata
792777
);
793778

794779
/**
795780
* @brief Extracts event information for forward sensitivity analysis,
796781
* expects the model state was set appropriately
797782
* @param ie index of event type
798-
* @param t event timepoint
799783
* @param model model that was used in forward solve
800-
* @param simulation_state simulation state the timepoint `it`
784+
* @param sol Solution state the timepoint `it`
801785
* @param edata ExpData instance carrying experimental data
802786
*/
803787
void getEventSensisFSA(
804-
int ie, realtype t, Model& model,
805-
SimulationState const& simulation_state, ExpData const* edata
788+
int ie, Model& model, SolutionState const& sol, ExpData const* edata
806789
);
807790

808791
/**
@@ -826,13 +809,13 @@ class ReturnData : public ModelDimensions {
826809
* (llhS0), if no preequilibration was run or if forward sensitivities were
827810
* used
828811
* @param model model that was used for forward/backward simulation
829-
* @param simulation_state simulation state the timepoint `it`
812+
* @param sol Solution state the timepoint `it`
830813
* @param llhS0 contribution to likelihood for initial state sensitivities
831814
* @param xB vector with final adjoint state
832815
* (excluding conservation laws)
833816
*/
834817
void handleSx0Forward(
835-
Model const& model, SimulationState const& simulation_state,
818+
Model const& model, SolutionState const& sol,
836819
std::vector<realtype>& llhS0, AmiVector const& xB
837820
) const;
838821
};

0 commit comments

Comments
 (0)