Skip to content

Commit 35c3492

Browse files
authored
Let froot depend on w (#3031)
This should hopefully allow for more efficient computation of `root`-derivatives by avoiding flattening `w` into `root` which, so far, made computing `drootdt_total` prohibitively expensive in case of large `w` dependencies in `root`.
1 parent c84db32 commit 35c3492

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+372
-226
lines changed

include/amici/abstract_model.h

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ class AbstractModel {
299299
* @param p parameter vector
300300
* @param k constant vector
301301
* @param h Heaviside vector
302+
* @param w vector with helper variables
302303
* @param dx time derivative of state (DAE only)
303304
* @param tcl total abundances for conservation laws
304305
* @param sx current state sensitivity
@@ -307,8 +308,9 @@ class AbstractModel {
307308
*/
308309
virtual void fstau(
309310
realtype* stau, realtype t, realtype const* x, realtype const* p,
310-
realtype const* k, realtype const* h, realtype const* dx,
311-
realtype const* tcl, realtype const* sx, int ip, int ie
311+
realtype const* k, realtype const* h, realtype const* w,
312+
realtype const* dx, realtype const* tcl, realtype const* sx, int ip,
313+
int ie
312314
);
313315

314316
/**
@@ -542,6 +544,7 @@ class AbstractModel {
542544
* @param p parameter vector
543545
* @param k constant vector
544546
* @param h Heaviside vector
547+
* @param w vector with helper variables
545548
* @param dx time derivative of state (DAE only)
546549
* @param ie event index
547550
* @param xdot new model right hand side
@@ -552,9 +555,10 @@ class AbstractModel {
552555
*/
553556
virtual void fdeltaxB(
554557
realtype* deltaxB, realtype t, realtype const* x, realtype const* p,
555-
realtype const* k, realtype const* h, realtype const* dx, int ie,
556-
realtype const* xdot, realtype const* xdot_old, realtype const* x_old,
557-
realtype const* xB, realtype const* tcl
558+
realtype const* k, realtype const* h, realtype const* w,
559+
realtype const* dx, int ie, realtype const* xdot,
560+
realtype const* xdot_old, realtype const* x_old, realtype const* xB,
561+
realtype const* tcl
558562
);
559563

560564
/**
@@ -565,6 +569,7 @@ class AbstractModel {
565569
* @param p parameter vector
566570
* @param k constant vector
567571
* @param h Heaviside vector
572+
* @param w vector with helper variables
568573
* @param dx time derivative of state (DAE only)
569574
* @param ip sensitivity index
570575
* @param ie event index
@@ -575,9 +580,9 @@ class AbstractModel {
575580
*/
576581
virtual void fdeltaqB(
577582
realtype* deltaqB, realtype t, realtype const* x, realtype const* p,
578-
realtype const* k, realtype const* h, realtype const* dx, int ip,
579-
int ie, realtype const* xdot, realtype const* xdot_old,
580-
realtype const* x_old, realtype const* xB
583+
realtype const* k, realtype const* h, realtype const* w,
584+
realtype const* dx, int ip, int ie, realtype const* xdot,
585+
realtype const* xdot_old, realtype const* x_old, realtype const* xB
581586
);
582587

583588
/**
@@ -1059,11 +1064,13 @@ class AbstractModel {
10591064
* @brief Compute explicit roots of the model.
10601065
* @param p parameter vector
10611066
* @param k constant vector
1067+
* @param w vector with helper variables
10621068
* @return A vector of length ne_solver, each containing a vector of
10631069
* explicit roots for the corresponding event.
10641070
*/
10651071
virtual std::vector<std::vector<realtype>> fexplicit_roots(
1066-
[[maybe_unused]] realtype const* p, [[maybe_unused]] realtype const* k
1072+
[[maybe_unused]] realtype const* p, [[maybe_unused]] realtype const* k,
1073+
[[maybe_unused]] realtype const* w
10671074
) = 0;
10681075
};
10691076

include/amici/model.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1591,7 +1591,8 @@ class Model : public AbstractModel, public ModelDimensions {
15911591
}
15921592

15931593
[[nodiscard]] std::vector<std::vector<realtype>> fexplicit_roots(
1594-
[[maybe_unused]] realtype const* p, [[maybe_unused]] realtype const* k
1594+
[[maybe_unused]] realtype const* p, [[maybe_unused]] realtype const* k,
1595+
[[maybe_unused]] realtype const* w
15951596
) override {
15961597
if (ne != ne_solver) {
15971598
throw AmiException(

include/amici/model_dae.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,11 +359,13 @@ class Model_DAE : public Model {
359359
* @param p parameter vector
360360
* @param k constants vector
361361
* @param h Heaviside vector
362+
* @param w vector with helper variables
362363
* @param dx Vector with the derivative states
363364
**/
364365
virtual void froot(
365366
realtype* root, realtype t, realtype const* x, double const* p,
366-
double const* k, realtype const* h, realtype const* dx
367+
double const* k, realtype const* h, realtype const* w,
368+
realtype const* dx
367369
);
368370

369371
/**
@@ -444,7 +446,7 @@ class Model_DAE : public Model {
444446
* @param x Vector with the states
445447
* @param p parameter vector
446448
* @param k constants vector
447-
* @param h heavyside vector
449+
* @param h Heaviside vector
448450
* @param dx Vector with the derivative states
449451
* @param w vector with helper variables
450452
*/

include/amici/model_ode.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,13 @@ class Model_ODE : public Model {
316316
* @param p parameter vector
317317
* @param k constants vector
318318
* @param h Heaviside vector
319+
* @param w vector with helper variables
319320
* @param tcl total abundances for conservation laws
320321
**/
321322
virtual void froot(
322323
realtype* root, realtype t, realtype const* x, realtype const* p,
323-
realtype const* k, realtype const* h, realtype const* tcl
324+
realtype const* k, realtype const* h, realtype const* w,
325+
realtype const* tcl
324326
);
325327

326328
/**

models/model_calvetti_py/explicit_roots.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
#include <algorithm>
55
#include <vector>
66
#include "k.h"
7+
#include "w.h"
78

89
namespace amici {
910
namespace model_model_calvetti_py {
1011

11-
std::vector<std::vector<realtype>> explicit_roots_model_calvetti_py(const realtype *p, const realtype *k){
12+
std::vector<std::vector<realtype>> explicit_roots_model_calvetti_py(const realtype *p, const realtype *k, const realtype *w){
1213
return {
1314
{10},
1415
{10},

models/model_calvetti_py/model_calvetti_py.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ extern void dJydy_rowvals_model_calvetti_py(SUNMatrixWrapper &rowvals, int index
3838

3939

4040

41-
extern void root_model_calvetti_py(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *tcl);
41+
extern void root_model_calvetti_py(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *tcl);
4242

4343

4444

@@ -99,7 +99,7 @@ extern void x_solver_model_calvetti_py(realtype *x_solver, const realtype *x_rda
9999
extern std::vector<HermiteSpline> create_splines_model_calvetti_py(const realtype *p, const realtype *k);
100100

101101

102-
extern std::vector<std::vector<realtype>> explicit_roots_model_calvetti_py(const realtype *p, const realtype *k);
102+
extern std::vector<std::vector<realtype>> explicit_roots_model_calvetti_py(const realtype *p, const realtype *k, const realtype *w);
103103
/**
104104
* @brief AMICI-generated model subclass.
105105
*/
@@ -200,10 +200,10 @@ class Model_model_calvetti_py : public amici::Model_DAE {
200200
void fdeltasx(realtype *deltasx, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *sx, const realtype *stau, const realtype *tcl, const realtype *x_old) override {}
201201

202202

203-
void fdeltaxB(realtype *deltaxB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *dx, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB, const realtype *tcl) override {}
203+
void fdeltaxB(realtype *deltaxB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *dx, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB, const realtype *tcl) override {}
204204

205205

206-
void fdeltaqB(realtype *deltaqB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *dx, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB) override {}
206+
void fdeltaqB(realtype *deltaqB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *dx, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB) override {}
207207

208208

209209
void fdrzdp(realtype *drzdp, const int ie, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const int ip) override {}
@@ -323,8 +323,8 @@ class Model_model_calvetti_py : public amici::Model_DAE {
323323
void fdzdx(realtype *dzdx, const int ie, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h) override {}
324324

325325

326-
void froot(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *tcl) override {
327-
root_model_calvetti_py(root, t, x, p, k, h, tcl);
326+
void froot(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *tcl) override {
327+
root_model_calvetti_py(root, t, x, p, k, h, w, tcl);
328328
}
329329

330330

@@ -339,7 +339,7 @@ class Model_model_calvetti_py : public amici::Model_DAE {
339339
void fsigmaz(realtype *sigmaz, const realtype t, const realtype *p, const realtype *k) override {}
340340

341341

342-
void fstau(realtype *stau, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *dx, const realtype *tcl, const realtype *sx, const int ip, const int ie) override {}
342+
void fstau(realtype *stau, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *dx, const realtype *tcl, const realtype *sx, const int ip, const int ie) override {}
343343

344344
void fsx0(realtype *sx0, const realtype t, const realtype *x, const realtype *p, const realtype *k, const int ip) override {}
345345

@@ -411,8 +411,8 @@ class Model_model_calvetti_py : public amici::Model_DAE {
411411
void fdtotal_cldx_rdata_rowvals(SUNMatrixWrapper &rowvals) override {}
412412

413413

414-
std::vector<std::vector<realtype>> fexplicit_roots(const realtype *p, const realtype *k) override {
415-
return explicit_roots_model_calvetti_py(p, k);
414+
std::vector<std::vector<realtype>> fexplicit_roots(const realtype *p, const realtype *k, const realtype *w) override {
415+
return explicit_roots_model_calvetti_py(p, k, w);
416416
}
417417

418418

@@ -557,7 +557,7 @@ class Model_model_calvetti_py : public amici::Model_DAE {
557557
* @return AMICI git commit hash
558558
*/
559559
std::string get_amici_commit() const override {
560-
return "40190b46b1b398e321314ded4169fe910b37c484";
560+
return "3fb84cd5df12639f17b179d681e8ba4b5be8a160";
561561
}
562562

563563
bool has_quadratic_llh() const override {

models/model_calvetti_py/root.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
#include "x.h"
66
#include "k.h"
77
#include "h.h"
8+
#include "w.h"
89

910
namespace amici {
1011
namespace model_model_calvetti_py {
1112

12-
void root_model_calvetti_py(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *tcl){
13+
void root_model_calvetti_py(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *tcl){
1314
root[0] = t - 10;
1415
root[1] = 10 - t;
1516
root[2] = 12 - t;

models/model_dirac_py/deltaqB.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "x.h"
66
#include "p.h"
77
#include "h.h"
8+
#include "w.h"
89
#include "xdot.h"
910
#include "xdot_old.h"
1011
#include "x_old.h"
@@ -13,7 +14,7 @@
1314
namespace amici {
1415
namespace model_model_dirac_py {
1516

16-
void deltaqB_model_dirac_py(realtype *deltaqB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *dx, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB){
17+
void deltaqB_model_dirac_py(realtype *deltaqB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *dx, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB){
1718
switch(ie) {
1819
case 0:
1920
switch(ip) {

models/model_dirac_py/explicit_roots.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
namespace amici {
99
namespace model_model_dirac_py {
1010

11-
std::vector<std::vector<realtype>> explicit_roots_model_dirac_py(const realtype *p, const realtype *k){
11+
std::vector<std::vector<realtype>> explicit_roots_model_dirac_py(const realtype *p, const realtype *k, const realtype *w){
1212
return {
1313
{p2}
1414
};

models/model_dirac_py/model_dirac_py.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ extern void dJydy_rowvals_model_dirac_py(SUNMatrixWrapper &rowvals, int index);
3838

3939

4040

41-
extern void root_model_dirac_py(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *tcl);
41+
extern void root_model_dirac_py(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *tcl);
4242

4343

4444

@@ -77,11 +77,11 @@ extern void xdot_model_dirac_py(realtype *xdot, const realtype t, const realtype
7777
extern void y_model_dirac_py(realtype *y, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w);
7878

7979

80-
extern void stau_model_dirac_py(realtype *stau, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *dx, const realtype *tcl, const realtype *sx, const int ip, const int ie);
80+
extern void stau_model_dirac_py(realtype *stau, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *dx, const realtype *tcl, const realtype *sx, const int ip, const int ie);
8181
extern void deltax_model_dirac_py(double *deltax, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old);
8282
extern void deltasx_model_dirac_py(realtype *deltasx, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *sx, const realtype *stau, const realtype *tcl, const realtype *x_old);
8383

84-
extern void deltaqB_model_dirac_py(realtype *deltaqB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *dx, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB);
84+
extern void deltaqB_model_dirac_py(realtype *deltaqB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *dx, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB);
8585

8686
extern void x_solver_model_dirac_py(realtype *x_solver, const realtype *x_rdata);
8787

@@ -99,7 +99,7 @@ extern void x_solver_model_dirac_py(realtype *x_solver, const realtype *x_rdata)
9999
extern std::vector<HermiteSpline> create_splines_model_dirac_py(const realtype *p, const realtype *k);
100100

101101

102-
extern std::vector<std::vector<realtype>> explicit_roots_model_dirac_py(const realtype *p, const realtype *k);
102+
extern std::vector<std::vector<realtype>> explicit_roots_model_dirac_py(const realtype *p, const realtype *k, const realtype *w);
103103
/**
104104
* @brief AMICI-generated model subclass.
105105
*/
@@ -201,11 +201,11 @@ class Model_model_dirac_py : public amici::Model_ODE {
201201
}
202202

203203

204-
void fdeltaxB(realtype *deltaxB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *dx, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB, const realtype *tcl) override {}
204+
void fdeltaxB(realtype *deltaxB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *dx, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB, const realtype *tcl) override {}
205205

206206

207-
void fdeltaqB(realtype *deltaqB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *dx, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB) override {
208-
deltaqB_model_dirac_py(deltaqB, t, x, p, k, h, dx, ip, ie, xdot, xdot_old, x_old, xB);
207+
void fdeltaqB(realtype *deltaqB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *dx, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB) override {
208+
deltaqB_model_dirac_py(deltaqB, t, x, p, k, h, w, dx, ip, ie, xdot, xdot_old, x_old, xB);
209209
}
210210

211211

@@ -314,8 +314,8 @@ class Model_model_dirac_py : public amici::Model_ODE {
314314
void fdzdx(realtype *dzdx, const int ie, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h) override {}
315315

316316

317-
void froot(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *tcl) override {
318-
root_model_dirac_py(root, t, x, p, k, h, tcl);
317+
void froot(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *tcl) override {
318+
root_model_dirac_py(root, t, x, p, k, h, w, tcl);
319319
}
320320

321321

@@ -330,8 +330,8 @@ class Model_model_dirac_py : public amici::Model_ODE {
330330
void fsigmaz(realtype *sigmaz, const realtype t, const realtype *p, const realtype *k) override {}
331331

332332

333-
void fstau(realtype *stau, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *dx, const realtype *tcl, const realtype *sx, const int ip, const int ie) override {
334-
stau_model_dirac_py(stau, t, x, p, k, h, dx, tcl, sx, ip, ie);
333+
void fstau(realtype *stau, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *dx, const realtype *tcl, const realtype *sx, const int ip, const int ie) override {
334+
stau_model_dirac_py(stau, t, x, p, k, h, w, dx, tcl, sx, ip, ie);
335335
}
336336

337337
void fsx0(realtype *sx0, const realtype t, const realtype *x, const realtype *p, const realtype *k, const int ip) override {}
@@ -398,8 +398,8 @@ class Model_model_dirac_py : public amici::Model_ODE {
398398
void fdtotal_cldx_rdata_rowvals(SUNMatrixWrapper &rowvals) override {}
399399

400400

401-
std::vector<std::vector<realtype>> fexplicit_roots(const realtype *p, const realtype *k) override {
402-
return explicit_roots_model_dirac_py(p, k);
401+
std::vector<std::vector<realtype>> fexplicit_roots(const realtype *p, const realtype *k, const realtype *w) override {
402+
return explicit_roots_model_dirac_py(p, k, w);
403403
}
404404

405405

@@ -544,7 +544,7 @@ class Model_model_dirac_py : public amici::Model_ODE {
544544
* @return AMICI git commit hash
545545
*/
546546
std::string get_amici_commit() const override {
547-
return "40190b46b1b398e321314ded4169fe910b37c484";
547+
return "3fb84cd5df12639f17b179d681e8ba4b5be8a160";
548548
}
549549

550550
bool has_quadratic_llh() const override {

0 commit comments

Comments
 (0)