From 8f7aea3414cc705224c2c454f1ed62ff6f44b9b0 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Mon, 30 Jun 2025 21:49:38 +0200 Subject: [PATCH] ... --- include/amici/misc.h | 12 ++++++++++ include/amici/model.h | 11 +++++++++ include/amici/model_state.h | 16 ++++++++++++- include/amici/serialization.h | 2 ++ python/tests/test_events.py | 44 +++++++++++++++++++++++++++++++++++ src/forwardproblem.cpp | 5 ++++ src/model.cpp | 2 ++ src/model_dae.cpp | 13 +++++++++++ src/model_ode.cpp | 13 +++++++++++ 9 files changed, 117 insertions(+), 1 deletion(-) diff --git a/include/amici/misc.h b/include/amici/misc.h index a9821c1cfc..a1302f4d4a 100644 --- a/include/amici/misc.h +++ b/include/amici/misc.h @@ -344,6 +344,18 @@ class CpuTimer { }; #endif +/** + * @brief The sign function. + * + * @param x The value to determine the sign of. + * @return -1, 0, or 1 depending on the sign of x. + */ +template +int sign(T x) { + return (T(0) < x) - (x < T(0)); +} + + } // namespace amici #endif // AMICI_MISC_H diff --git a/include/amici/model.h b/include/amici/model.h index 2aabadc472..51f9a74d78 100644 --- a/include/amici/model.h +++ b/include/amici/model.h @@ -1362,6 +1362,17 @@ class Model : public AbstractModel, public ModelDimensions { */ void updateHeaviside(std::vector const& rootsfound); + /** + * @brief Disable the event with index `ie` because it just triggered. + * + * Not to be called by user code. + * + * @param ie Event index. + */ + void register_root(int const ie, int direction) { + state_.root_enabled.at(ie) = false; + state_.root_last_sign.at(ie) = direction; + } /** * @brief Check if the given array has only finite elements. * diff --git a/include/amici/model_state.h b/include/amici/model_state.h index 7e5a0c6582..a6d2fd7e8f 100644 --- a/include/amici/model_state.h +++ b/include/amici/model_state.h @@ -29,6 +29,8 @@ struct ModelState { stotal_cl.resize((dim.nx_rdata - dim.nx_solver) * dim.np, 0.0); unscaledParameters.resize(dim.np); fixedParameters.resize(dim.nk); + root_enabled.resize(dim.ne, true); + root_last_sign.resize(dim.ne, 0); } /** @@ -56,6 +58,18 @@ struct ModelState { * (dimension: nplist) */ std::vector plist; + + /** + * Flags indicating whether a root function element is enabled + * (dimension: `ne`) + */ + std::vector root_enabled; + + /** + * The sign of the root function elements at the last root function call + * (dimension: `ne`). + */ + std::vector root_last_sign; }; inline bool operator==(ModelState const& a, ModelState const& b) { @@ -63,7 +77,7 @@ inline bool operator==(ModelState const& a, ModelState const& b) { && is_equal(a.stotal_cl, b.stotal_cl) && is_equal(a.unscaledParameters, b.unscaledParameters) && is_equal(a.fixedParameters, b.fixedParameters) - && a.plist == b.plist; + && a.plist == b.plist && a.root_enabled == b.root_enabled && a.root_last_sign == b.root_last_sign; } /** diff --git a/include/amici/serialization.h b/include/amici/serialization.h index 4aa00c975d..be07cb3e0d 100644 --- a/include/amici/serialization.h +++ b/include/amici/serialization.h @@ -145,6 +145,8 @@ void serialize(Archive& ar, amici::Model& m, unsigned int const /*version*/) { ar & m.state_.unscaledParameters; ar & m.state_.fixedParameters; ar & m.state_.plist; + ar & m.state_.root_enabled; + ar & m.state_.root_last_sign; ar & m.x0data_; ar & m.sx0data_; ar & m.nmaxevent_; diff --git a/python/tests/test_events.py b/python/tests/test_events.py index c02b27a7f7..1945b439ba 100644 --- a/python/tests/test_events.py +++ b/python/tests/test_events.py @@ -1072,3 +1072,47 @@ def test_event_uses_values_from_trigger_time(tempdir): ) # TODO: test ASA after https://github.com/AMICI-dev/AMICI/pull/1539 + + +def test_simultaneous_events(tempdir): + """Test simultaneously firing events with different trigger functions.""" + from amici.antimony_import import antimony2amici + + model_name = "test_simultaneous_events" + antimony2amici( + r""" + target1_0 = 1 + target1 = target1_0 + one = 1 + target1' = one + two = 2 + target2_0 = two + target2 = target2_0 + target2' = 1 + some_time = time + some_time' = 1 + trigger_time = 1000 + + E1: at some_time >= trigger_time, priority=10, fromTrigger=false: + target1 = target1 + 10; + E2: at time >= trigger_time, priority=20, fromTrigger=false: + target2 = target2 + 10; + """, + model_name=model_name, + output_dir=tempdir, + ) + + model_module = import_model_module(model_name, tempdir) + + model = model_module.get_model() + model.setTimepoints([0, 2]) + solver = model.getSolver() + solver.setRelativeTolerance(1e-6) + solver.setAbsoluteTolerance(1e-6) + solver.setSensitivityOrder(SensitivityOrder.first) + solver.setSensitivityMethod(SensitivityMethod.forward) + + rdata = amici.runAmiciSimulation(model, solver) + assert rdata.status == amici.AMICI_SUCCESS + assert_allclose(rdata.by_id("target1"), [1.0, 13.0]) + assert_allclose(rdata.by_id("target2"), [2.0, 14.0]) diff --git a/src/forwardproblem.cpp b/src/forwardproblem.cpp index 4f3cd1be3e..c829948a94 100644 --- a/src/forwardproblem.cpp +++ b/src/forwardproblem.cpp @@ -359,6 +359,10 @@ void EventHandlingSimulator::handle_event( ? std::optional(get_simulation_state()) : std::nullopt)} ); + + } + if(ws_->roots_found.at(ie) != 0) { + model_->register_root(ie, ws_->roots_found.at(ie)); } } @@ -533,6 +537,7 @@ int EventHandlingSimulator::detect_secondary_events() { } else { ws_->roots_found.at(ie) = -1; } + model_->register_root(ie, ws_->roots_found.at(ie)); secondevent++; } else { ws_->roots_found.at(ie) = 0; diff --git a/src/model.cpp b/src/model.cpp index cfb69189fe..a8c8dbb096 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -437,6 +437,8 @@ void Model::initEvents( roots_found.at(ie) = 1; } } + state_.root_enabled[ie] = rootvals[ie] != 0; + state_.root_last_sign[ie] = events_[ie].get_initial_value()?1:-1; } } diff --git a/src/model_dae.cpp b/src/model_dae.cpp index 6e4a4a8f07..25e923a45e 100644 --- a/src/model_dae.cpp +++ b/src/model_dae.cpp @@ -110,6 +110,19 @@ void Model_DAE::froot( state_.unscaledParameters.data(), state_.fixedParameters.data(), state_.h.data(), N_VGetArrayPointerConst(dx) ); + + for (int ie = 0; ie < ne; ++ie) { + if (!state_.root_enabled[ie]) { + if (root[ie] < 0.0) { + // If the disabled root function becomes negative, + // re-enable it. + state_.root_enabled[ie] = true; + } else { + // If the root function is disabled, mask it + root[ie] = 1.0; + } + } + } } void Model_DAE::fxdot( diff --git a/src/model_ode.cpp b/src/model_ode.cpp index b4908dbe63..8a2892059d 100644 --- a/src/model_ode.cpp +++ b/src/model_ode.cpp @@ -97,6 +97,19 @@ void Model_ODE::froot(realtype t, const_N_Vector x, gsl::span root) { state_.unscaledParameters.data(), state_.fixedParameters.data(), state_.h.data(), state_.total_cl.data() ); + + for (int ie = 0; ie < ne; ++ie) { + auto sgn = sign(root[ie]); + if (!state_.root_enabled[ie] && sgn != 0 && sgn != state_.root_last_sign[ie]) { + // The sign flipped, so we re-enable the root function + state_.root_enabled[ie] = true; + } + + if(!state_.root_enabled[ie]) { + // If the root function is disabled, mask it + root[ie] = (state_.root_last_sign[ie] > 0) ? 1.0 : -1.0; + } + } } void Model_ODE::fxdot(