Skip to content

Commit 2ccb30d

Browse files
committed
Let froot depend on w
WIP This should hopefully allow for more efficient computation of `root`-derivatives by avoiding flattening `w` into `root` which, so far, made computing `drootdt_total` prohibitively expensive in case of large `w` dependencies in `root`.
1 parent b0b2684 commit 2ccb30d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+561
-193
lines changed

cmake/AmiciFindBLAS.cmake

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ if(DEFINED ENV{AMICI_BLAS_USE_SCIPY_OPENBLAS})
1515
"Using AMICI_BLAS_USE_SCIPY_OPENBLAS=${AMICI_BLAS_USE_SCIPY_OPENBLAS} from environment variable."
1616
)
1717
set(AMICI_BLAS_USE_SCIPY_OPENBLAS $ENV{AMICI_BLAS_USE_SCIPY_OPENBLAS})
18+
elseif(NOT DEFINED AMICI_BLAS_USE_SCIPY_OPENBLAS
19+
AND NOT AMICI_PYTHON_BUILD_EXT_ONLY)
20+
# If were are not building the Python extension, it's unlikely that we want to
21+
# use scipy-openblas
22+
set(AMICI_BLAS_USE_SCIPY_OPENBLAS FALSE)
1823
endif()
1924

2025
if((${BLAS} STREQUAL "MKL" OR DEFINED ENV{MKLROOT})

include/amici/abstract_model.h

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ class AbstractModel {
299299
* @param p parameter vector
300300
* @param k constant vector
301301
* @param h Heaviside vector
302+
* @param w vector with helper variables
302303
* @param dx time derivative of state (DAE only)
303304
* @param tcl total abundances for conservation laws
304305
* @param sx current state sensitivity
@@ -307,8 +308,9 @@ class AbstractModel {
307308
*/
308309
virtual void fstau(
309310
realtype* stau, realtype t, realtype const* x, realtype const* p,
310-
realtype const* k, realtype const* h, realtype const* dx,
311-
realtype const* tcl, realtype const* sx, int ip, int ie
311+
realtype const* k, realtype const* h, realtype const* w,
312+
realtype const* dx, realtype const* tcl, realtype const* sx, int ip,
313+
int ie
312314
);
313315

314316
/**
@@ -542,6 +544,7 @@ class AbstractModel {
542544
* @param p parameter vector
543545
* @param k constant vector
544546
* @param h Heaviside vector
547+
* @param w vector with helper variables
545548
* @param dx time derivative of state (DAE only)
546549
* @param ie event index
547550
* @param xdot new model right hand side
@@ -552,9 +555,10 @@ class AbstractModel {
552555
*/
553556
virtual void fdeltaxB(
554557
realtype* deltaxB, realtype t, realtype const* x, realtype const* p,
555-
realtype const* k, realtype const* h, realtype const* dx, int ie,
556-
realtype const* xdot, realtype const* xdot_old, realtype const* x_old,
557-
realtype const* xB, realtype const* tcl
558+
realtype const* k, realtype const* h, realtype const* w,
559+
realtype const* dx, int ie, realtype const* xdot,
560+
realtype const* xdot_old, realtype const* x_old, realtype const* xB,
561+
realtype const* tcl
558562
);
559563

560564
/**
@@ -565,6 +569,7 @@ class AbstractModel {
565569
* @param p parameter vector
566570
* @param k constant vector
567571
* @param h Heaviside vector
572+
* @param w vector with helper variables
568573
* @param dx time derivative of state (DAE only)
569574
* @param ip sensitivity index
570575
* @param ie event index
@@ -575,9 +580,9 @@ class AbstractModel {
575580
*/
576581
virtual void fdeltaqB(
577582
realtype* deltaqB, realtype t, realtype const* x, realtype const* p,
578-
realtype const* k, realtype const* h, realtype const* dx, int ip,
579-
int ie, realtype const* xdot, realtype const* xdot_old,
580-
realtype const* x_old, realtype const* xB
583+
realtype const* k, realtype const* h, realtype const* w,
584+
realtype const* dx, int ip, int ie, realtype const* xdot,
585+
realtype const* xdot_old, realtype const* x_old, realtype const* xB
581586
);
582587

583588
/**

include/amici/model_dae.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,11 +359,13 @@ class Model_DAE : public Model {
359359
* @param p parameter vector
360360
* @param k constants vector
361361
* @param h Heaviside vector
362+
* @param w vector with helper variables
362363
* @param dx Vector with the derivative states
363364
**/
364365
virtual void froot(
365366
realtype* root, realtype t, realtype const* x, double const* p,
366-
double const* k, realtype const* h, realtype const* dx
367+
double const* k, realtype const* h, realtype const* w,
368+
realtype const* dx
367369
);
368370

369371
/**

include/amici/model_ode.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,13 @@ class Model_ODE : public Model {
316316
* @param p parameter vector
317317
* @param k constants vector
318318
* @param h Heaviside vector
319+
* @param w vector with helper variables
319320
* @param tcl total abundances for conservation laws
320321
**/
321322
virtual void froot(
322323
realtype* root, realtype t, realtype const* x, realtype const* p,
323-
realtype const* k, realtype const* h, realtype const* tcl
324+
realtype const* k, realtype const* h, realtype const* w,
325+
realtype const* tcl
324326
);
325327

326328
/**

models/model_calvetti_py/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@ cmake_policy(VERSION 3.22...3.31)
44

55
project(model_calvetti_py)
66

7+
message(STATUS "CMAKE_CURRENT_SOURCE_DIR: ${CMAKE_CURRENT_SOURCE_DIR}")
8+
message(STATUS "CMAKE_BINARY_DIR: ${CMAKE_BINARY_DIR}")
9+
message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}")
10+
message(STATUS "CMAKE_VERSION: ${CMAKE_VERSION}")
11+
message(STATUS "CMAKE_COMMAND: ${CMAKE_COMMAND}")
12+
713
set(CMAKE_CXX_STANDARD 20)
814
set(CMAKE_CXX_STANDARD_REQUIRED ON)
915
set(CMAKE_POSITION_INDEPENDENT_CODE ON)

models/model_calvetti_py/model_calvetti_py.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ extern void dJydy_rowvals_model_calvetti_py(SUNMatrixWrapper &rowvals, int index
3838

3939

4040

41-
extern void root_model_calvetti_py(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *tcl);
41+
extern void root_model_calvetti_py(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *tcl);
4242

4343

4444

@@ -200,10 +200,10 @@ class Model_model_calvetti_py : public amici::Model_DAE {
200200
void fdeltasx(realtype *deltasx, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *sx, const realtype *stau, const realtype *tcl, const realtype *x_old) override {}
201201

202202

203-
void fdeltaxB(realtype *deltaxB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *dx, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB, const realtype *tcl) override {}
203+
void fdeltaxB(realtype *deltaxB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *dx, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB, const realtype *tcl) override {}
204204

205205

206-
void fdeltaqB(realtype *deltaqB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *dx, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB) override {}
206+
void fdeltaqB(realtype *deltaqB, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *dx, const int ip, const int ie, const realtype *xdot, const realtype *xdot_old, const realtype *x_old, const realtype *xB) override {}
207207

208208

209209
void fdrzdp(realtype *drzdp, const int ie, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const int ip) override {}
@@ -323,8 +323,8 @@ class Model_model_calvetti_py : public amici::Model_DAE {
323323
void fdzdx(realtype *dzdx, const int ie, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h) override {}
324324

325325

326-
void froot(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *tcl) override {
327-
root_model_calvetti_py(root, t, x, p, k, h, tcl);
326+
void froot(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *tcl) override {
327+
root_model_calvetti_py(root, t, x, p, k, h, w, tcl);
328328
}
329329

330330

@@ -339,7 +339,7 @@ class Model_model_calvetti_py : public amici::Model_DAE {
339339
void fsigmaz(realtype *sigmaz, const realtype t, const realtype *p, const realtype *k) override {}
340340

341341

342-
void fstau(realtype *stau, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *dx, const realtype *tcl, const realtype *sx, const int ip, const int ie) override {}
342+
void fstau(realtype *stau, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *dx, const realtype *tcl, const realtype *sx, const int ip, const int ie) override {}
343343

344344
void fsx0(realtype *sx0, const realtype t, const realtype *x, const realtype *p, const realtype *k, const int ip) override {}
345345

@@ -557,7 +557,7 @@ class Model_model_calvetti_py : public amici::Model_DAE {
557557
* @return AMICI git commit hash
558558
*/
559559
std::string get_amici_commit() const override {
560-
return "f005fac9e2de7c3c90be2ac55d4ad165471ed1e7";
560+
return "unknown";
561561
}
562562

563563
bool has_quadratic_llh() const override {

models/model_calvetti_py/root.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
#include "x.h"
66
#include "k.h"
77
#include "h.h"
8+
#include "w.h"
89

910
namespace amici {
1011
namespace model_model_calvetti_py {
1112

12-
void root_model_calvetti_py(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *tcl){
13+
void root_model_calvetti_py(realtype *root, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, const realtype *w, const realtype *tcl){
1314
root[0] = t - 10;
1415
root[1] = 10 - t;
1516
root[2] = 12 - t;

models/model_calvetti_py/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""AMICI model package setup"""
22

3+
import importlib.metadata
34
import os
45
import sys
56
from pathlib import Path
@@ -8,7 +9,6 @@
89
from amici.custom_commands import AmiciBuildCMakeExtension
910
from cmake_build_extension import CMakeExtension
1011
from setuptools import find_namespace_packages, setup
11-
import importlib.metadata
1212

1313

1414
def get_extension() -> CMakeExtension:

models/model_calvetti_py/swig/model_calvetti_py.i

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@ import sysconfig
88
from pathlib import Path
99
1010
ext_suffix = sysconfig.get_config_var('EXT_SUFFIX')
11+
extension_path = Path(__file__).parent / f'_model_calvetti_py{ext_suffix}'
1112
_model_calvetti_py = amici._module_from_path(
1213
'model_calvetti_py._model_calvetti_py' if __package__ or '.' in __name__
1314
else '_model_calvetti_py',
14-
Path(__file__).parent / f'_model_calvetti_py{ext_suffix}',
15+
extension_path,
1516
)
1617
1718
def _get_import_time():
@@ -36,6 +37,28 @@ if t_imported < t_modified:
3637

3738
%module(package="model_calvetti_py",moduleimport=MODULEIMPORT) model_calvetti_py
3839

40+
// store swig version
41+
%constant int SWIG_VERSION_MAJOR = (SWIG_VERSION >> 16);
42+
%constant int SWIG_VERSION_MINOR = ((SWIG_VERSION >> 8) & 0xff);
43+
%constant int SWIG_VERSION_PATCH = (SWIG_VERSION & 0xff);
44+
45+
%pythoncode %{
46+
# SWIG version used to build the model extension as `(major, minor, patch)`
47+
_SWIG_VERSION = (SWIG_VERSION_MAJOR, SWIG_VERSION_MINOR, SWIG_VERSION_PATCH)
48+
49+
if (amici_swig := amici.amici._SWIG_VERSION) != (model_swig := _SWIG_VERSION):
50+
import warnings
51+
warnings.warn(
52+
f"SWIG version mismatch between amici ({amici_swig}) and model "
53+
f"({model_swig}). This may lead to unexpected behavior. "
54+
"In that case, please recompile the model with swig=="
55+
f"{amici_swig[0]}.{amici_swig[1]}.{amici_swig[2]} or rebuild amici "
56+
f"with swig=={model_swig[0]}.{model_swig[1]}.{model_swig[2]}.",
57+
RuntimeWarning,
58+
stacklevel=2,
59+
)
60+
%}
61+
3962
%pythoncode %{
4063
# the model-package __init__.py module (will be set during import)
4164
_model_module = None
@@ -56,7 +79,7 @@ using namespace amici;
5679
// store the time a module was imported
5780
%{
5881
#include <chrono>
59-
static std::chrono::time_point<std::chrono::system_clock> _module_import_time;
82+
static std::chrono::time_point<std::chrono::system_clock> _module_import_time = std::chrono::system_clock::now();
6083

6184
static double _get_import_time() {
6285
auto epoch = _module_import_time.time_since_epoch();
@@ -67,7 +90,9 @@ static double _get_import_time() {
6790
static double _get_import_time();
6891

6992
%init %{
70-
_module_import_time = std::chrono::system_clock::now();
93+
// NOTE: from SWIG 4.4.0 onwards, %init code is executed every time the
94+
// module is executed - not only on first import
95+
// This code ends up in `SWIG_mod_exec`.
7196
%}
7297

7398

models/model_dirac_py/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@ cmake_policy(VERSION 3.22...3.31)
44

55
project(model_dirac_py)
66

7+
message(STATUS "CMAKE_CURRENT_SOURCE_DIR: ${CMAKE_CURRENT_SOURCE_DIR}")
8+
message(STATUS "CMAKE_BINARY_DIR: ${CMAKE_BINARY_DIR}")
9+
message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}")
10+
message(STATUS "CMAKE_VERSION: ${CMAKE_VERSION}")
11+
message(STATUS "CMAKE_COMMAND: ${CMAKE_COMMAND}")
12+
713
set(CMAKE_CXX_STANDARD 20)
814
set(CMAKE_CXX_STANDARD_REQUIRED ON)
915
set(CMAKE_POSITION_INDEPENDENT_CODE ON)

0 commit comments

Comments
 (0)