Skip to content

Commit 98c782d

Browse files
authored
Combine code for sparse model functions and their index files (#2159)
Combine code for sparse model functions and their index files, i.e. generate only a single file instead of 3 individual files for content, rowvals, and colptrs, respectively. Advantage: Faster import of smaller models and fewer files. For a toy model, this reduced the build steps from 44 to 28, and reduced build time by >20% on my computer. Disadvantage: None found, so I don't think it worth adding an option for (not) combining those files. For larger models, there shouldn't be any impact. The extra time for compiling the index arrays should be negligible compared to computing the contents. Related to #2119 Here a test for a large model (N=1): | File | Size | Compilation time (s) | |--------------|--------:|---------------------:| | dwdx | 22.4MiB | 3413.64 | | dwdx_colptrs | 2.0KiB | 2.79 | | dwdx_rowvals | 65.6KiB | 2.66 | | *combined* | | 3416.79 | I'd consider this time increase negligible.
1 parent 8ead4b0 commit 98c782d

File tree

1 file changed

+27
-21
lines changed

1 file changed

+27
-21
lines changed

python/sdist/amici/de_export.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
Set,
3333
Tuple,
3434
Union,
35+
Literal
3536
)
3637

3738
import numpy as np
@@ -2838,9 +2839,6 @@ def _generate_c_code(self) -> None:
28382839
if func_info.generate_body:
28392840
dec = log_execution_time(f"writing {func_name}.cpp", logger)
28402841
dec(self._write_function_file)(func_name)
2841-
if func_name in sparse_functions and func_info.body:
2842-
self._write_function_index(func_name, "colptrs")
2843-
self._write_function_index(func_name, "rowvals")
28442842

28452843
for name in self.model.sym_names():
28462844
# only generate for those that have nontrivial implementation,
@@ -3040,16 +3038,32 @@ def _write_function_file(self, function: str) -> None:
30403038
else:
30413039
equations = self.model.eq(function)
30423040

3041+
# function body
3042+
if function == "create_splines":
3043+
body = self._get_create_splines_body()
3044+
else:
3045+
body = self._get_function_body(function, equations)
3046+
if not body:
3047+
return
3048+
3049+
# colptrs / rowvals for sparse matrices
3050+
if function in sparse_functions:
3051+
lines = self._generate_function_index(function, "colptrs")
3052+
lines.extend(self._generate_function_index(function, "rowvals"))
3053+
lines.append("\n\n")
3054+
else:
3055+
lines = []
3056+
30433057
# function header
3044-
lines = [
3058+
lines.extend([
30453059
'#include "amici/symbolic_functions.h"',
30463060
'#include "amici/defines.h"',
30473061
'#include "sundials/sundials_types.h"',
30483062
"",
30493063
"#include <gsl/gsl-lite.hpp>",
30503064
"#include <algorithm>",
30513065
"",
3052-
]
3066+
])
30533067
if function == "create_splines":
30543068
lines += ['#include "amici/splinefunctions.h"', "#include <vector>"]
30553069

@@ -3096,14 +3110,6 @@ def _write_function_file(self, function: str) -> None:
30963110
]
30973111
)
30983112

3099-
# function body
3100-
if function == "create_splines":
3101-
body = self._get_create_splines_body()
3102-
else:
3103-
body = self._get_function_body(function, equations)
3104-
if not body:
3105-
return
3106-
31073113
if self.assume_pow_positivity and func_info.assume_pow_positivity:
31083114
pow_rx = re.compile(r"(^|\W)std::pow\(")
31093115
body = [
@@ -3137,16 +3143,20 @@ def _write_function_file(self, function: str) -> None:
31373143
with open(filename, "w") as fileout:
31383144
fileout.write("\n".join(lines))
31393145

3140-
def _write_function_index(self, function: str, indextype: str) -> None:
3146+
def _generate_function_index(
3147+
self, function: str, indextype: Literal["colptrs", "rowvals"]
3148+
) -> List[str]:
31413149
"""
3142-
Generate equations and write the C++ code for the function
3143-
``function``.
3150+
Generate equations and C++ code for the function ``function``.
31443151
31453152
:param function:
31463153
name of the function to be written (see ``self.functions``)
31473154
31483155
:param indextype:
31493156
type of index {'colptrs', 'rowvals'}
3157+
3158+
:returns:
3159+
The code lines for the respective function index file
31503160
"""
31513161
if indextype == "colptrs":
31523162
values = self.model.colptrs(function)
@@ -3233,11 +3243,7 @@ def _write_function_index(self, function: str, indextype: str) -> None:
32333243
]
32343244
)
32353245

3236-
filename = f"{function}_{indextype}.cpp"
3237-
filename = os.path.join(self.model_path, filename)
3238-
3239-
with open(filename, "w") as fileout:
3240-
fileout.write("\n".join(lines))
3246+
return lines
32413247

32423248
def _get_function_body(self, function: str, equations: sp.Matrix) -> List[str]:
32433249
"""

0 commit comments

Comments
 (0)