Skip to content

Commit 8e89cda

Browse files
authored
Refactor SteadystateProblem, simplify solver creation (#2810)
Avoid unnecessary entanglement of forward/bwd solver setup. Less complex control flow.
1 parent 386d62b commit 8e89cda

File tree

3 files changed

+29
-65
lines changed

3 files changed

+29
-65
lines changed

include/amici/steadystateproblem.h

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -493,21 +493,6 @@ class SteadystateProblem {
493493
*/
494494
void runSteadystateSimulationBwd(Solver const& solver, Model& model);
495495

496-
/**
497-
* @brief Create and initialize a CVodeSolver instance for
498-
* preequilibration simulation.
499-
* @param solver Solver instance
500-
* @param model Model instance.
501-
* @param forwardSensis flag switching on integration with FSA
502-
* @param backward flag switching on quadrature computation
503-
* @param t0 Initial time for the steady state simulation.
504-
* @return A unique pointer to the created Solver instance.
505-
*/
506-
std::unique_ptr<Solver> createSteadystateSimSolver(
507-
Solver const& solver, Model& model, bool forwardSensis, bool backward,
508-
realtype t0
509-
) const;
510-
511496
/**
512497
* @brief Initialize forward computation
513498
* @param it Index of the current output time point.

src/newton_solver.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ NewtonSolver::NewtonSolver(
3030
break;
3131
default:
3232
throw NewtonFailure(
33-
AMICI_NOT_IMPLEMENTED, "Unknown linear solver type"
33+
AMICI_NOT_IMPLEMENTED,
34+
"Invalid solver for steady state simulation"
3435
);
3536
}
3637
} catch (NewtonFailure const&) {

src/steadystateproblem.cpp

Lines changed: 27 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -348,18 +348,27 @@ SteadyStateStatus SteadystateProblem::findSteadyStateBySimulation(
348348
) {
349349
try {
350350
if (it < 0) {
351-
// Preequilibration? -> Create a new solver instance for simulation
352-
bool integrateSensis = requires_state_sensitivities(
353-
model, solver, it, SteadyStateContext::solverCreation
354-
);
355-
auto newtonSimSolver = createSteadystateSimSolver(
356-
solver, model, integrateSensis, false, t0
357-
);
358-
runSteadystateSimulationFwd(*newtonSimSolver, model);
351+
// Preequilibration -> Create a new solver instance for simulation
352+
auto sim_solver = std::unique_ptr<Solver>(solver.clone());
353+
sim_solver->logger = solver.logger;
354+
355+
// do we need sensitivities?
356+
if (requires_state_sensitivities(
357+
model, solver, it, SteadyStateContext::solverCreation
358+
)) {
359+
// need forward to compute sx0
360+
sim_solver->setSensitivityMethod(SensitivityMethod::forward);
361+
} else {
362+
sim_solver->setSensitivityMethod(SensitivityMethod::none);
363+
sim_solver->setSensitivityOrder(SensitivityOrder::none);
364+
}
365+
sim_solver->setup(t0, &model, state_.x, state_.dx, state_.sx, sdx_);
366+
runSteadystateSimulationFwd(*sim_solver, model);
359367
} else {
360-
// Solver was already created, use this one
368+
// Postequilibration -> Solver was already created, use that one
361369
runSteadystateSimulationFwd(solver, model);
362370
}
371+
363372
return SteadyStateStatus::success;
364373
} catch (IntegrationFailure const& ex) {
365374
switch (ex.error_code) {
@@ -498,11 +507,18 @@ void SteadystateProblem::getQuadratureBySimulation(
498507
// xQ was written in getQuadratureByLinSolve() -> set to zero
499508
xQ_.zero();
500509

501-
auto simSolver = createSteadystateSimSolver(solver, model, false, true, t0);
510+
auto sim_solver = std::unique_ptr<Solver>(solver.clone());
511+
sim_solver->logger = solver.logger;
512+
sim_solver->setSensitivityMethod(SensitivityMethod::none);
513+
sim_solver->setSensitivityOrder(SensitivityOrder::none);
514+
sim_solver->setup(t0, &model, xB_, xB_, state_.sx, sdx_);
515+
sim_solver->setupSteadystate(
516+
t0, &model, state_.x, state_.dx, xB_, xB_, xQ_
517+
);
502518

503519
// perform integration and quadrature
504520
try {
505-
runSteadystateSimulationBwd(*simSolver, model);
521+
runSteadystateSimulationBwd(*sim_solver, model);
506522
hasQuadrature_ = true;
507523
} catch (NewtonFailure const&) {
508524
hasQuadrature_ = false;
@@ -779,44 +795,6 @@ void SteadystateProblem::runSteadystateSimulationBwd(
779795
}
780796
}
781797

782-
std::unique_ptr<Solver> SteadystateProblem::createSteadystateSimSolver(
783-
Solver const& solver, Model& model, bool forwardSensis, bool backward,
784-
realtype t0
785-
) const {
786-
switch (solver.getLinearSolver()) {
787-
case LinearSolver::dense:
788-
case LinearSolver::KLU:
789-
break;
790-
default:
791-
throw AmiException("Invalid solver for steady state simulation");
792-
}
793-
794-
auto sim_solver = std::unique_ptr<Solver>(solver.clone());
795-
796-
sim_solver->logger = solver.logger;
797-
798-
// do we need sensitivities?
799-
if (forwardSensis) {
800-
// need forward to compute sx0
801-
sim_solver->setSensitivityMethod(SensitivityMethod::forward);
802-
} else {
803-
sim_solver->setSensitivityMethod(SensitivityMethod::none);
804-
sim_solver->setSensitivityOrder(SensitivityOrder::none);
805-
}
806-
// use x and sx as dummies for dx and sdx
807-
// (they won't get touched in a CVodeSolver)
808-
if (backward) {
809-
sim_solver->setup(t0, &model, xB_, xB_, state_.sx, sdx_);
810-
sim_solver->setupSteadystate(
811-
t0, &model, state_.x, state_.dx, xB_, xB_, xQ_
812-
);
813-
} else {
814-
sim_solver->setup(t0, &model, state_.x, state_.dx, state_.sx, sdx_);
815-
}
816-
817-
return sim_solver;
818-
}
819-
820798
void SteadystateProblem::getAdjointUpdates(
821799
Model& model, ExpData const& edata, std::vector<realtype>& dJydx
822800
) {

0 commit comments

Comments
 (0)