Skip to content

Commit a3dc1cc

Browse files
authored
Refactor: Avoid xQ dummy arguments (#2811)
Overload `Solver::writeSolution` to avoid confusing dummy arguments for `xQ`. This also saves some unnecessary copying.
1 parent 8e89cda commit a3dc1cc

File tree

4 files changed

+26
-5
lines changed

4 files changed

+26
-5
lines changed

include/amici/solver.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,17 @@ class Solver {
656656
AmiVector& xQ
657657
) const;
658658

659+
/**
660+
* @brief write solution from forward simulation
661+
* @param t time
662+
* @param x state
663+
* @param dx derivative state
664+
* @param sx state sensitivity
665+
*/
666+
void writeSolution(
667+
realtype* t, AmiVector& x, AmiVector& dx, AmiVectorArray& sx
668+
) const;
669+
659670
/**
660671
* @brief write solution from backward simulation
661672
* @param t time

src/forwardproblem.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ void EventHandlingSimulator::run(
112112
int const status = solver_->run(next_t_stop);
113113
// sx will be copied from solver on demand if sensitivities
114114
// are computed
115-
solver_->writeSolution(&t_, ws_->x, ws_->dx, ws_->sx, ws_->dx);
115+
solver_->writeSolution(&t_, ws_->x, ws_->dx, ws_->sx);
116116

117117
if (status == AMICI_ILL_INPUT) {
118118
// clustering of roots => turn off root-finding
@@ -210,7 +210,7 @@ void ForwardProblem::handlePresimulation() {
210210

211211
std::vector<realtype> const timepoints{model->t0()};
212212
pre_simulator_.run(t_, edata, timepoints);
213-
solver->writeSolution(&t_, ws_.x, ws_.dx, ws_.sx, ws_.dx);
213+
solver->writeSolution(&t_, ws_.x, ws_.dx, ws_.sx);
214214
}
215215

216216
void ForwardProblem::handleMainSimulation() {
@@ -562,7 +562,7 @@ ForwardProblem::getAdjointUpdates(Model& model, ExpData const& edata) {
562562

563563
SimulationState EventHandlingSimulator::get_simulation_state() {
564564
if (std::isfinite(solver_->gett())) {
565-
solver_->writeSolution(&t_, ws_->x, ws_->dx, ws_->sx, ws_->dx);
565+
solver_->writeSolution(&t_, ws_->x, ws_->dx, ws_->sx);
566566
}
567567
auto state = SimulationState();
568568
state.t = t_;

src/solver.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,6 +1318,16 @@ void Solver::writeSolution(
13181318
dx.copy(getDerivativeState(*t));
13191319
}
13201320

1321+
void Solver::writeSolution(
1322+
realtype* t, AmiVector& x, AmiVector& dx, AmiVectorArray& sx
1323+
) const {
1324+
*t = gett();
1325+
if (sens_initialized_)
1326+
sx.copy(getStateSensitivity(*t));
1327+
x.copy(getState(*t));
1328+
dx.copy(getDerivativeState(*t));
1329+
}
1330+
13211331
void Solver::writeSolutionB(
13221332
realtype* t, AmiVector& xB, AmiVector& dxB, AmiVector& xQB, int const which
13231333
) const {

src/steadystateproblem.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ void SteadystateProblem::initializeForwardProblem(
424424
solver.setup(t0, &model, state_.x, state_.dx, state_.sx, sdx_);
425425
} else {
426426
// The solver was run before, extract current state from solver.
427-
solver.writeSolution(&state_.t, state_.x, state_.dx, state_.sx, xQ_);
427+
solver.writeSolution(&state_.t, state_.x, state_.dx, state_.sx);
428428
}
429429

430430
state_.t = t0;
@@ -726,7 +726,7 @@ void SteadystateProblem::runSteadystateSimulationFwd(
726726
// direction w.r.t. current t.
727727
solver.step(std::max(state_.t, 1.0) * 10);
728728

729-
solver.writeSolution(&state_.t, state_.x, state_.dx, state_.sx, xQ_);
729+
solver.writeSolution(&state_.t, state_.x, state_.dx, state_.sx);
730730
flagUpdatedState();
731731
}
732732

0 commit comments

Comments
 (0)