Skip to content

Commit a8a10f6

Browse files
authored
Event handling during pre-equilibration with FSA before ASA simulation (#2881)
Implement event handling during pre-equilibration if steady-state sensitivities are computed with FSA, and ASA is used afterwards. (No ASA preeq + ASA main sim yet.) * Store discontinuities during pre-equilibration (store `PeriodResult` instead of only `SimulationState`) * Make sure to use the same `EventHandlingSimulator` for initial events and later events, with the correct sensitivity method Storing the results is not so pretty yet. To be cleaned up once all methods work. Related to #2775.
1 parent 0e5444e commit a8a10f6

File tree

6 files changed

+122
-61
lines changed

6 files changed

+122
-61
lines changed

include/amici/backwardproblem.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,8 @@ class BackwardProblem {
335335
/** current time */
336336
realtype t_;
337337

338-
/** array containing the time-points of discontinuities*/
339-
std::vector<Discontinuity> discs_;
338+
/** The discontinuities encountered during the main simulation. */
339+
std::vector<Discontinuity> discs_main_;
340340

341341
/**
342342
* state derivative of data likelihood

include/amici/forwardproblem.h

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -383,9 +383,11 @@ class SteadystateProblem {
383383
* @param ws Workspace for forward simulation
384384
* @param solver Solver instance
385385
* @param model Model instance
386+
* @param is_preeq Whether this is a pre-equilibration (`true`) or
387+
* post-equilibration problem (`false`).
386388
*/
387389
explicit SteadystateProblem(
388-
FwdSimWorkspace* ws, Solver const& solver, Model& model
390+
FwdSimWorkspace* ws, Solver const& solver, Model& model, bool is_preeq
389391
);
390392

391393
/**
@@ -399,29 +401,30 @@ class SteadystateProblem {
399401
* @param it Index of the current output time point.
400402
* @param t0 Initial time for the steady state simulation.
401403
*/
402-
void
403-
workSteadyStateProblem(Solver& solver, int it, realtype t0);
404+
void workSteadyStateProblem(Solver& solver, int it, realtype t0);
404405

405406
/**
406407
* @brief Return the stored SimulationState.
407408
* @return stored SimulationState
408409
*/
409410
[[nodiscard]] SimulationState const& getFinalSimulationState() const {
410-
return final_state_;
411+
return period_result_.final_state_;
411412
}
412413

413414
/**
414415
* @brief Return state at steady state
415416
* @return x
416417
*/
417-
[[nodiscard]] AmiVector const& getState() const { return final_state_.x; }
418+
[[nodiscard]] AmiVector const& getState() const {
419+
return period_result_.final_state_.x;
420+
}
418421

419422
/**
420423
* @brief Return state sensitivity at steady state
421424
* @return sx
422425
*/
423426
[[nodiscard]] AmiVectorArray const& getStateSensitivity() const {
424-
return final_state_.sx;
427+
return period_result_.final_state_.sx;
425428
}
426429

427430
/**
@@ -444,7 +447,9 @@ class SteadystateProblem {
444447
* @brief Get model time at which steady state was found through simulation.
445448
* @return Time at which steady state was found (model time units).
446449
*/
447-
[[nodiscard]] realtype getSteadyStateTime() const { return final_state_.t; }
450+
[[nodiscard]] realtype getSteadyStateTime() const {
451+
return period_result_.final_state_.t;
452+
}
448453

449454
/**
450455
* @brief Get the weighted root mean square of the residuals.
@@ -546,15 +551,20 @@ class SteadystateProblem {
546551
*/
547552
void updateRightHandSide();
548553

554+
/** Whether this is a pre- or post-equilibration problem */
555+
bool is_preeq_;
556+
549557
/** Workspace for forward simulation */
550558
FwdSimWorkspace* ws_;
559+
551560
/** WRMS computer for x */
552561
WRMSComputer wrms_computer_x_;
562+
553563
/** weighted root-mean-square error */
554564
realtype wrms_{NAN};
555565

556-
/** The simulation state at the end of the forward problem. */
557-
SimulationState final_state_;
566+
/** Results for this period. */
567+
PeriodResult period_result_;
558568

559569
/** stores diagnostic information about employed number of steps */
560570
std::vector<int> numsteps_{std::vector<int>(3, 0)};
@@ -605,6 +615,9 @@ class SteadystateProblem {
605615

606616
/** The model to equilibrate */
607617
Model* model_{nullptr};
618+
619+
/** Simulator for event handling */
620+
std::optional<EventHandlingSimulator> simulator_;
608621
};
609622

610623
/**

python/tests/test_preequilibration.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
skip_on_valgrind,
1414
)
1515
from amici.gradient_check import check_derivatives
16+
from amici import SensitivityMethod, SensitivityOrder
1617

1718
pytestmark = pytest.mark.filterwarnings(
1819
# https://github.com/AMICI-dev/AMICI/issues/18
@@ -760,15 +761,18 @@ def test_preequilibration_events(tempdir):
760761
bolus2 = 1
761762
bolus3 = 1
762763
bolus4 = 1
763-
at some_time >= 0, t0 = false: target1 = target1 + bolus1
764-
at time >= 0, t0 = false: target2 = target2 + bolus2
764+
# E1 & E2 will both trigger during pre-equilibration and main
765+
# simulation (Heaviside is reset after pre-equilibration)
766+
E1: at some_time >= 0, t0 = false: target1 = target1 + bolus1
767+
E2: at time >= 0, t0 = false: target2 = target2 + bolus2
765768
# requires early time point
766769
# https://github.com/AMICI-dev/AMICI/issues/2804
767770
trigger_time2 = 1e-3
768-
# will trigger only during preequilibration (some_time is not reset)
769-
at some_time >= trigger_time2: target3 = target3 + bolus3
771+
# E3 will trigger only during preequilibration
772+
# (some_time is not reset and trigger initial value is `true`)
773+
E3: at some_time >= trigger_time2: target3 = target3 + bolus3
770774
# will trigger during preequilibration and main simulation
771-
at time >= trigger_time2: target4 = target4 + bolus4
775+
E4: at time >= trigger_time2: target4 = target4 + bolus4
772776
end
773777
"""
774778
module_name = "test_preequilibration_events"
@@ -798,9 +802,15 @@ def test_preequilibration_events(tempdir):
798802

799803
# Integration-only preequilibration should handle all events
800804
amici_model = model_module.getModel()
805+
amici_model.setSteadyStateSensitivityMode(
806+
amici.SteadyStateSensitivityMode.integrationOnly
807+
)
808+
amici_model.setSteadyStateComputationMode(
809+
amici.SteadyStateSensitivityMode.integrationOnly
810+
)
801811
rdata = amici.runAmiciSimulation(amici_model, amici_solver, edata)
802812
assert rdata.status == amici.AMICI_SUCCESS
803-
assert rdata.preeq_t > 1e-3
813+
assert rdata.preeq_t > 1e-3 # verifies that integration was done
804814
assert rdata.x_ss[target1_idx] == 1
805815
assert rdata.x_ss[target2_idx] == 1
806816
assert rdata.x_ss[target3_idx] == 1
@@ -814,17 +824,20 @@ def test_preequilibration_events(tempdir):
814824
edata.fixedParametersPreequilibration = [1.0]
815825
edata.fixedParameters = [0.0]
816826

817-
amici_solver.setSensitivityMethod(amici.SensitivityMethod.forward)
818-
amici_solver.setSensitivityOrder(amici.SensitivityOrder.first)
827+
amici_solver.setSensitivityOrder(SensitivityOrder.first)
819828

820-
for sensi_meth in [
821-
amici.SensitivityMethod.forward,
822-
# amici.SensitivityMethod.adjoint,
823-
]:
829+
for sensi_meth, sensi_meth_preeq in (
830+
(SensitivityMethod.forward, SensitivityMethod.forward),
831+
(SensitivityMethod.adjoint, SensitivityMethod.forward),
832+
# TODO https://github.com/AMICI-dev/AMICI/issues/2775
833+
# (SensitivityMethod.adjoint, SensitivityMethod.adjoint),
834+
):
824835
amici_solver.setSensitivityMethod(sensi_meth)
836+
amici_solver.setSensitivityMethodPreequilibration(sensi_meth_preeq)
825837

826838
# amici_model.requireSensitivitiesForAllParameters()
827-
# FIXME: sensitivities w.r.t. trigger time are off
839+
# FIXME: finite differences w.r.t. trigger time are off
840+
# need different epsilon for trigger time
828841
amici_model.setParameterList(
829842
[
830843
i

src/backwardproblem.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ BackwardProblem::BackwardProblem(ForwardProblem& fwd)
1515
, solver_(fwd.solver)
1616
, edata_(fwd.edata)
1717
, t_(fwd.getTime())
18-
, discs_(fwd.getDiscontinuities())
18+
, discs_main_(fwd.getDiscontinuities())
1919
, dJydx_(fwd.getAdjointUpdates(*model_, *edata_))
2020
, dJzdx_(fwd.getDJzdx())
2121
, preeq_problem_(fwd.getPreequilibrationProblem())
@@ -45,8 +45,8 @@ void BackwardProblem::workBackwardProblem() {
4545

4646
// initialize state vectors, depending on postequilibration
4747
model_->initializeB(ws_.xB_, ws_.dxB_, ws_.xQB_, it < model_->nt() - 1);
48-
ws_.discs_ = discs_;
49-
ws_.nroots_ = compute_nroots(discs_, model_->ne, model_->nMaxEvent());
48+
ws_.discs_ = discs_main_;
49+
ws_.nroots_ = compute_nroots(discs_main_, model_->ne, model_->nMaxEvent());
5050
simulator_.run(
5151
t_, model_->t0(), it, model_->getTimepoints(), &dJydx_, &dJzdx_
5252
);

0 commit comments

Comments
 (0)