Skip to content

Commit 915cc0a

Browse files
authored
Clean up SteadystateProblem (#2808)
* Remove unused members * Construct some objects closer to where they are required. Avoids extra allocations and simplifies the class.
1 parent 2922a4b commit 915cc0a

File tree

3 files changed

+43
-46
lines changed

3 files changed

+43
-46
lines changed

include/amici/steadystateproblem.h

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -472,9 +472,10 @@ class SteadystateProblem {
472472
/**
473473
* @brief Checks convergence for state sensitivities
474474
* @param model Model instance
475+
* @param wrms_computer_sx WRMSComputer instance for state sensitivities
475476
* @return weighted root mean squared residuals of the RHS
476477
*/
477-
realtype getWrmsFSA(Model& model);
478+
realtype getWrmsFSA(Model& model, WRMSComputer& wrms_computer_sx);
478479

479480
/**
480481
* @brief Launch simulation if Newton solver or linear system solve
@@ -526,12 +527,6 @@ class SteadystateProblem {
526527

527528
/** WRMS computer for x */
528529
WRMSComputer wrms_computer_x_;
529-
/** WRMS computer for xQB */
530-
WRMSComputer wrms_computer_xQB_;
531-
/** WRMS computer for sx */
532-
WRMSComputer wrms_computer_sx_;
533-
/** old state vector */
534-
AmiVector x_old_;
535530
/** time derivative state vector */
536531
AmiVector xdot_;
537532
/** state differential sensitivities */
@@ -542,9 +537,6 @@ class SteadystateProblem {
542537
AmiVector xQ_;
543538
/** quadrature state vector */
544539
AmiVector xQB_;
545-
/** time-derivative of quadrature state vector */
546-
AmiVector xQBdot_;
547-
548540
/** weighted root-mean-square error */
549541
realtype wrms_{NAN};
550542

@@ -582,10 +574,6 @@ class SteadystateProblem {
582574
* checks during simulation and Newton's method.
583575
*/
584576
bool newton_step_conv_{false};
585-
/**
586-
* whether sensitivities should be checked for convergence to steady state
587-
*/
588-
bool check_sensi_conv_{true};
589577

590578
/** flag indicating whether xdot_ has been computed for the current state */
591579
bool xdot_updated_{false};

include/amici/vector.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class AmiVector {
4040
AmiVector() = default;
4141

4242
/**
43-
* @brief Construct empty vector of given size
43+
* @brief Construct zero-initialized vector of the given size.
4444
*
4545
* Creates an std::vector<realtype> and attaches the
4646
* data pointer to a newly created N_Vector_Serial.

src/steadystateproblem.cpp

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -123,23 +123,11 @@ SteadystateProblem::SteadystateProblem(Solver const& solver, Model& model)
123123
solver.getRelativeToleranceSteadyState(),
124124
AmiVector(model.get_steadystate_mask(), solver.getSunContext())
125125
)
126-
, wrms_computer_xQB_(
127-
model.nplist(), solver.getSunContext(),
128-
solver.getAbsoluteToleranceQuadratures(),
129-
solver.getRelativeToleranceQuadratures(), AmiVector()
130-
)
131-
, wrms_computer_sx_(
132-
model.nx_solver, solver.getSunContext(),
133-
solver.getAbsoluteToleranceSteadyStateSensi(),
134-
solver.getRelativeToleranceSteadyStateSensi(),
135-
AmiVector(model.get_steadystate_mask(), solver.getSunContext())
136-
)
137126
, xdot_(model.nx_solver, solver.getSunContext())
138127
, sdx_(model.nx_solver, model.nplist(), solver.getSunContext())
139128
, xB_(model.nJ * model.nx_solver, solver.getSunContext())
140129
, xQ_(model.nJ * model.nx_solver, solver.getSunContext())
141130
, xQB_(model.nplist(), solver.getSunContext())
142-
, xQBdot_(model.nplist(), solver.getSunContext())
143131
, state_(
144132
{.t = INFINITY,
145133
.x = AmiVector(model.nx_solver, solver.getSunContext()),
@@ -158,8 +146,7 @@ SteadystateProblem::SteadystateProblem(Solver const& solver, Model& model)
158146
solver.getNewtonDampingFactorLowerBound(), solver.getNewtonMaxSteps(),
159147
solver.getNewtonStepSteadyStateCheck()
160148
)
161-
, newton_step_conv_(solver.getNewtonStepSteadyStateCheck())
162-
, check_sensi_conv_(solver.getSensiSteadyStateCheck()) {
149+
, newton_step_conv_(solver.getNewtonStepSteadyStateCheck()) {
163150
// Check for compatibility of options
164151
if (solver.getSensitivityMethod() == SensitivityMethod::forward
165152
&& solver.getSensitivityMethodPreequilibration()
@@ -252,7 +239,6 @@ void SteadystateProblem::workSteadyStateBackwardProblem(
252239
// initialize quadratures
253240
xQ_.zero();
254241
xQB_.zero();
255-
xQBdot_.zero();
256242

257243
// Compute quadratures, track computation time
258244
CpuTimer cpu_timer;
@@ -627,7 +613,8 @@ realtype SteadystateProblem::getWrmsState(Model& model) {
627613
return wrms_computer_x_.wrms(xdot_, state_.x);
628614
}
629615

630-
realtype SteadystateProblem::getWrmsFSA(Model& model) {
616+
realtype
617+
SteadystateProblem::getWrmsFSA(Model& model, WRMSComputer& wrms_computer_sx) {
631618
// Forward sensitivities: Compute weighted error norm for their RHS
632619
realtype wrms = 0.0;
633620

@@ -643,7 +630,7 @@ realtype SteadystateProblem::getWrmsFSA(Model& model) {
643630
if (newton_step_conv_) {
644631
newton_solver_.solveLinearSystem(xdot_);
645632
}
646-
wrms = wrms_computer_sx_.wrms(xdot_, state_.sx[ip]);
633+
wrms = wrms_computer_sx.wrms(xdot_, state_.sx[ip]);
647634
// ideally this function would report the maximum of all wrms over
648635
// all ip, but for practical purposes we can just report the wrms for
649636
// the first ip where we know that the convergence threshold is not
@@ -685,6 +672,28 @@ void SteadystateProblem::runSteadystateSimulationFwd(
685672
sensitivity_method = SensitivityMethod::none;
686673
}
687674

675+
// function for sensitivity convergence check or dummy
676+
std::function<bool()> sensi_converged;
677+
if (solver.getSensiSteadyStateCheck()
678+
&& sensitivity_method == SensitivityMethod::forward) {
679+
sensi_converged =
680+
[&,
681+
wrms_computer_sx = WRMSComputer(
682+
model.nx_solver, solver.getSunContext(),
683+
solver.getAbsoluteToleranceSteadyStateSensi(),
684+
solver.getRelativeToleranceSteadyStateSensi(),
685+
AmiVector(model.get_steadystate_mask(), solver.getSunContext())
686+
)]() mutable -> bool {
687+
updateSensiSimulation(solver);
688+
// getWrms needs to be called before getWrmsFSA
689+
// such that the linear system is prepared for newton-type
690+
// convergence check
691+
return getWrmsFSA(model, wrms_computer_sx) < conv_thresh;
692+
};
693+
} else {
694+
sensi_converged = []() { return true; };
695+
}
696+
688697
int& sim_steps = numsteps_.at(1);
689698
int convergence_check_frequency = newton_step_conv_ ? 25 : 1;
690699

@@ -693,18 +702,8 @@ void SteadystateProblem::runSteadystateSimulationFwd(
693702
// Check for convergence (already before simulation, since we might
694703
// start in steady state)
695704
wrms_ = getWrmsState(model);
696-
if (wrms_ < conv_thresh) {
697-
if (check_sensi_conv_
698-
&& sensitivity_method == SensitivityMethod::forward) {
699-
updateSensiSimulation(solver);
700-
// getWrms needs to be called before getWrmsFSA
701-
// such that the linear system is prepared for newton-type
702-
// convergence check
703-
if (getWrmsFSA(model) < conv_thresh)
704-
break; // converged
705-
} else {
706-
break; // converged
707-
}
705+
if (wrms_ < conv_thresh && sensi_converged()) {
706+
break;
708707
}
709708
}
710709

@@ -751,6 +750,16 @@ void SteadystateProblem::runSteadystateSimulationBwd(
751750

752751
int& sim_steps = numstepsB_;
753752

753+
// WRMS computer for xQB
754+
WRMSComputer wrms_computer_xQB_(
755+
model.nplist(), solver.getSunContext(),
756+
solver.getAbsoluteToleranceQuadratures(),
757+
solver.getRelativeToleranceQuadratures(), AmiVector()
758+
);
759+
760+
// time-derivative of quadrature state vector
761+
AmiVector xQBdot(model.nplist(), solver.getSunContext());
762+
754763
int convergence_check_frequency = newton_step_conv_ ? 25 : 1;
755764
auto max_steps = (solver.getMaxStepsBackwardProblem() > 0)
756765
? solver.getMaxStepsBackwardProblem()
@@ -766,8 +775,8 @@ void SteadystateProblem::runSteadystateSimulationBwd(
766775
// converge to zero at all. So we need xQBdot, hence compute xQB
767776
// first.
768777
computeQBfromQ(model, xQ_, xQB_, state_);
769-
computeQBfromQ(model, xB_, xQBdot_, state_);
770-
wrms_ = wrms_computer_xQB_.wrms(xQBdot_, xQB_);
778+
computeQBfromQ(model, xB_, xQBdot, state_);
779+
wrms_ = wrms_computer_xQB_.wrms(xQBdot, xQB_);
771780
if (wrms_ < conv_thresh) {
772781
break; // converged
773782
}

0 commit comments

Comments
 (0)