4444 splines ,
4545)
4646from .constants import SymbolId
47- from .cxxcodeprinter import AmiciCxxCodePrinter , get_switch_statement
47+ from .cxxcodeprinter import (
48+ AmiciCxxCodePrinter ,
49+ get_switch_statement ,
50+ csc_matrix ,
51+ )
4852from .de_model import *
4953from .import_utils import (
5054 ObservableTransformation ,
@@ -725,9 +729,6 @@ class DEModel:
725729 whether all observables have a gaussian noise model, i.e. whether
726730 res and FIM make sense.
727731
728- :ivar _code_printer:
729- Code printer to generate C++ code
730-
731732 :ivar _z2event:
732733 list of event indices for each event observable
733734 """
@@ -869,10 +870,6 @@ def cached_simplify(
869870 self ._has_quadratic_nllh : bool = True
870871 set_log_level (logger , verbose )
871872
872- self ._code_printer = AmiciCxxCodePrinter ()
873- for fun in CUSTOM_FUNCTIONS :
874- self ._code_printer .known_functions [fun ["sympy" ]] = fun ["c++" ]
875-
876873 def differential_states (self ) -> list [DifferentialState ]:
877874 """Get all differential states."""
878875 return self ._differential_states
@@ -1882,7 +1879,7 @@ def _generate_sparse_symbol(self, name: str) -> None:
18821879 sparse_list ,
18831880 symbol_list ,
18841881 sparse_matrix ,
1885- ) = self . _code_printer . csc_matrix (
1882+ ) = csc_matrix (
18861883 matrix [iy , :],
18871884 rownames = rownames ,
18881885 colnames = colnames ,
@@ -1900,7 +1897,7 @@ def _generate_sparse_symbol(self, name: str) -> None:
19001897 sparse_list ,
19011898 symbol_list ,
19021899 sparse_matrix ,
1903- ) = self . _code_printer . csc_matrix (
1900+ ) = csc_matrix (
19041901 matrix ,
19051902 rownames = rownames ,
19061903 colnames = colnames ,
@@ -2884,6 +2881,9 @@ class DEExporter:
28842881 If the given model uses special functions, this set contains hints for
28852882 model building.
28862883
2884+ :ivar _code_printer:
2885+ Code printer to generate C++ code
2886+
28872887 :ivar generate_sensitivity_code:
28882888 Specifies whether code for sensitivity computation is to be generated
28892889
@@ -2950,10 +2950,14 @@ def __init__(
29502950 self .set_name (model_name )
29512951 self .set_paths (outdir )
29522952
2953+ self ._code_printer = AmiciCxxCodePrinter ()
2954+ for fun in CUSTOM_FUNCTIONS :
2955+ self ._code_printer .known_functions [fun ["sympy" ]] = fun ["c++" ]
2956+
29532957 # Signatures and properties of generated model functions (see
29542958 # include/amici/model.h for details)
29552959 self .model : DEModel = de_model
2956- self .model . _code_printer .known_functions .update (
2960+ self ._code_printer .known_functions .update (
29572961 splines .spline_user_functions (
29582962 self .model ._splines , self ._get_index ("p" )
29592963 )
@@ -3519,7 +3523,7 @@ def _get_function_body(
35193523 f"reinitialization_state_idxs.cend(), { index } ) != "
35203524 "reinitialization_state_idxs.cend())" ,
35213525 f" { function } [{ index } ] = "
3522- f"{ self .model . _code_printer .doprint (formula )} ;" ,
3526+ f"{ self ._code_printer .doprint (formula )} ;" ,
35233527 ]
35243528 )
35253529 cases [ipar ] = expressions
@@ -3534,12 +3538,12 @@ def _get_function_body(
35343538 f"reinitialization_state_idxs.cend(), { index } ) != "
35353539 "reinitialization_state_idxs.cend())\n "
35363540 f"{ function } [{ index } ] = "
3537- f"{ self .model . _code_printer .doprint (formula )} ;"
3541+ f"{ self ._code_printer .doprint (formula )} ;"
35383542 )
35393543
35403544 elif function in event_functions :
35413545 cases = {
3542- ie : self .model . _code_printer ._get_sym_lines_array (
3546+ ie : self ._code_printer ._get_sym_lines_array (
35433547 equations [ie ], function , 0
35443548 )
35453549 for ie in range (self .model .num_events ())
@@ -3552,7 +3556,7 @@ def _get_function_body(
35523556 for ie , inner_equations in enumerate (equations ):
35533557 inner_lines = []
35543558 inner_cases = {
3555- ipar : self .model . _code_printer ._get_sym_lines_array (
3559+ ipar : self ._code_printer ._get_sym_lines_array (
35563560 inner_equations [:, ipar ], function , 0
35573561 )
35583562 for ipar in range (self .model .num_par ())
@@ -3567,7 +3571,7 @@ def _get_function_body(
35673571 and equations .shape [1 ] == self .model .num_par ()
35683572 ):
35693573 cases = {
3570- ipar : self .model . _code_printer ._get_sym_lines_array (
3574+ ipar : self ._code_printer ._get_sym_lines_array (
35713575 equations [:, ipar ], function , 0
35723576 )
35733577 for ipar in range (self .model .num_par ())
@@ -3577,15 +3581,15 @@ def _get_function_body(
35773581 elif function in multiobs_functions :
35783582 if function == "dJydy" :
35793583 cases = {
3580- iobs : self .model . _code_printer ._get_sym_lines_array (
3584+ iobs : self ._code_printer ._get_sym_lines_array (
35813585 equations [iobs ], function , 0
35823586 )
35833587 for iobs in range (self .model .num_obs ())
35843588 if not smart_is_zero_matrix (equations [iobs ])
35853589 }
35863590 else :
35873591 cases = {
3588- iobs : self .model . _code_printer ._get_sym_lines_array (
3592+ iobs : self ._code_printer ._get_sym_lines_array (
35893593 equations [:, iobs ], function , 0
35903594 )
35913595 for iobs in range (equations .shape [1 ])
@@ -3605,12 +3609,12 @@ def _get_function_body(
36053609 symbols = list (map (sp .Symbol , self .model .sparsesym (function )))
36063610 else :
36073611 symbols = self .model .sym (function )
3608- lines += self .model . _code_printer ._get_sym_lines_symbols (
3612+ lines += self ._code_printer ._get_sym_lines_symbols (
36093613 symbols , equations , function , 4
36103614 )
36113615
36123616 else :
3613- lines += self .model . _code_printer ._get_sym_lines_array (
3617+ lines += self ._code_printer ._get_sym_lines_array (
36143618 equations , function , 4
36153619 )
36163620
@@ -3766,10 +3770,10 @@ def _write_model_header_cpp(self) -> None:
37663770 "NK" : self .model .num_const (),
37673771 "O2MODE" : "amici::SecondOrderMode::none" ,
37683772 # using code printer ensures proper handling of nan/inf
3769- "PARAMETERS" : self .model . _code_printer .doprint (
3770- self . model . val ( "p" )
3771- )[ 1 : - 1 ],
3772- "FIXED_PARAMETERS" : self .model . _code_printer .doprint (
3773+ "PARAMETERS" : self ._code_printer .doprint (self . model . val ( "p" ))[
3774+ 1 : - 1
3775+ ],
3776+ "FIXED_PARAMETERS" : self ._code_printer .doprint (
37733777 self .model .val ("k" )
37743778 )[1 :- 1 ],
37753779 "PARAMETER_NAMES_INITIALIZER_LIST" : self ._get_symbol_name_initializer_list (
@@ -3961,7 +3965,7 @@ def _get_symbol_id_initializer_list(self, name: str) -> str:
39613965 Template initializer list of ids
39623966 """
39633967 return "\n " .join (
3964- f'"{ self .model . _code_printer .doprint (symbol )} ", // { name } [{ idx } ]'
3968+ f'"{ self ._code_printer .doprint (symbol )} ", // { name } [{ idx } ]'
39653969 for idx , symbol in enumerate (self .model .sym (name ))
39663970 )
39673971
0 commit comments