Skip to content

Commit 5272301

Browse files
committed
Let froot depend on w
WIP This should hopefully allow for more efficient computation of `root`-derivatives by avoiding flattening `w` into `root` which, so far, made computing `drootdt_total` prohibitively expensive in case of large `w` dependencies in `root`.
1 parent e75c348 commit 5272301

File tree

35 files changed

+257
-158
lines changed

35 files changed

+257
-158
lines changed

include/amici/abstract_model.h

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ class AbstractModel {
299299
* @param p parameter vector
300300
* @param k constant vector
301301
* @param h Heaviside vector
302+
* @param w vector with helper variables
302303
* @param dx time derivative of state (DAE only)
303304
* @param tcl total abundances for conservation laws
304305
* @param sx current state sensitivity
@@ -307,8 +308,9 @@ class AbstractModel {
307308
*/
308309
virtual void fstau(
309310
realtype* stau, realtype t, realtype const* x, realtype const* p,
310-
realtype const* k, realtype const* h, realtype const* dx,
311-
realtype const* tcl, realtype const* sx, int ip, int ie
311+
realtype const* k, realtype const* h, realtype const* w,
312+
realtype const* dx, realtype const* tcl, realtype const* sx, int ip,
313+
int ie
312314
);
313315

314316
/**
@@ -542,6 +544,7 @@ class AbstractModel {
542544
* @param p parameter vector
543545
* @param k constant vector
544546
* @param h Heaviside vector
547+
* @param w vector with helper variables
545548
* @param dx time derivative of state (DAE only)
546549
* @param ie event index
547550
* @param xdot new model right hand side
@@ -552,9 +555,10 @@ class AbstractModel {
552555
*/
553556
virtual void fdeltaxB(
554557
realtype* deltaxB, realtype t, realtype const* x, realtype const* p,
555-
realtype const* k, realtype const* h, realtype const* dx, int ie,
556-
realtype const* xdot, realtype const* xdot_old, realtype const* x_old,
557-
realtype const* xB, realtype const* tcl
558+
realtype const* k, realtype const* h, realtype const* w,
559+
realtype const* dx, int ie, realtype const* xdot,
560+
realtype const* xdot_old, realtype const* x_old, realtype const* xB,
561+
realtype const* tcl
558562
);
559563

560564
/**
@@ -565,6 +569,7 @@ class AbstractModel {
565569
* @param p parameter vector
566570
* @param k constant vector
567571
* @param h Heaviside vector
572+
* @param w vector with helper variables
568573
* @param dx time derivative of state (DAE only)
569574
* @param ip sensitivity index
570575
* @param ie event index
@@ -575,9 +580,9 @@ class AbstractModel {
575580
*/
576581
virtual void fdeltaqB(
577582
realtype* deltaqB, realtype t, realtype const* x, realtype const* p,
578-
realtype const* k, realtype const* h, realtype const* dx, int ip,
579-
int ie, realtype const* xdot, realtype const* xdot_old,
580-
realtype const* x_old, realtype const* xB
583+
realtype const* k, realtype const* h, realtype const* w,
584+
realtype const* dx, int ip, int ie, realtype const* xdot,
585+
realtype const* xdot_old, realtype const* x_old, realtype const* xB
581586
);
582587

583588
/**

include/amici/model_dae.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,11 +359,13 @@ class Model_DAE : public Model {
359359
* @param p parameter vector
360360
* @param k constants vector
361361
* @param h Heaviside vector
362+
* @param w vector with helper variables
362363
* @param dx Vector with the derivative states
363364
**/
364365
virtual void froot(
365366
realtype* root, realtype t, realtype const* x, double const* p,
366-
double const* k, realtype const* h, realtype const* dx
367+
double const* k, realtype const* h, realtype const* w,
368+
realtype const* dx
367369
);
368370

369371
/**

include/amici/model_ode.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,13 @@ class Model_ODE : public Model {
316316
* @param p parameter vector
317317
* @param k constants vector
318318
* @param h Heaviside vector
319+
* @param w vector with helper variables
319320
* @param tcl total abundances for conservation laws
320321
**/
321322
virtual void froot(
322323
realtype* root, realtype t, realtype const* x, realtype const* p,
323-
realtype const* k, realtype const* h, realtype const* tcl
324+
realtype const* k, realtype const* h, realtype const* w,
325+
realtype const* tcl
324326
);
325327

326328
/**

models/model_calvetti_py/model_calvetti_py.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ extern void dJydy_rowvals_model_calvetti_py(SUNMatrixWrapper &rowvals, int index
3838

3939

4040

41-
extern void root_model_calvetti_py(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *tcl);
41+
extern void root_model_calvetti_py(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *tcl);
4242

4343

4444

@@ -200,10 +200,10 @@ class Model_model_calvetti_py : public amici::Model_DAE {
200200
void fdeltasx(realtype *deltasx, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *sx, const realtype *stau, const realtype *tcl, const realtype *x_old) override {}
201201

202202

203-
void fdeltaxB(realtype *deltaxB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *dx, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB, const realtype *tcl) override {}
203+
void fdeltaxB(realtype *deltaxB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *dx, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB, const realtype *tcl) override {}
204204

205205

206-
void fdeltaqB(realtype *deltaqB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *dx, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB) override {}
206+
void fdeltaqB(realtype *deltaqB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *dx, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB) override {}
207207

208208

209209
void fdrzdp(realtype *drzdp, const int ie, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const int ip) override {}
@@ -323,8 +323,8 @@ class Model_model_calvetti_py : public amici::Model_DAE {
323323
void fdzdx(realtype *dzdx, const int ie, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h) override {}
324324

325325

326-
void froot(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *tcl) override {
327-
root_model_calvetti_py(root, t, x, p, k, h, tcl);
326+
void froot(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *tcl) override {
327+
root_model_calvetti_py(root, t, x, p, k, h, w, tcl);
328328
}
329329

330330

@@ -339,7 +339,7 @@ class Model_model_calvetti_py : public amici::Model_DAE {
339339
void fsigmaz(realtype *sigmaz, const realtype t, const realtype *p, const realtype *k) override {}
340340

341341

342-
void fstau(realtype *stau, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *dx, const realtype *tcl, const realtype *sx, const int ip, const int ie) override {}
342+
void fstau(realtype *stau, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *dx, const realtype *tcl, const realtype *sx, const int ip, const int ie) override {}
343343

344344
void fsx0(realtype *sx0, const realtype t, const realtype *x, const realtype *p, const realtype *k, const int ip) override {}
345345

@@ -557,7 +557,7 @@ class Model_model_calvetti_py : public amici::Model_DAE {
557557
* @return AMICI git commit hash
558558
*/
559559
std::string get_amici_commit() const override {
560-
return "b0b2684b4b67db9eadf5e47d4f87f8fe74dd9070";
560+
return "unknown";
561561
}
562562

563563
bool has_quadratic_llh() const override {

models/model_calvetti_py/root.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
#include "x.h"
66
#include "k.h"
77
#include "h.h"
8+
#include "w.h"
89

910
namespace amici {
1011
namespace model_model_calvetti_py {
1112

12-
void root_model_calvetti_py(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *tcl){
13+
void root_model_calvetti_py(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *tcl){
1314
root[0] = t - 10;
1415
root[1] = 10 - t;
1516
root[2] = 12 - t;

models/model_dirac_py/deltaqB.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "x.h"
66
#include "p.h"
77
#include "h.h"
8+
#include "w.h"
89
#include "xdot.h"
910
#include "xdot_old.h"
1011
#include "x_old.h"
@@ -13,7 +14,7 @@
1314
namespace amici {
1415
namespace model_model_dirac_py {
1516

16-
void deltaqB_model_dirac_py(realtype *deltaqB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *dx, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB){
17+
void deltaqB_model_dirac_py(realtype *deltaqB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *dx, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB){
1718
switch(ie) {
1819
case 0:
1920
switch(ip) {

models/model_dirac_py/model_dirac_py.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ extern void dJydy_rowvals_model_dirac_py(SUNMatrixWrapper &rowvals, int index);
3838

3939

4040

41-
extern void root_model_dirac_py(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *tcl);
41+
extern void root_model_dirac_py(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *tcl);
4242

4343

4444

@@ -77,11 +77,11 @@ extern void xdot_model_dirac_py(realtype *xdot, const realtype t, const realtype
7777
extern void y_model_dirac_py(realtype *y, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w);
7878

7979

80-
extern void stau_model_dirac_py(realtype *stau, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *dx, const realtype *tcl, const realtype *sx, const int ip, const int ie);
80+
extern void stau_model_dirac_py(realtype *stau, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *dx, const realtype *tcl, const realtype *sx, const int ip, const int ie);
8181
extern void deltax_model_dirac_py(double *deltax, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old);
8282
extern void deltasx_model_dirac_py(realtype *deltasx, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *sx, const realtype *stau, const realtype *tcl, const realtype *x_old);
8383

84-
extern void deltaqB_model_dirac_py(realtype *deltaqB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *dx, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB);
84+
extern void deltaqB_model_dirac_py(realtype *deltaqB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *dx, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB);
8585

8686
extern void x_solver_model_dirac_py(realtype *x_solver, const realtype *x_rdata);
8787

@@ -201,11 +201,11 @@ class Model_model_dirac_py : public amici::Model_ODE {
201201
}
202202

203203

204-
void fdeltaxB(realtype *deltaxB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *dx, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB, const realtype *tcl) override {}
204+
void fdeltaxB(realtype *deltaxB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *dx, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB, const realtype *tcl) override {}
205205

206206

207-
void fdeltaqB(realtype *deltaqB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *dx, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB) override {
208-
deltaqB_model_dirac_py(deltaqB, t, x, p, k, h, dx, ip, ie, xdot, xdot_old, x_old, xB);
207+
void fdeltaqB(realtype *deltaqB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *dx, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB) override {
208+
deltaqB_model_dirac_py(deltaqB, t, x, p, k, h, w, dx, ip, ie, xdot, xdot_old, x_old, xB);
209209
}
210210

211211

@@ -314,8 +314,8 @@ class Model_model_dirac_py : public amici::Model_ODE {
314314
void fdzdx(realtype *dzdx, const int ie, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h) override {}
315315

316316

317-
void froot(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *tcl) override {
318-
root_model_dirac_py(root, t, x, p, k, h, tcl);
317+
void froot(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *tcl) override {
318+
root_model_dirac_py(root, t, x, p, k, h, w, tcl);
319319
}
320320

321321

@@ -330,8 +330,8 @@ class Model_model_dirac_py : public amici::Model_ODE {
330330
void fsigmaz(realtype *sigmaz, const realtype t, const realtype *p, const realtype *k) override {}
331331

332332

333-
void fstau(realtype *stau, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *dx, const realtype *tcl, const realtype *sx, const int ip, const int ie) override {
334-
stau_model_dirac_py(stau, t, x, p, k, h, dx, tcl, sx, ip, ie);
333+
void fstau(realtype *stau, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *dx, const realtype *tcl, const realtype *sx, const int ip, const int ie) override {
334+
stau_model_dirac_py(stau, t, x, p, k, h, w, dx, tcl, sx, ip, ie);
335335
}
336336

337337
void fsx0(realtype *sx0, const realtype t, const realtype *x, const realtype *p, const realtype *k, const int ip) override {}
@@ -544,7 +544,7 @@ class Model_model_dirac_py : public amici::Model_ODE {
544544
* @return AMICI git commit hash
545545
*/
546546
std::string get_amici_commit() const override {
547-
return "b0b2684b4b67db9eadf5e47d4f87f8fe74dd9070";
547+
return "unknown";
548548
}
549549

550550
bool has_quadratic_llh() const override {

models/model_dirac_py/root.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
namespace amici {
1010
namespace model_model_dirac_py {
1111

12-
void root_model_dirac_py(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *tcl){
12+
void root_model_dirac_py(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *tcl){
1313
root[0] = -p2 + t;
1414
}
1515

models/model_dirac_py/stau.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
namespace amici {
1111
namespace model_model_dirac_py {
1212

13-
void stau_model_dirac_py(realtype *stau, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *dx, const realtype *tcl, const realtype *sx, const int ip, const int ie){
13+
void stau_model_dirac_py(realtype *stau, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *dx, const realtype *tcl, const realtype *sx, const int ip, const int ie){
1414
switch(ie) {
1515
case 0:
1616
switch(ip) {

models/model_events_py/deltaqB.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "p.h"
77
#include "k.h"
88
#include "h.h"
9+
#include "w.h"
910
#include "xdot.h"
1011
#include "xdot_old.h"
1112
#include "x_old.h"
@@ -14,7 +15,7 @@
1415
namespace amici {
1516
namespace model_model_events_py {
1617

17-
void deltaqB_model_events_py(realtype *deltaqB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *dx, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB){
18+
void deltaqB_model_events_py(realtype *deltaqB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *dx, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB){
1819
switch(ie) {
1920
case 2:
2021
case 3:

0 commit comments

Comments
 (0)