Skip to content

Commit 18b27f3

Browse files
authored
Refactor model/solver handling in SteadyStateBackwardProblem (#2869)
Store solver and model, use preequilibration solver, simplify parameter lists and initialization.
1 parent 8d85846 commit 18b27f3

File tree

3 files changed

+58
-74
lines changed

3 files changed

+58
-74
lines changed

include/amici/steadystateproblem.h

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,12 @@ class SteadystateProblem {
308308
*/
309309
[[nodiscard]] bool checkSteadyStateSuccess() const;
310310

311+
/**
312+
* @brief Get the preequilibration solver.
313+
* @return The preequilibration solver.
314+
*/
315+
[[nodiscard]] Solver const* get_solver() const { return solver_; }
316+
311317
private:
312318
/**
313319
* @brief Handle the computation of the steady state.
@@ -467,15 +473,10 @@ class SteadyStateBackwardProblem {
467473
* Integrates over the adjoint state backward in time by solving a linear
468474
* system of equations, which gives the analytical solution.
469475
*
470-
* @param solver The solver instance
471-
* @param model The model instance
472476
* @param xB0 Initial adjoint state vector.
473-
* @param is_preeq Flag indicating whether this is a preequilibration.
474477
* @param t0 Initial time for the steady state simulation.
475478
*/
476-
void
477-
run(Solver const& solver, Model& model, AmiVector const& xB0, bool is_preeq,
478-
realtype t0);
479+
void run(AmiVector const& xB0, realtype t0);
479480

480481
/**
481482
* @brief Return the quadratures from pre- or postequilibration
@@ -521,37 +522,27 @@ class SteadyStateBackwardProblem {
521522
* @brief Launch backward simulation if Newton solver or linear system solve
522523
* fail or are disabled.
523524
* @param solver Solver instance.
524-
* @param model Model instance.
525525
*/
526-
void run_simulation(Solver const& solver, Model& model);
526+
void run_simulation(Solver const& solver);
527527

528528
/**
529529
* @brief Compute quadratures in adjoint mode
530-
* @param solver Solver instance.
531-
* @param model Model instance.
532530
* @param t0 Initial time for the steady state simulation.
533531
*/
534-
void compute_steady_state_quadrature(
535-
Solver const& solver, Model& model, realtype t0
536-
);
532+
void compute_steady_state_quadrature(realtype t0);
537533

538534
/**
539535
* @brief Compute the quadrature in steady state backward mode by
540536
* solving the linear system defined by the backward Jacobian.
541-
* @param model Model instance.
542537
*/
543-
void compute_quadrature_by_lin_solve(Model& model);
538+
void compute_quadrature_by_lin_solve();
544539

545540
/**
546541
* @brief Computes the quadrature in steady state backward mode by
547542
* numerical integration of xB forward in time.
548-
* @param solver Solver instance.
549-
* @param model Model instance.
550543
* @param t0 Initial time for the steady state simulation.
551544
*/
552-
void compute_quadrature_by_simulation(
553-
Solver const& solver, Model& model, realtype t0
554-
);
545+
void compute_quadrature_by_simulation(realtype t0);
555546

556547
/** CPU time for solving the backward problem (milliseconds) */
557548
double cpu_timeB_{0.0};
@@ -582,6 +573,9 @@ class SteadyStateBackwardProblem {
582573
* checks during simulation.
583574
*/
584575
bool newton_step_conv_{false};
576+
577+
Model* model_{nullptr};
578+
Solver const* solver_{nullptr};
585579
};
586580

587581
} // namespace amici

src/backwardproblem.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,31 @@ void BackwardProblem::workBackwardProblem() {
6565
}
6666

6767
// handle pre-equilibration
68-
if (preeq_problem_) {
68+
if (preeq_problem_
69+
&& preeq_problem_->get_solver()->getSensitivityMethodPreequilibration()
70+
== SensitivityMethod::adjoint) {
71+
auto preeq_solver = preeq_problem_->get_solver();
72+
6973
ConditionContext cc2(
7074
model_, edata_, FixedParameterContext::preequilibration
7175
);
7276
auto const t0
7377
= std::isnan(model_->t0Preeq()) ? model_->t0() : model_->t0Preeq();
7478
auto final_state = preeq_problem_->getFinalSimulationState();
79+
80+
// If we need to reinitialize solver states, this won't work yet.
81+
if (model_->nx_reinit() > 0)
82+
throw NewtonFailure(
83+
AMICI_NOT_IMPLEMENTED,
84+
"Adjoint preequilibration with reinitialization of "
85+
"non-constant states is not yet implemented. Stopping."
86+
);
87+
88+
// only preequilibrations needs a reInit, postequilibration does not
89+
preeq_solver->updateAndReinitStatesAndSensitivities(model_);
90+
7591
preeq_problem_bwd_.emplace(*solver_, *model_, final_state);
76-
preeq_problem_bwd_->run(*solver_, *model_, ws_.xB_, true, t0);
92+
preeq_problem_bwd_->run(ws_.xB_, t0);
7793
}
7894
}
7995

@@ -92,7 +108,7 @@ void BackwardProblem::handlePostequilibration() {
92108

93109
auto final_state = posteq_problem_->getFinalSimulationState();
94110
posteq_problem_bwd_.emplace(*solver_, *model_, final_state);
95-
posteq_problem_bwd_->run(*solver_, *model_, ws_.xB_, false, model_->t0());
111+
posteq_problem_bwd_->run(ws_.xB_, model_->t0());
96112
ws_.xQB_ = posteq_problem_bwd_->getEquilibrationQuadratures();
97113
}
98114

src/steadystateproblem.cpp

Lines changed: 25 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -245,32 +245,11 @@ SteadyStateBackwardProblem::SteadyStateBackwardProblem(
245245
, newton_solver_(
246246
NewtonSolver(model, solver.getLinearSolver(), solver.getSunContext())
247247
)
248-
, newton_step_conv_(solver.getNewtonStepSteadyStateCheck()) {}
249-
250-
void SteadyStateBackwardProblem::run(
251-
Solver const& solver, Model& model, AmiVector const& xB0, bool is_preeq,
252-
realtype t0
253-
) {
254-
if (is_preeq) {
255-
if (solver.getSensitivityMethodPreequilibration()
256-
!= SensitivityMethod::adjoint) {
257-
// if not adjoint mode, there's nothing to do
258-
return;
259-
}
260-
261-
// If we need to reinitialize solver states, this won't work yet.
262-
if (model.nx_reinit() > 0)
263-
throw NewtonFailure(
264-
AMICI_NOT_IMPLEMENTED,
265-
"Adjoint preequilibration with reinitialization of "
266-
"non-constant states is not yet implemented. Stopping."
267-
);
268-
269-
// only preequilibrations needs a reInit,
270-
// postequilibration does not
271-
solver.updateAndReinitStatesAndSensitivities(&model);
272-
}
248+
, newton_step_conv_(solver.getNewtonStepSteadyStateCheck())
249+
, model_(&model)
250+
, solver_(&solver) {}
273251

252+
void SteadyStateBackwardProblem::run(AmiVector const& xB0, realtype t0) {
274253
newton_solver_.reinitialize();
275254
xB_.copy(xB0);
276255

@@ -280,7 +259,7 @@ void SteadyStateBackwardProblem::run(
280259

281260
// Compute quadratures, track computation time
282261
CpuTimer cpu_timer;
283-
compute_steady_state_quadrature(solver, model, t0);
262+
compute_steady_state_quadrature(t0);
284263
cpu_timeB_ = cpu_timer.elapsed_milliseconds();
285264
}
286265

@@ -440,30 +419,28 @@ SteadyStateStatus SteadystateProblem::findSteadyStateBySimulation(
440419
}
441420
}
442421

443-
void SteadyStateBackwardProblem::compute_steady_state_quadrature(
444-
Solver const& solver, Model& model, realtype t0
445-
) {
422+
void SteadyStateBackwardProblem::compute_steady_state_quadrature(realtype t0) {
446423
// This routine computes the quadratures:
447424
// xQB = Integral[ xB(x(t), t, p) * dxdot/dp(x(t), t, p) | dt ]
448425
// As we're in steady state, we have x(t) = x_ss (x_steadystate), hence
449426
// xQB = Integral[ xB(x_ss, t, p) | dt ] * dxdot/dp(x_ss, t, p)
450427
// We therefore compute the integral over xB first and then do a
451428
// matrix-vector multiplication.
452429

453-
auto const sensitivityMode = model.getSteadyStateSensitivityMode();
430+
auto const sensitivityMode = model_->getSteadyStateSensitivityMode();
454431

455432
// Try to compute the analytical solution for quadrature algebraically
456433
if (sensitivityMode == SteadyStateSensitivityMode::newtonOnly
457434
|| sensitivityMode
458435
== SteadyStateSensitivityMode::integrateIfNewtonFails)
459-
compute_quadrature_by_lin_solve(model);
436+
compute_quadrature_by_lin_solve();
460437

461438
// Perform simulation if necessary
462439
if (sensitivityMode == SteadyStateSensitivityMode::integrationOnly
463440
|| (sensitivityMode
464441
== SteadyStateSensitivityMode::integrateIfNewtonFails
465442
&& !hasQuadrature()))
466-
compute_quadrature_by_simulation(solver, model, t0);
443+
compute_quadrature_by_simulation(t0);
467444

468445
// If the analytic solution and integration did not work, throw
469446
if (!hasQuadrature())
@@ -474,7 +451,7 @@ void SteadyStateBackwardProblem::compute_steady_state_quadrature(
474451
);
475452
}
476453

477-
void SteadyStateBackwardProblem::compute_quadrature_by_lin_solve(Model& model) {
454+
void SteadyStateBackwardProblem::compute_quadrature_by_lin_solve() {
478455
// Computes the integral over the adjoint state xB:
479456
// If the Jacobian has full rank, this has an analytical solution, since
480457
// d/dt[ xB(t) ] = JB^T(x(t), p) xB(t) = JB^T(x_ss, p) xB(t)
@@ -491,12 +468,13 @@ void SteadyStateBackwardProblem::compute_quadrature_by_lin_solve(Model& model) {
491468
try {
492469
// compute integral over xB and write to xQ
493470
newton_solver_.prepareLinearSystemB(
494-
model, {final_state_.t, final_state_.x, final_state_.dx}
471+
*model_, {final_state_.t, final_state_.x, final_state_.dx}
495472
);
496473
newton_solver_.solveLinearSystem(xQ_);
497474
// Compute the quadrature as the inner product xQ * dxdotdp
498475
computeQBfromQ(
499-
model, xQ_, xQB_, {final_state_.t, final_state_.x, final_state_.dx}
476+
*model_, xQ_, xQB_,
477+
{final_state_.t, final_state_.x, final_state_.dx}
500478
);
501479
has_quadrature_ = true;
502480

@@ -507,9 +485,7 @@ void SteadyStateBackwardProblem::compute_quadrature_by_lin_solve(Model& model) {
507485
}
508486
}
509487

510-
void SteadyStateBackwardProblem::compute_quadrature_by_simulation(
511-
Solver const& solver, Model& model, realtype t0
512-
) {
488+
void SteadyStateBackwardProblem::compute_quadrature_by_simulation(realtype t0) {
513489
// If the Jacobian is singular, the integral over xB must be computed
514490
// by usual integration over time, but simplifications can be applied:
515491
// x is not time-dependent, no forward trajectory is needed.
@@ -519,18 +495,18 @@ void SteadyStateBackwardProblem::compute_quadrature_by_simulation(
519495
// xQ was written in getQuadratureByLinSolve() -> set to zero
520496
xQ_.zero();
521497

522-
auto sim_solver = std::unique_ptr<Solver>(solver.clone());
523-
sim_solver->logger = solver.logger;
498+
auto sim_solver = std::unique_ptr<Solver>(solver_->clone());
499+
sim_solver->logger = solver_->logger;
524500
sim_solver->setSensitivityMethod(SensitivityMethod::none);
525501
sim_solver->setSensitivityOrder(SensitivityOrder::none);
526-
sim_solver->setup(t0, &model, xB_, xB_, final_state_.sx, final_state_.sx);
502+
sim_solver->setup(t0, model_, xB_, xB_, final_state_.sx, final_state_.sx);
527503
sim_solver->setupSteadystate(
528-
t0, &model, final_state_.x, final_state_.dx, xB_, xB_, xQ_
504+
t0, model_, final_state_.x, final_state_.dx, xB_, xB_, xQ_
529505
);
530506

531507
// perform integration and quadrature
532508
try {
533-
run_simulation(*sim_solver, model);
509+
run_simulation(*sim_solver);
534510
has_quadrature_ = true;
535511
} catch (NewtonFailure const&) {
536512
has_quadrature_ = false;
@@ -681,10 +657,8 @@ void SteadystateProblem::runSteadystateSimulationFwd(Model& model) {
681657
updateSensiSimulation();
682658
}
683659

684-
void SteadyStateBackwardProblem::run_simulation(
685-
Solver const& solver, Model& model
686-
) {
687-
if (model.nx_solver == 0)
660+
void SteadyStateBackwardProblem::run_simulation(Solver const& solver) {
661+
if (model_->nx_solver == 0)
688662
return;
689663

690664
if (newton_step_conv_) {
@@ -699,13 +673,13 @@ void SteadyStateBackwardProblem::run_simulation(
699673

700674
// WRMS computer for xQB
701675
WRMSComputer wrms_computer_xQB_(
702-
model.nplist(), solver.getSunContext(),
676+
model_->nplist(), solver.getSunContext(),
703677
solver.getAbsoluteToleranceQuadratures(),
704678
solver.getRelativeToleranceQuadratures(), AmiVector()
705679
);
706680

707681
// time-derivative of quadrature state vector
708-
AmiVector xQBdot(model.nplist(), solver.getSunContext());
682+
AmiVector xQBdot(model_->nplist(), solver.getSunContext());
709683

710684
int const convergence_check_frequency = newton_step_conv_ ? 25 : 1;
711685
auto max_steps = (solver.getMaxStepsBackwardProblem() > 0)
@@ -722,11 +696,11 @@ void SteadyStateBackwardProblem::run_simulation(
722696
// converge to zero at all. So we need xQBdot, hence compute xQB
723697
// first.
724698
computeQBfromQ(
725-
model, xQ_, xQB_,
699+
*model_, xQ_, xQB_,
726700
{final_state_.t, final_state_.x, final_state_.dx}
727701
);
728702
computeQBfromQ(
729-
model, xB_, xQBdot,
703+
*model_, xB_, xQBdot,
730704
{final_state_.t, final_state_.x, final_state_.dx}
731705
);
732706
auto wrms = wrms_computer_xQB_.wrms(xQBdot, xQB_);

0 commit comments

Comments
 (0)