Skip to content

Commit 0d49041

Browse files
authored
Decouple JAX & C++ code generation (#2615)
* refactor * correct doc * remove PK * fix tests * fix notebook * fix parameter ids * fix jax test * fix notebook * reviews * fixup
1 parent bd3bd91 commit 0d49041

File tree

13 files changed

+832
-417
lines changed

13 files changed

+832
-417
lines changed

python/sdist/amici/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,6 @@ def get_model(self) -> amici.Model:
141141
"""Create a model instance."""
142142
...
143143

144-
def get_jax_model(self) -> JAXModel: ...
145-
146144
AmiciModel = Union[amici.Model, amici.ModelPtr]
147145
else:
148146
ModelModule = ModuleType
Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
11
"""AMICI-generated module for model TPL_MODELNAME"""
22

3-
import datetime
4-
import os
53
import sys
64
from pathlib import Path
7-
from typing import TYPE_CHECKING
85
import amici
96

107

11-
if TYPE_CHECKING:
12-
from amici.jax import JAXModel
13-
148
# Ensure we are binary-compatible, see #556
159
if "TPL_AMICI_VERSION" != amici.__version__:
1610
raise amici.AmiciVersionError(
@@ -38,28 +32,4 @@
3832
# when the model package is imported via `import`
3933
TPL_MODELNAME._model_module = sys.modules[__name__]
4034

41-
42-
def get_jax_model() -> "JAXModel":
43-
# If the model directory was meanwhile overwritten, this would load the
44-
# new version, which would not match the previously imported extension.
45-
# This is not allowed, as it would lead to inconsistencies.
46-
jax_py_file = Path(__file__).parent / "jax.py"
47-
jax_py_file = jax_py_file.resolve()
48-
t_imported = TPL_MODELNAME._get_import_time() # noqa: protected-access
49-
t_modified = os.path.getmtime(jax_py_file)
50-
if t_imported < t_modified:
51-
t_imp_str = datetime.datetime.fromtimestamp(t_imported).isoformat()
52-
t_mod_str = datetime.datetime.fromtimestamp(t_modified).isoformat()
53-
raise RuntimeError(
54-
f"Refusing to import {jax_py_file} which was changed since "
55-
f"TPL_MODELNAME was imported. This is to avoid inconsistencies "
56-
"between the different model implementations.\n"
57-
f"Imported at {t_imp_str}\nModified at {t_mod_str}.\n"
58-
"Import the module with a different name or restart the "
59-
"Python kernel."
60-
)
61-
jax = amici._module_from_path("jax", jax_py_file)
62-
return jax.JAXModel_TPL_MODELNAME()
63-
64-
6535
__version__ = "TPL_PACKAGE_VERSION"

python/sdist/amici/de_export.py

Lines changed: 20 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
TYPE_CHECKING,
2222
Literal,
2323
)
24-
from itertools import chain
2524

2625
import sympy as sp
2726

@@ -56,7 +55,6 @@
5655
AmiciCxxCodePrinter,
5756
get_switch_statement,
5857
)
59-
from .jaxcodeprinter import AmiciJaxCodePrinter
6058
from .de_model import DEModel
6159
from .de_model_components import *
6260
from .import_utils import (
@@ -146,10 +144,7 @@ class DEExporter:
146144
If the given model uses special functions, this set contains hints for
147145
model building.
148146
149-
:ivar _code_printer_jax:
150-
Code printer to generate JAX code
151-
152-
:ivar _code_printer_cpp:
147+
:ivar _code_printer:
153148
Code printer to generate C++ code
154149
155150
:ivar generate_sensitivity_code:
@@ -218,15 +213,14 @@ def __init__(
218213
self.set_name(model_name)
219214
self.set_paths(outdir)
220215

221-
self._code_printer_cpp = AmiciCxxCodePrinter()
222-
self._code_printer_jax = AmiciJaxCodePrinter()
216+
self._code_printer = AmiciCxxCodePrinter()
223217
for fun in CUSTOM_FUNCTIONS:
224-
self._code_printer_cpp.known_functions[fun["sympy"]] = fun["c++"]
218+
self._code_printer.known_functions[fun["sympy"]] = fun["c++"]
225219

226220
# Signatures and properties of generated model functions (see
227221
# include/amici/model.h for details)
228222
self.model: DEModel = de_model
229-
self._code_printer_cpp.known_functions.update(
223+
self._code_printer.known_functions.update(
230224
splines.spline_user_functions(
231225
self.model._splines, self._get_index("p")
232226
)
@@ -249,7 +243,6 @@ def generate_model_code(self) -> None:
249243
sp.Pow, "_eval_derivative", _custom_pow_eval_derivative
250244
):
251245
self._prepare_model_folder()
252-
self._generate_jax_code()
253246
self._generate_c_code()
254247
self._generate_m_code()
255248

@@ -277,121 +270,6 @@ def _prepare_model_folder(self) -> None:
277270
if os.path.isfile(file_path):
278271
os.remove(file_path)
279272

280-
@log_execution_time("generating jax code", logger)
281-
def _generate_jax_code(self) -> None:
282-
try:
283-
from amici.jax.model import JAXModel
284-
except ImportError:
285-
logger.warning(
286-
"Could not import JAXModel. JAX code will not be generated."
287-
)
288-
return
289-
290-
eq_names = (
291-
"xdot",
292-
"w",
293-
"x0",
294-
"y",
295-
"sigmay",
296-
"Jy",
297-
"x_solver",
298-
"x_rdata",
299-
"total_cl",
300-
)
301-
sym_names = ("x", "tcl", "w", "my", "y", "sigmay", "x_rdata")
302-
303-
indent = 8
304-
305-
def jnp_array_str(array) -> str:
306-
elems = ", ".join(str(s) for s in array)
307-
308-
return f"jnp.array([{elems}])"
309-
310-
# replaces Heaviside variables with corresponding functions
311-
subs_heaviside = dict(
312-
zip(
313-
self.model.sym("h"),
314-
[sp.Heaviside(x) for x in self.model.eq("root")],
315-
strict=True,
316-
)
317-
)
318-
# replaces observables with a generic my variable
319-
subs_observables = dict(
320-
zip(
321-
self.model.sym("my"),
322-
[sp.Symbol("my")] * len(self.model.sym("my")),
323-
strict=True,
324-
)
325-
)
326-
327-
tpl_data = {
328-
# assign named variable using corresponding algebraic formula (function body)
329-
**{
330-
f"{eq_name.upper()}_EQ": "\n".join(
331-
self._code_printer_jax._get_sym_lines(
332-
(str(strip_pysb(s)) for s in self.model.sym(eq_name)),
333-
self.model.eq(eq_name).subs(
334-
{**subs_heaviside, **subs_observables}
335-
),
336-
indent,
337-
)
338-
)[indent:] # remove indent for first line
339-
for eq_name in eq_names
340-
},
341-
# create jax array from concatenation of named variables
342-
**{
343-
f"{eq_name.upper()}_RET": jnp_array_str(
344-
strip_pysb(s) for s in self.model.sym(eq_name)
345-
)
346-
if self.model.sym(eq_name)
347-
else "jnp.array([])"
348-
for eq_name in eq_names
349-
},
350-
# assign named variables from a jax array
351-
**{
352-
f"{sym_name.upper()}_SYMS": "".join(
353-
str(strip_pysb(s)) + ", " for s in self.model.sym(sym_name)
354-
)
355-
if self.model.sym(sym_name)
356-
else "_"
357-
for sym_name in sym_names
358-
},
359-
# tuple of variable names (ids as they are unique)
360-
**{
361-
f"{sym_name.upper()}_IDS": "".join(
362-
f'"{strip_pysb(s)}", ' for s in self.model.sym(sym_name)
363-
)
364-
if self.model.sym(sym_name)
365-
else "tuple()"
366-
for sym_name in ("p", "k", "y", "x")
367-
},
368-
**{
369-
# in jax model we do not need to distinguish between p (parameters) and
370-
# k (fixed parameters) so we use a single variable combining both
371-
"PK_SYMS": "".join(
372-
str(strip_pysb(s)) + ", "
373-
for s in chain(self.model.sym("p"), self.model.sym("k"))
374-
),
375-
"PK_IDS": "".join(
376-
f'"{strip_pysb(s)}", '
377-
for s in chain(self.model.sym("p"), self.model.sym("k"))
378-
),
379-
"MODEL_NAME": self.model_name,
380-
# keep track of the API version that the model was generated with so we
381-
# can flag conflicts in the future
382-
"MODEL_API_VERSION": f"'{JAXModel.MODEL_API_VERSION}'",
383-
},
384-
}
385-
os.makedirs(
386-
os.path.join(self.model_path, self.model_name), exist_ok=True
387-
)
388-
389-
apply_template(
390-
os.path.join(amiciModulePath, "jax.template.py"),
391-
os.path.join(self.model_path, self.model_name, "jax.py"),
392-
tpl_data,
393-
)
394-
395273
def _generate_c_code(self) -> None:
396274
"""
397275
Create C++ code files for the model based on
@@ -795,7 +673,7 @@ def _get_function_body(
795673
lines = []
796674

797675
if len(equations) == 0 or (
798-
isinstance(equations, (sp.Matrix, sp.ImmutableDenseMatrix))
676+
isinstance(equations, sp.Matrix | sp.ImmutableDenseMatrix)
799677
and min(equations.shape) == 0
800678
):
801679
# dJydy is a list
@@ -852,7 +730,7 @@ def _get_function_body(
852730
f"reinitialization_state_idxs.cend(), {index}) != "
853731
"reinitialization_state_idxs.cend())",
854732
f" {function}[{index}] = "
855-
f"{self._code_printer_cpp.doprint(formula)};",
733+
f"{self._code_printer.doprint(formula)};",
856734
]
857735
)
858736
cases[ipar] = expressions
@@ -867,12 +745,12 @@ def _get_function_body(
867745
f"reinitialization_state_idxs.cend(), {index}) != "
868746
"reinitialization_state_idxs.cend())\n "
869747
f"{function}[{index}] = "
870-
f"{self._code_printer_cpp.doprint(formula)};"
748+
f"{self._code_printer.doprint(formula)};"
871749
)
872750

873751
elif function in event_functions:
874752
cases = {
875-
ie: self._code_printer_cpp._get_sym_lines_array(
753+
ie: self._code_printer._get_sym_lines_array(
876754
equations[ie], function, 0
877755
)
878756
for ie in range(self.model.num_events())
@@ -885,7 +763,7 @@ def _get_function_body(
885763
for ie, inner_equations in enumerate(equations):
886764
inner_lines = []
887765
inner_cases = {
888-
ipar: self._code_printer_cpp._get_sym_lines_array(
766+
ipar: self._code_printer._get_sym_lines_array(
889767
inner_equations[:, ipar], function, 0
890768
)
891769
for ipar in range(self.model.num_par())
@@ -900,7 +778,7 @@ def _get_function_body(
900778
and equations.shape[1] == self.model.num_par()
901779
):
902780
cases = {
903-
ipar: self._code_printer_cpp._get_sym_lines_array(
781+
ipar: self._code_printer._get_sym_lines_array(
904782
equations[:, ipar], function, 0
905783
)
906784
for ipar in range(self.model.num_par())
@@ -910,15 +788,15 @@ def _get_function_body(
910788
elif function in multiobs_functions:
911789
if function == "dJydy":
912790
cases = {
913-
iobs: self._code_printer_cpp._get_sym_lines_array(
791+
iobs: self._code_printer._get_sym_lines_array(
914792
equations[iobs], function, 0
915793
)
916794
for iobs in range(self.model.num_obs())
917795
if not smart_is_zero_matrix(equations[iobs])
918796
}
919797
else:
920798
cases = {
921-
iobs: self._code_printer_cpp._get_sym_lines_array(
799+
iobs: self._code_printer._get_sym_lines_array(
922800
equations[:, iobs], function, 0
923801
)
924802
for iobs in range(equations.shape[1])
@@ -948,7 +826,7 @@ def _get_function_body(
948826
tmp_equations = sp.Matrix(
949827
[equations[i] for i in static_idxs]
950828
)
951-
tmp_lines = self._code_printer_cpp._get_sym_lines_symbols(
829+
tmp_lines = self._code_printer._get_sym_lines_symbols(
952830
tmp_symbols,
953831
tmp_equations,
954832
function,
@@ -974,7 +852,7 @@ def _get_function_body(
974852
[equations[i] for i in dynamic_idxs]
975853
)
976854

977-
tmp_lines = self._code_printer_cpp._get_sym_lines_symbols(
855+
tmp_lines = self._code_printer._get_sym_lines_symbols(
978856
tmp_symbols,
979857
tmp_equations,
980858
function,
@@ -986,12 +864,12 @@ def _get_function_body(
986864
lines.extend(tmp_lines)
987865

988866
else:
989-
lines += self._code_printer_cpp._get_sym_lines_symbols(
867+
lines += self._code_printer._get_sym_lines_symbols(
990868
symbols, equations, function, 4
991869
)
992870

993871
else:
994-
lines += self._code_printer_cpp._get_sym_lines_array(
872+
lines += self._code_printer._get_sym_lines_array(
995873
equations, function, 4
996874
)
997875

@@ -1136,8 +1014,7 @@ def _write_model_header_cpp(self) -> None:
11361014
)
11371015
),
11381016
"NDXDOTDX_EXPLICIT": len(self.model.sparsesym("dxdotdx_explicit")),
1139-
"NDJYDY": "std::vector<int>{%s}"
1140-
% ",".join(str(len(x)) for x in self.model.sparsesym("dJydy")),
1017+
"NDJYDY": f"std::vector<int>{{{','.join(str(len(x)) for x in self.model.sparsesym('dJydy'))}}}",
11411018
"NDXRDATADXSOLVER": len(self.model.sparsesym("dx_rdatadx_solver")),
11421019
"NDXRDATADTCL": len(self.model.sparsesym("dx_rdatadtcl")),
11431020
"NDTOTALCLDXRDATA": len(self.model.sparsesym("dtotal_cldx_rdata")),
@@ -1147,10 +1024,10 @@ def _write_model_header_cpp(self) -> None:
11471024
"NK": self.model.num_const(),
11481025
"O2MODE": "amici::SecondOrderMode::none",
11491026
# using code printer ensures proper handling of nan/inf
1150-
"PARAMETERS": self._code_printer_cpp.doprint(self.model.val("p"))[
1027+
"PARAMETERS": self._code_printer.doprint(self.model.val("p"))[
11511028
1:-1
11521029
],
1153-
"FIXED_PARAMETERS": self._code_printer_cpp.doprint(
1030+
"FIXED_PARAMETERS": self._code_printer.doprint(
11541031
self.model.val("k")
11551032
)[1:-1],
11561033
"PARAMETER_NAMES_INITIALIZER_LIST": self._get_symbol_name_initializer_list(
@@ -1344,7 +1221,7 @@ def _get_symbol_id_initializer_list(self, name: str) -> str:
13441221
Template initializer list of ids
13451222
"""
13461223
return "\n".join(
1347-
f'"{self._code_printer_cpp.doprint(symbol)}", // {name}[{idx}]'
1224+
f'"{self._code_printer.doprint(symbol)}", // {name}[{idx}]'
13481225
for idx, symbol in enumerate(self.model.sym(name))
13491226
)
13501227

0 commit comments

Comments
 (0)