Skip to content

Commit 31b0a07

Browse files
committed
Let froot depend on w
WIP This should hopefully allow for more efficient computation of `root`-derivatives by avoiding flattening `w` into `root` which, so far, made computing `drootdt_total` prohibitively expensive in case of large `w` dependencies in `root`.
1 parent ebbc2ec commit 31b0a07

File tree

12 files changed

+161
-70
lines changed

12 files changed

+161
-70
lines changed

cmake/AmiciFindBLAS.cmake

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ if(DEFINED ENV{AMICI_BLAS_USE_SCIPY_OPENBLAS})
1515
"Using AMICI_BLAS_USE_SCIPY_OPENBLAS=${AMICI_BLAS_USE_SCIPY_OPENBLAS} from environment variable."
1616
)
1717
set(AMICI_BLAS_USE_SCIPY_OPENBLAS $ENV{AMICI_BLAS_USE_SCIPY_OPENBLAS})
18+
elseif(NOT DEFINED AMICI_BLAS_USE_SCIPY_OPENBLAS
19+
AND NOT AMICI_PYTHON_BUILD_EXT_ONLY)
20+
# If were are not building the Python extension, it's unlikely that we want to
21+
# use scipy-openblas
22+
set(AMICI_BLAS_USE_SCIPY_OPENBLAS FALSE)
1823
endif()
1924

2025
if((${BLAS} STREQUAL "MKL" OR DEFINED ENV{MKLROOT})

include/amici/abstract_model.h

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ class AbstractModel {
299299
* @param p parameter vector
300300
* @param k constant vector
301301
* @param h Heaviside vector
302+
* @param w vector with helper variables
302303
* @param dx time derivative of state (DAE only)
303304
* @param tcl total abundances for conservation laws
304305
* @param sx current state sensitivity
@@ -307,8 +308,9 @@ class AbstractModel {
307308
*/
308309
virtual void fstau(
309310
realtype* stau, realtype t, realtype const* x, realtype const* p,
310-
realtype const* k, realtype const* h, realtype const* dx,
311-
realtype const* tcl, realtype const* sx, int ip, int ie
311+
realtype const* k, realtype const* h, realtype const* w,
312+
realtype const* dx, realtype const* tcl, realtype const* sx, int ip,
313+
int ie
312314
);
313315

314316
/**
@@ -542,6 +544,7 @@ class AbstractModel {
542544
* @param p parameter vector
543545
* @param k constant vector
544546
* @param h Heaviside vector
547+
* @param w vector with helper variables
545548
* @param dx time derivative of state (DAE only)
546549
* @param ie event index
547550
* @param xdot new model right hand side
@@ -552,9 +555,10 @@ class AbstractModel {
552555
*/
553556
virtual void fdeltaxB(
554557
realtype* deltaxB, realtype t, realtype const* x, realtype const* p,
555-
realtype const* k, realtype const* h, realtype const* dx, int ie,
556-
realtype const* xdot, realtype const* xdot_old, realtype const* x_old,
557-
realtype const* xB, realtype const* tcl
558+
realtype const* k, realtype const* h, realtype const* w,
559+
realtype const* dx, int ie, realtype const* xdot,
560+
realtype const* xdot_old, realtype const* x_old, realtype const* xB,
561+
realtype const* tcl
558562
);
559563

560564
/**
@@ -565,6 +569,7 @@ class AbstractModel {
565569
* @param p parameter vector
566570
* @param k constant vector
567571
* @param h Heaviside vector
572+
* @param w vector with helper variables
568573
* @param dx time derivative of state (DAE only)
569574
* @param ip sensitivity index
570575
* @param ie event index
@@ -575,9 +580,9 @@ class AbstractModel {
575580
*/
576581
virtual void fdeltaqB(
577582
realtype* deltaqB, realtype t, realtype const* x, realtype const* p,
578-
realtype const* k, realtype const* h, realtype const* dx, int ip,
579-
int ie, realtype const* xdot, realtype const* xdot_old,
580-
realtype const* x_old, realtype const* xB
583+
realtype const* k, realtype const* h, realtype const* w,
584+
realtype const* dx, int ip, int ie, realtype const* xdot,
585+
realtype const* xdot_old, realtype const* x_old, realtype const* xB
581586
);
582587

583588
/**

include/amici/model_dae.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,11 +359,13 @@ class Model_DAE : public Model {
359359
* @param p parameter vector
360360
* @param k constants vector
361361
* @param h Heaviside vector
362+
* @param w vector with helper variables
362363
* @param dx Vector with the derivative states
363364
**/
364365
virtual void froot(
365366
realtype* root, realtype t, realtype const* x, double const* p,
366-
double const* k, realtype const* h, realtype const* dx
367+
double const* k, realtype const* h, realtype const* w,
368+
realtype const* dx
367369
);
368370

369371
/**

include/amici/model_ode.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,13 @@ class Model_ODE : public Model {
316316
* @param p parameter vector
317317
* @param k constants vector
318318
* @param h Heaviside vector
319+
* @param w vector with helper variables
319320
* @param tcl total abundances for conservation laws
320321
**/
321322
virtual void froot(
322323
realtype* root, realtype t, realtype const* x, realtype const* p,
323-
realtype const* k, realtype const* h, realtype const* tcl
324+
realtype const* k, realtype const* h, realtype const* w,
325+
realtype const* tcl
324326
);
325327

326328
/**

python/sdist/amici/de_model.py

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,8 @@ def __init__(
296296
] = {
297297
"sroot": {
298298
"eq": "root",
299+
# TODO?
300+
# "chainvars": ["x", "w"],
299301
"chainvars": ["x"],
300302
"var": "p",
301303
"dxdz_name": "sx",
@@ -528,17 +530,9 @@ def get_rate(symbol: sp.Symbol):
528530

529531
for component in chain(
530532
self.observables(),
531-
self.events(),
532533
self._algebraic_equations,
533534
):
534535
if rate_ofs := component.get_val().find(rate_of_func):
535-
if isinstance(component, Event):
536-
# TODO froot(...) can currently not depend on `w`, so this substitution fails for non-zero rates
537-
# see, e.g., sbml test case 01293
538-
raise SBMLException(
539-
"AMICI does currently not support rateOf(.) inside event trigger functions."
540-
)
541-
542536
if isinstance(component, AlgebraicEquation):
543537
# TODO IDACalcIC fails with
544538
# "The linesearch algorithm failed: step too small or too many backtracks."
@@ -1340,10 +1334,10 @@ def parse_events(self) -> None:
13401334
# add roots of heaviside functions
13411335
self.add_component(root)
13421336

1343-
# Substitute 'w' expressions into root expressions, to avoid rewriting
1344-
# 'root.cpp' and 'stau.cpp' headers to include 'w.h'.
1345-
for event in self.events():
1346-
event.set_val(event.get_val().subs(w_toposorted))
1337+
# # Substitute 'w' expressions into root expressions, to avoid rewriting
1338+
# # 'root.cpp' and 'stau.cpp' headers to include 'w.h'.
1339+
# for event in self.events():
1340+
# event.set_val(event.get_val().subs(w_toposorted))
13471341

13481342
# re-order events - first those that require root tracking, then the others
13491343
constant_syms = set(self.sym("k")) | set(self.sym("p"))
@@ -1668,18 +1662,23 @@ def _compute_equation(self, name: str) -> None:
16681662
self._eqs[name] = smart_jacobian(self.eq("root"), time_symbol)
16691663

16701664
elif name == "drootdt_total":
1665+
# root(t, x(t), w(t, x(t)))
1666+
# drootdt_total = drootdt + drootdx * dxdt + drootdw * dwdt_total
1667+
# dwdt_total = dwdt + dwdx * dxdt
16711668
self._eqs[name] = self.eq("drootdt")
1672-
# backsubstitution of optimized right-hand side terms into RHS
1673-
# calling subs() is costly. We can skip it if we don't have any
1674-
# state-dependent roots.
1669+
1670+
xdot = self.eq("xdot")
16751671
if self.num_states_solver() and not smart_is_zero_matrix(
16761672
drootdx := self.eq("drootdx")
16771673
):
1678-
w_sorted = toposort_symbols(
1679-
dict(zip(self.sym("w"), self.eq("w"), strict=True))
1680-
)
1681-
tmp_xdot = smart_subs_dict(self.eq("xdot"), w_sorted)
1682-
self._eqs[name] += smart_multiply(drootdx, tmp_xdot)
1674+
self._eqs[name] += smart_multiply(drootdx, xdot)
1675+
1676+
drootdw = self.eq("drootdw")
1677+
dwdt = self.eq("dwdt")
1678+
dwdx = self.eq("dwdx")
1679+
dwdt_total = dwdt + smart_multiply(dwdx, xdot)
1680+
1681+
self._eqs[name] += smart_multiply(drootdw, dwdt_total)
16831682

16841683
elif name == "deltax":
16851684
# fill boluses for Heaviside functions, as empty state updates
@@ -1763,16 +1762,26 @@ def _compute_equation(self, name: str) -> None:
17631762
]
17641763

17651764
elif name == "dtaudx":
1765+
# TODO drootdx + drootdw * dwdx
17661766
self._eqs[name] = [
1767-
self.eq("drootdx")[ie, :] / self.eq("drootdt_total")[ie]
1767+
(
1768+
self.eq("drootdx")[ie, :]
1769+
+ self.eq("drootdw")[ie, :] * self.eq("dwdx")
1770+
)
1771+
/ self.eq("drootdt_total")[ie]
17681772
if not self.eq("drootdt_total")[ie].is_zero
17691773
else sp.zeros(*self.eq("drootdx")[ie, :].shape)
17701774
for ie in range(self.num_events())
17711775
]
17721776

17731777
elif name == "dtaudp":
1778+
# TODO drootdp + drootdw * dwdp
17741779
self._eqs[name] = [
1775-
self.eq("drootdp")[ie, :] / self.eq("drootdt_total")[ie]
1780+
(
1781+
self.eq("drootdp")[ie, :]
1782+
+ self.eq("drootdw")[ie, :] * self.eq("dwdp")
1783+
)
1784+
/ self.eq("drootdt_total")[ie]
17761785
if not self.eq("drootdt_total")[ie].is_zero
17771786
else sp.zeros(*self.eq("drootdp")[ie, :].shape)
17781787
for ie in range(self.num_events())
@@ -1922,6 +1931,9 @@ def _compute_equation(self, name: str) -> None:
19221931
smart_jacobian(self.eq("w")[self.num_cons_law() :, :], x)
19231932
)
19241933

1934+
elif name == "dwdt":
1935+
self._eqs[name] = smart_jacobian(self.eq("w"), time_symbol)
1936+
19251937
elif name == "iroot":
19261938
self._eqs[name] = sp.Matrix(
19271939
[
@@ -2110,7 +2122,9 @@ def _derivative(self, eq: str, var: str, name: str = None) -> None:
21102122
"attach this model."
21112123
)
21122124

2113-
if name == "dydw" and not smart_is_zero_matrix(derivative):
2125+
elif name in ("dydw", "drootdw") and not smart_is_zero_matrix(
2126+
derivative
2127+
):
21142128
dwdw = self.eq("dwdw")
21152129
# h(k) = d{eq}dw*dwdw^k* (k=1)
21162130
h = smart_multiply(derivative, dwdw)
@@ -2407,6 +2421,8 @@ def _expr_is_time_dependent(self, expr: sp.Expr) -> bool:
24072421
:returns:
24082422
Whether the expression is time-dependent.
24092423
"""
2424+
# TODO: handle w-dependency
2425+
24102426
# `expr.free_symbols` will be different to `self._states.keys()`, so
24112427
# it's easier to compare as `str`.
24122428
expr_syms = {str(sym) for sym in expr.free_symbols}
@@ -2515,13 +2531,9 @@ def _process_heavisides(
25152531
tmp_roots_old = self._collect_heaviside_roots((dxdt,))
25162532
# substitute 'w' symbols in the root expression by their equations,
25172533
# because currently,
2518-
# 1) root functions must not depend on 'w'
2519-
# 2) the check for time-dependence currently assumes only state
2534+
# # 1) root functions must not depend on 'w'
2535+
# FIXME 2) the check for time-dependence currently assumes only state
25202536
# variables are implicitly time-dependent
2521-
tmp_roots_old = [
2522-
(a.subs(w_toposorted), b.subs(w_toposorted))
2523-
for a, b in tmp_roots_old
2524-
]
25252537
for tmp_root_old, tmp_x0_old in unique_preserve_order(tmp_roots_old):
25262538
# we want unique identifiers for the roots
25272539
tmp_root_new = self._get_unique_root(tmp_root_old, roots)

python/sdist/amici/exporters/sundials/cxx_functions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def var_in_signature(self, varname: str, ode: bool = True) -> bool:
144144
"root": _FunctionInfo(
145145
"realtype *root, const realtype t, const realtype *x, "
146146
"const realtype *p, const realtype *k, const realtype *h, "
147-
"const realtype *tcl"
147+
"const realtype *w, const realtype *tcl"
148148
),
149149
"dwdp": _FunctionInfo(
150150
"realtype *dwdp, const realtype t, const realtype *x, "
@@ -288,7 +288,7 @@ def var_in_signature(self, varname: str, ode: bool = True) -> bool:
288288
"stau": _FunctionInfo(
289289
"realtype *stau, const realtype t, const realtype *x, "
290290
"const realtype *p, const realtype *k, const realtype *h, "
291-
"const realtype *dx, "
291+
"const realtype *w, const realtype *dx, "
292292
"const realtype *tcl, const realtype *sx, const int ip, "
293293
"const int ie"
294294
),
@@ -313,15 +313,15 @@ def var_in_signature(self, varname: str, ode: bool = True) -> bool:
313313
"deltaxB": _FunctionInfo(
314314
"realtype *deltaxB, const realtype t, const realtype *x, "
315315
"const realtype *p, const realtype *k, const realtype *h, "
316-
"const realtype *dx, "
316+
"const realtype *w, const realtype *dx, "
317317
"const int ie, const realtype *xdot, const realtype *xdot_old, "
318318
"const realtype *x_old, "
319319
"const realtype *xB, const realtype *tcl"
320320
),
321321
"deltaqB": _FunctionInfo(
322322
"realtype *deltaqB, const realtype t, const realtype *x, "
323323
"const realtype *p, const realtype *k, const realtype *h, "
324-
"const realtype *dx, "
324+
"const realtype *w, const realtype *dx, "
325325
"const int ip, const int ie, const realtype *xdot, "
326326
"const realtype *xdot_old, const realtype *x_old, const realtype *xB"
327327
),

python/sdist/amici/gradient_check.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -133,17 +133,22 @@ def check_finite_difference(
133133
else:
134134
raise NotImplementedError()
135135

136-
_check_close(
137-
sensi,
138-
fd,
139-
atol=atol,
140-
rtol=rtol,
141-
field=field,
142-
ip=ip,
143-
parameter_id=model.get_parameter_ids()[ip]
144-
if model.has_parameter_ids()
145-
else None,
146-
)
136+
try:
137+
_check_close(
138+
sensi,
139+
fd,
140+
atol=atol,
141+
rtol=rtol,
142+
field=field,
143+
ip=ip,
144+
parameter_id=model.get_parameter_ids()[ip]
145+
if model.has_parameter_ids()
146+
else None,
147+
)
148+
except AssertionError as e:
149+
sm = SensitivityMethod(solver.get_sensitivity_method())
150+
e.add_note(f"Sensitivity method was {sm!r}")
151+
raise e
147152

148153
solver.set_sensitivity_order(og_sensitivity_order)
149154
model.set_parameters(og_parameters)

python/tests/test_events.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,3 +1294,58 @@ def test_gh2926(tempdir):
12941294
rdata = amici.run_simulation(model, solver)
12951295
assert rdata.status == amici.AMICI_SUCCESS
12961296
assert rdata.by_id("x1").tolist() == [1.0, 1.0, 2.0, 2.0]
1297+
1298+
1299+
@skip_on_valgrind
1300+
def test_event_with_w_dependent_trigger(tempdir):
1301+
"""Test sensitivities for events with trigger depending on
1302+
cascading expressions in `w`."""
1303+
1304+
model_name = "test_event_with_w_dependent_trigger"
1305+
model = antimony2amici(
1306+
r"""
1307+
one = 1
1308+
two = 2
1309+
a := two^2 - 2 # 2
1310+
b := a^2 - 1 # 3
1311+
c := b * 2 + x / 10 # 6
1312+
d := c + a + (a - 1) * time / 10 # -> d = 8 + time / 5
1313+
1314+
x = 0
1315+
x' = 1
1316+
target = 0
1317+
1318+
E1: at x >= d: # triggers at time = 8 + time / 5 <=> time = 10
1319+
target = x + one;
1320+
""",
1321+
model_name=model_name,
1322+
output_dir=tempdir,
1323+
)
1324+
1325+
model.set_timepoints([0, 5, 9, 11])
1326+
solver = model.create_solver()
1327+
solver.set_sensitivity_order(SensitivityOrder.first)
1328+
solver.set_sensitivity_method(SensitivityMethod.forward)
1329+
1330+
# generate synthetic measurements
1331+
rdata = amici.run_simulation(model, solver)
1332+
assert rdata.status == amici.AMICI_SUCCESS
1333+
# check that event triggered correctly
1334+
assert np.isclose(rdata.by_id("target")[-1], 11.0)
1335+
edata = amici.ExpData(rdata, 1, 0)
1336+
1337+
# check sensitivities against finite differences
1338+
1339+
for sens_method in (
1340+
SensitivityMethod.forward,
1341+
SensitivityMethod.adjoint,
1342+
):
1343+
solver.set_sensitivity_method(sens_method)
1344+
check_derivatives(
1345+
model,
1346+
solver,
1347+
edata=edata,
1348+
atol=1e-6,
1349+
rtol=1e-6,
1350+
epsilon=1e-8,
1351+
)

0 commit comments

Comments
 (0)