2121 TYPE_CHECKING ,
2222 Literal ,
2323)
24- from itertools import chain
2524
2625import sympy as sp
2726
5655 AmiciCxxCodePrinter ,
5756 get_switch_statement ,
5857)
59- from .jaxcodeprinter import AmiciJaxCodePrinter
6058from .de_model import DEModel
6159from .de_model_components import *
6260from .import_utils import (
@@ -146,10 +144,7 @@ class DEExporter:
146144 If the given model uses special functions, this set contains hints for
147145 model building.
148146
149- :ivar _code_printer_jax:
150- Code printer to generate JAX code
151-
152- :ivar _code_printer_cpp:
147+ :ivar _code_printer:
153148 Code printer to generate C++ code
154149
155150 :ivar generate_sensitivity_code:
@@ -218,15 +213,14 @@ def __init__(
218213 self .set_name (model_name )
219214 self .set_paths (outdir )
220215
221- self ._code_printer_cpp = AmiciCxxCodePrinter ()
222- self ._code_printer_jax = AmiciJaxCodePrinter ()
216+ self ._code_printer = AmiciCxxCodePrinter ()
223217 for fun in CUSTOM_FUNCTIONS :
224- self ._code_printer_cpp .known_functions [fun ["sympy" ]] = fun ["c++" ]
218+ self ._code_printer .known_functions [fun ["sympy" ]] = fun ["c++" ]
225219
226220 # Signatures and properties of generated model functions (see
227221 # include/amici/model.h for details)
228222 self .model : DEModel = de_model
229- self ._code_printer_cpp .known_functions .update (
223+ self ._code_printer .known_functions .update (
230224 splines .spline_user_functions (
231225 self .model ._splines , self ._get_index ("p" )
232226 )
@@ -249,7 +243,6 @@ def generate_model_code(self) -> None:
249243 sp .Pow , "_eval_derivative" , _custom_pow_eval_derivative
250244 ):
251245 self ._prepare_model_folder ()
252- self ._generate_jax_code ()
253246 self ._generate_c_code ()
254247 self ._generate_m_code ()
255248
@@ -277,121 +270,6 @@ def _prepare_model_folder(self) -> None:
277270 if os .path .isfile (file_path ):
278271 os .remove (file_path )
279272
280- @log_execution_time ("generating jax code" , logger )
281- def _generate_jax_code (self ) -> None :
282- try :
283- from amici .jax .model import JAXModel
284- except ImportError :
285- logger .warning (
286- "Could not import JAXModel. JAX code will not be generated."
287- )
288- return
289-
290- eq_names = (
291- "xdot" ,
292- "w" ,
293- "x0" ,
294- "y" ,
295- "sigmay" ,
296- "Jy" ,
297- "x_solver" ,
298- "x_rdata" ,
299- "total_cl" ,
300- )
301- sym_names = ("x" , "tcl" , "w" , "my" , "y" , "sigmay" , "x_rdata" )
302-
303- indent = 8
304-
305- def jnp_array_str (array ) -> str :
306- elems = ", " .join (str (s ) for s in array )
307-
308- return f"jnp.array([{ elems } ])"
309-
310- # replaces Heaviside variables with corresponding functions
311- subs_heaviside = dict (
312- zip (
313- self .model .sym ("h" ),
314- [sp .Heaviside (x ) for x in self .model .eq ("root" )],
315- strict = True ,
316- )
317- )
318- # replaces observables with a generic my variable
319- subs_observables = dict (
320- zip (
321- self .model .sym ("my" ),
322- [sp .Symbol ("my" )] * len (self .model .sym ("my" )),
323- strict = True ,
324- )
325- )
326-
327- tpl_data = {
328- # assign named variable using corresponding algebraic formula (function body)
329- ** {
330- f"{ eq_name .upper ()} _EQ" : "\n " .join (
331- self ._code_printer_jax ._get_sym_lines (
332- (str (strip_pysb (s )) for s in self .model .sym (eq_name )),
333- self .model .eq (eq_name ).subs (
334- {** subs_heaviside , ** subs_observables }
335- ),
336- indent ,
337- )
338- )[indent :] # remove indent for first line
339- for eq_name in eq_names
340- },
341- # create jax array from concatenation of named variables
342- ** {
343- f"{ eq_name .upper ()} _RET" : jnp_array_str (
344- strip_pysb (s ) for s in self .model .sym (eq_name )
345- )
346- if self .model .sym (eq_name )
347- else "jnp.array([])"
348- for eq_name in eq_names
349- },
350- # assign named variables from a jax array
351- ** {
352- f"{ sym_name .upper ()} _SYMS" : "" .join (
353- str (strip_pysb (s )) + ", " for s in self .model .sym (sym_name )
354- )
355- if self .model .sym (sym_name )
356- else "_"
357- for sym_name in sym_names
358- },
359- # tuple of variable names (ids as they are unique)
360- ** {
361- f"{ sym_name .upper ()} _IDS" : "" .join (
362- f'"{ strip_pysb (s )} ", ' for s in self .model .sym (sym_name )
363- )
364- if self .model .sym (sym_name )
365- else "tuple()"
366- for sym_name in ("p" , "k" , "y" , "x" )
367- },
368- ** {
369- # in jax model we do not need to distinguish between p (parameters) and
370- # k (fixed parameters) so we use a single variable combining both
371- "PK_SYMS" : "" .join (
372- str (strip_pysb (s )) + ", "
373- for s in chain (self .model .sym ("p" ), self .model .sym ("k" ))
374- ),
375- "PK_IDS" : "" .join (
376- f'"{ strip_pysb (s )} ", '
377- for s in chain (self .model .sym ("p" ), self .model .sym ("k" ))
378- ),
379- "MODEL_NAME" : self .model_name ,
380- # keep track of the API version that the model was generated with so we
381- # can flag conflicts in the future
382- "MODEL_API_VERSION" : f"'{ JAXModel .MODEL_API_VERSION } '" ,
383- },
384- }
385- os .makedirs (
386- os .path .join (self .model_path , self .model_name ), exist_ok = True
387- )
388-
389- apply_template (
390- os .path .join (amiciModulePath , "jax.template.py" ),
391- os .path .join (self .model_path , self .model_name , "jax.py" ),
392- tpl_data ,
393- )
394-
395273 def _generate_c_code (self ) -> None :
396274 """
397275 Create C++ code files for the model based on
@@ -795,7 +673,7 @@ def _get_function_body(
795673 lines = []
796674
797675 if len (equations ) == 0 or (
798- isinstance (equations , ( sp .Matrix , sp .ImmutableDenseMatrix ) )
676+ isinstance (equations , sp .Matrix | sp .ImmutableDenseMatrix )
799677 and min (equations .shape ) == 0
800678 ):
801679 # dJydy is a list
@@ -852,7 +730,7 @@ def _get_function_body(
852730 f"reinitialization_state_idxs.cend(), { index } ) != "
853731 "reinitialization_state_idxs.cend())" ,
854732 f" { function } [{ index } ] = "
855- f"{ self ._code_printer_cpp .doprint (formula )} ;" ,
733+ f"{ self ._code_printer .doprint (formula )} ;" ,
856734 ]
857735 )
858736 cases [ipar ] = expressions
@@ -867,12 +745,12 @@ def _get_function_body(
867745 f"reinitialization_state_idxs.cend(), { index } ) != "
868746 "reinitialization_state_idxs.cend())\n "
869747 f"{ function } [{ index } ] = "
870- f"{ self ._code_printer_cpp .doprint (formula )} ;"
748+ f"{ self ._code_printer .doprint (formula )} ;"
871749 )
872750
873751 elif function in event_functions :
874752 cases = {
875- ie : self ._code_printer_cpp ._get_sym_lines_array (
753+ ie : self ._code_printer ._get_sym_lines_array (
876754 equations [ie ], function , 0
877755 )
878756 for ie in range (self .model .num_events ())
@@ -885,7 +763,7 @@ def _get_function_body(
885763 for ie , inner_equations in enumerate (equations ):
886764 inner_lines = []
887765 inner_cases = {
888- ipar : self ._code_printer_cpp ._get_sym_lines_array (
766+ ipar : self ._code_printer ._get_sym_lines_array (
889767 inner_equations [:, ipar ], function , 0
890768 )
891769 for ipar in range (self .model .num_par ())
@@ -900,7 +778,7 @@ def _get_function_body(
900778 and equations .shape [1 ] == self .model .num_par ()
901779 ):
902780 cases = {
903- ipar : self ._code_printer_cpp ._get_sym_lines_array (
781+ ipar : self ._code_printer ._get_sym_lines_array (
904782 equations [:, ipar ], function , 0
905783 )
906784 for ipar in range (self .model .num_par ())
@@ -910,15 +788,15 @@ def _get_function_body(
910788 elif function in multiobs_functions :
911789 if function == "dJydy" :
912790 cases = {
913- iobs : self ._code_printer_cpp ._get_sym_lines_array (
791+ iobs : self ._code_printer ._get_sym_lines_array (
914792 equations [iobs ], function , 0
915793 )
916794 for iobs in range (self .model .num_obs ())
917795 if not smart_is_zero_matrix (equations [iobs ])
918796 }
919797 else :
920798 cases = {
921- iobs : self ._code_printer_cpp ._get_sym_lines_array (
799+ iobs : self ._code_printer ._get_sym_lines_array (
922800 equations [:, iobs ], function , 0
923801 )
924802 for iobs in range (equations .shape [1 ])
@@ -948,7 +826,7 @@ def _get_function_body(
948826 tmp_equations = sp .Matrix (
949827 [equations [i ] for i in static_idxs ]
950828 )
951- tmp_lines = self ._code_printer_cpp ._get_sym_lines_symbols (
829+ tmp_lines = self ._code_printer ._get_sym_lines_symbols (
952830 tmp_symbols ,
953831 tmp_equations ,
954832 function ,
@@ -974,7 +852,7 @@ def _get_function_body(
974852 [equations [i ] for i in dynamic_idxs ]
975853 )
976854
977- tmp_lines = self ._code_printer_cpp ._get_sym_lines_symbols (
855+ tmp_lines = self ._code_printer ._get_sym_lines_symbols (
978856 tmp_symbols ,
979857 tmp_equations ,
980858 function ,
@@ -986,12 +864,12 @@ def _get_function_body(
986864 lines .extend (tmp_lines )
987865
988866 else :
989- lines += self ._code_printer_cpp ._get_sym_lines_symbols (
867+ lines += self ._code_printer ._get_sym_lines_symbols (
990868 symbols , equations , function , 4
991869 )
992870
993871 else :
994- lines += self ._code_printer_cpp ._get_sym_lines_array (
872+ lines += self ._code_printer ._get_sym_lines_array (
995873 equations , function , 4
996874 )
997875
@@ -1136,8 +1014,7 @@ def _write_model_header_cpp(self) -> None:
11361014 )
11371015 ),
11381016 "NDXDOTDX_EXPLICIT" : len (self .model .sparsesym ("dxdotdx_explicit" )),
1139- "NDJYDY" : "std::vector<int>{%s}"
1140- % "," .join (str (len (x )) for x in self .model .sparsesym ("dJydy" )),
1017+ "NDJYDY" : f"std::vector<int>{{{ ',' .join (str (len (x )) for x in self .model .sparsesym ('dJydy' ))} }}" ,
11411018 "NDXRDATADXSOLVER" : len (self .model .sparsesym ("dx_rdatadx_solver" )),
11421019 "NDXRDATADTCL" : len (self .model .sparsesym ("dx_rdatadtcl" )),
11431020 "NDTOTALCLDXRDATA" : len (self .model .sparsesym ("dtotal_cldx_rdata" )),
@@ -1147,10 +1024,10 @@ def _write_model_header_cpp(self) -> None:
11471024 "NK" : self .model .num_const (),
11481025 "O2MODE" : "amici::SecondOrderMode::none" ,
11491026 # using code printer ensures proper handling of nan/inf
1150- "PARAMETERS" : self ._code_printer_cpp .doprint (self .model .val ("p" ))[
1027+ "PARAMETERS" : self ._code_printer .doprint (self .model .val ("p" ))[
11511028 1 :- 1
11521029 ],
1153- "FIXED_PARAMETERS" : self ._code_printer_cpp .doprint (
1030+ "FIXED_PARAMETERS" : self ._code_printer .doprint (
11541031 self .model .val ("k" )
11551032 )[1 :- 1 ],
11561033 "PARAMETER_NAMES_INITIALIZER_LIST" : self ._get_symbol_name_initializer_list (
@@ -1344,7 +1221,7 @@ def _get_symbol_id_initializer_list(self, name: str) -> str:
13441221 Template initializer list of ids
13451222 """
13461223 return "\n " .join (
1347- f'"{ self ._code_printer_cpp .doprint (symbol )} ", // { name } [{ idx } ]'
1224+ f'"{ self ._code_printer .doprint (symbol )} ", // { name } [{ idx } ]'
13481225 for idx , symbol in enumerate (self .model .sym (name ))
13491226 )
13501227
0 commit comments