Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/amici/misc.h
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,18 @@
};
#endif

/**
* @brief The sign function.
*
* @param x The value to determine the sign of.
* @return -1, 0, or 1 depending on the sign of x.
*/
template <typename T>
int sign(T x) {
return (T(0) < x) - (x < T(0));

Check warning on line 355 in include/amici/misc.h

View check run for this annotation

Codecov / codecov/patch

include/amici/misc.h#L354-L355

Added lines #L354 - L355 were not covered by tests
}


} // namespace amici

#endif // AMICI_MISC_H
11 changes: 11 additions & 0 deletions include/amici/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,17 @@
*/
void updateHeaviside(std::vector<int> const& rootsfound);

/**
* @brief Disable the event with index `ie` because it just triggered.
*
* Not to be called by user code.
*
* @param ie Event index.
*/
void register_root(int const ie, int direction) {
state_.root_enabled.at(ie) = false;
state_.root_last_sign.at(ie) = direction;
}

Check warning on line 1375 in include/amici/model.h

View check run for this annotation

Codecov / codecov/patch

include/amici/model.h#L1375

Added line #L1375 was not covered by tests
/**
* @brief Check if the given array has only finite elements.
*
Expand Down
16 changes: 15 additions & 1 deletion include/amici/model_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
stotal_cl.resize((dim.nx_rdata - dim.nx_solver) * dim.np, 0.0);
unscaledParameters.resize(dim.np);
fixedParameters.resize(dim.nk);
root_enabled.resize(dim.ne, true);
root_last_sign.resize(dim.ne, 0);
}

/**
Expand Down Expand Up @@ -56,14 +58,26 @@
* (dimension: nplist)
*/
std::vector<int> plist;

/**
* Flags indicating whether a root function element is enabled
* (dimension: `ne`)
*/
std::vector<bool> root_enabled;

/**
* The sign of the root function elements at the last root function call
* (dimension: `ne`).
*/
std::vector<int> root_last_sign;
};

inline bool operator==(ModelState const& a, ModelState const& b) {
return is_equal(a.h, b.h) && is_equal(a.total_cl, b.total_cl)
&& is_equal(a.stotal_cl, b.stotal_cl)
&& is_equal(a.unscaledParameters, b.unscaledParameters)
&& is_equal(a.fixedParameters, b.fixedParameters)
&& a.plist == b.plist;
&& a.plist == b.plist && a.root_enabled == b.root_enabled && a.root_last_sign == b.root_last_sign;

Check warning on line 80 in include/amici/model_state.h

View check run for this annotation

Codecov / codecov/patch

include/amici/model_state.h#L80

Added line #L80 was not covered by tests
}

/**
Expand Down
2 changes: 2 additions & 0 deletions include/amici/serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@
ar & m.state_.unscaledParameters;
ar & m.state_.fixedParameters;
ar & m.state_.plist;
ar & m.state_.root_enabled;
ar & m.state_.root_last_sign;

Check warning on line 149 in include/amici/serialization.h

View check run for this annotation

Codecov / codecov/patch

include/amici/serialization.h#L148-L149

Added lines #L148 - L149 were not covered by tests
ar & m.x0data_;
ar & m.sx0data_;
ar & m.nmaxevent_;
Expand Down
44 changes: 44 additions & 0 deletions python/tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,3 +1072,47 @@ def test_event_uses_values_from_trigger_time(tempdir):
)

# TODO: test ASA after https://github.com/AMICI-dev/AMICI/pull/1539


def test_simultaneous_events(tempdir):
"""Test simultaneously firing events with different trigger functions."""
from amici.antimony_import import antimony2amici

model_name = "test_simultaneous_events"
antimony2amici(
r"""
target1_0 = 1
target1 = target1_0
one = 1
target1' = one
two = 2
target2_0 = two
target2 = target2_0
target2' = 1
some_time = time
some_time' = 1
trigger_time = 1000
E1: at some_time >= trigger_time, priority=10, fromTrigger=false:
target1 = target1 + 10;
E2: at time >= trigger_time, priority=20, fromTrigger=false:
target2 = target2 + 10;
""",
model_name=model_name,
output_dir=tempdir,
)

model_module = import_model_module(model_name, tempdir)

model = model_module.get_model()
model.setTimepoints([0, 2])
solver = model.getSolver()
solver.setRelativeTolerance(1e-6)
solver.setAbsoluteTolerance(1e-6)
solver.setSensitivityOrder(SensitivityOrder.first)
solver.setSensitivityMethod(SensitivityMethod.forward)

rdata = amici.runAmiciSimulation(model, solver)
assert rdata.status == amici.AMICI_SUCCESS
assert_allclose(rdata.by_id("target1"), [1.0, 13.0])
assert_allclose(rdata.by_id("target2"), [2.0, 14.0])
5 changes: 5 additions & 0 deletions src/forwardproblem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@
? std::optional<SimulationState>(get_simulation_state())
: std::nullopt)}
);

}
if(ws_->roots_found.at(ie) != 0) {
model_->register_root(ie, ws_->roots_found.at(ie));
}
}

Expand Down Expand Up @@ -533,6 +537,7 @@
} else {
ws_->roots_found.at(ie) = -1;
}
model_->register_root(ie, ws_->roots_found.at(ie));

Check warning on line 540 in src/forwardproblem.cpp

View check run for this annotation

Codecov / codecov/patch

src/forwardproblem.cpp#L540

Added line #L540 was not covered by tests
secondevent++;
} else {
ws_->roots_found.at(ie) = 0;
Expand Down
2 changes: 2 additions & 0 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,8 @@ void Model::initEvents(
roots_found.at(ie) = 1;
}
}
state_.root_enabled[ie] = rootvals[ie] != 0;
state_.root_last_sign[ie] = events_[ie].get_initial_value()?1:-1;
}
}

Expand Down
13 changes: 13 additions & 0 deletions src/model_dae.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,19 @@
state_.unscaledParameters.data(), state_.fixedParameters.data(),
state_.h.data(), N_VGetArrayPointerConst(dx)
);

for (int ie = 0; ie < ne; ++ie) {
if (!state_.root_enabled[ie]) {
if (root[ie] < 0.0) {
// If the disabled root function becomes negative,
// re-enable it.
state_.root_enabled[ie] = true;

Check warning on line 119 in src/model_dae.cpp

View check run for this annotation

Codecov / codecov/patch

src/model_dae.cpp#L119

Added line #L119 was not covered by tests
} else {
// If the root function is disabled, mask it
root[ie] = 1.0;

Check warning on line 122 in src/model_dae.cpp

View check run for this annotation

Codecov / codecov/patch

src/model_dae.cpp#L122

Added line #L122 was not covered by tests
}
}
}
}

void Model_DAE::fxdot(
Expand Down
13 changes: 13 additions & 0 deletions src/model_ode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,19 @@
state_.unscaledParameters.data(), state_.fixedParameters.data(),
state_.h.data(), state_.total_cl.data()
);

for (int ie = 0; ie < ne; ++ie) {
auto sgn = sign(root[ie]);
if (!state_.root_enabled[ie] && sgn != 0 && sgn != state_.root_last_sign[ie]) {
// The sign flipped, so we re-enable the root function
state_.root_enabled[ie] = true;

Check warning on line 105 in src/model_ode.cpp

View check run for this annotation

Codecov / codecov/patch

src/model_ode.cpp#L105

Added line #L105 was not covered by tests
}

if(!state_.root_enabled[ie]) {
// If the root function is disabled, mask it
root[ie] = (state_.root_last_sign[ie] > 0) ? 1.0 : -1.0;
}
}
}

void Model_ODE::fxdot(
Expand Down
Loading