Skip to content

Commit c4bcb9d

Browse files
authored
Refactor de_export.py, extract C++ function info to cxx_functions.py (#2321)
Move everything related to information on C++ model functions to a separate module. Related to #2306. No changes in functionality.
1 parent 504f4bc commit c4bcb9d

File tree

2 files changed

+399
-374
lines changed

2 files changed

+399
-374
lines changed
Lines changed: 385 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,385 @@
1+
"""Info about C++ functions in the generated model code."""
2+
from __future__ import annotations
3+
4+
from dataclasses import dataclass
5+
6+
7+
@dataclass
8+
class _FunctionInfo:
9+
"""Information on a model-specific generated C++ function
10+
11+
:ivar ode_arguments: argument list of the ODE function.
12+
input variables should be ``const``.
13+
:ivar dae_arguments: argument list of the DAE function, if different from
14+
ODE function. input variables should be ``const``.
15+
:ivar return_type: the return type of the function
16+
:ivar assume_pow_positivity:
17+
identifies the functions on which ``assume_pow_positivity`` will have
18+
an effect when specified during model generation. generally these are
19+
functions that are used for solving the ODE, where negative values may
20+
negatively affect convergence of the integration algorithm
21+
:ivar sparse:
22+
specifies whether the result of this function will be stored in sparse
23+
format. sparse format means that the function will only return an
24+
array of nonzero values and not a full matrix.
25+
:ivar generate_body:
26+
indicates whether a model-specific implementation is to be generated
27+
:ivar body:
28+
the actual function body. will be filled later
29+
"""
30+
31+
ode_arguments: str = ""
32+
dae_arguments: str = ""
33+
return_type: str = "void"
34+
assume_pow_positivity: bool = False
35+
sparse: bool = False
36+
generate_body: bool = True
37+
body: str = ""
38+
39+
def arguments(self, ode: bool = True) -> str:
40+
"""Get the arguments for the ODE or DAE function"""
41+
if ode or not self.dae_arguments:
42+
return self.ode_arguments
43+
return self.dae_arguments
44+
45+
46+
# Information on a model-specific generated C++ function
47+
# prototype for generated C++ functions, keys are the names of functions
48+
functions = {
49+
"Jy": _FunctionInfo(
50+
"realtype *Jy, const int iy, const realtype *p, "
51+
"const realtype *k, const realtype *y, const realtype *sigmay, "
52+
"const realtype *my"
53+
),
54+
"dJydsigma": _FunctionInfo(
55+
"realtype *dJydsigma, const int iy, const realtype *p, "
56+
"const realtype *k, const realtype *y, const realtype *sigmay, "
57+
"const realtype *my"
58+
),
59+
"dJydy": _FunctionInfo(
60+
"realtype *dJydy, const int iy, const realtype *p, "
61+
"const realtype *k, const realtype *y, "
62+
"const realtype *sigmay, const realtype *my",
63+
sparse=True,
64+
),
65+
"Jz": _FunctionInfo(
66+
"realtype *Jz, const int iz, const realtype *p, const realtype *k, "
67+
"const realtype *z, const realtype *sigmaz, const realtype *mz"
68+
),
69+
"dJzdsigma": _FunctionInfo(
70+
"realtype *dJzdsigma, const int iz, const realtype *p, "
71+
"const realtype *k, const realtype *z, const realtype *sigmaz, "
72+
"const realtype *mz"
73+
),
74+
"dJzdz": _FunctionInfo(
75+
"realtype *dJzdz, const int iz, const realtype *p, "
76+
"const realtype *k, const realtype *z, const realtype *sigmaz, "
77+
"const double *mz",
78+
),
79+
"Jrz": _FunctionInfo(
80+
"realtype *Jrz, const int iz, const realtype *p, "
81+
"const realtype *k, const realtype *rz, const realtype *sigmaz"
82+
),
83+
"dJrzdsigma": _FunctionInfo(
84+
"realtype *dJrzdsigma, const int iz, const realtype *p, "
85+
"const realtype *k, const realtype *rz, const realtype *sigmaz"
86+
),
87+
"dJrzdz": _FunctionInfo(
88+
"realtype *dJrzdz, const int iz, const realtype *p, "
89+
"const realtype *k, const realtype *rz, const realtype *sigmaz",
90+
),
91+
"root": _FunctionInfo(
92+
"realtype *root, const realtype t, const realtype *x, "
93+
"const realtype *p, const realtype *k, const realtype *h, "
94+
"const realtype *tcl"
95+
),
96+
"dwdp": _FunctionInfo(
97+
"realtype *dwdp, const realtype t, const realtype *x, "
98+
"const realtype *p, const realtype *k, const realtype *h, "
99+
"const realtype *w, const realtype *tcl, const realtype *dtcldp, "
100+
"const realtype *spl, const realtype *sspl, bool include_static",
101+
assume_pow_positivity=True,
102+
sparse=True,
103+
),
104+
"dwdx": _FunctionInfo(
105+
"realtype *dwdx, const realtype t, const realtype *x, "
106+
"const realtype *p, const realtype *k, const realtype *h, "
107+
"const realtype *w, const realtype *tcl, const realtype *spl, "
108+
"bool include_static",
109+
assume_pow_positivity=True,
110+
sparse=True,
111+
),
112+
"create_splines": _FunctionInfo(
113+
"const realtype *p, const realtype *k",
114+
return_type="std::vector<HermiteSpline>",
115+
),
116+
"spl": _FunctionInfo(generate_body=False),
117+
"sspl": _FunctionInfo(generate_body=False),
118+
"spline_values": _FunctionInfo(
119+
"const realtype *p, const realtype *k", generate_body=False
120+
),
121+
"spline_slopes": _FunctionInfo(
122+
"const realtype *p, const realtype *k", generate_body=False
123+
),
124+
"dspline_valuesdp": _FunctionInfo(
125+
"realtype *dspline_valuesdp, const realtype *p, const realtype *k, "
126+
"const int ip"
127+
),
128+
"dspline_slopesdp": _FunctionInfo(
129+
"realtype *dspline_slopesdp, const realtype *p, const realtype *k, "
130+
"const int ip"
131+
),
132+
"dwdw": _FunctionInfo(
133+
"realtype *dwdw, const realtype t, const realtype *x, "
134+
"const realtype *p, const realtype *k, const realtype *h, "
135+
"const realtype *w, const realtype *tcl, bool include_static",
136+
assume_pow_positivity=True,
137+
sparse=True,
138+
),
139+
"dxdotdw": _FunctionInfo(
140+
"realtype *dxdotdw, const realtype t, const realtype *x, "
141+
"const realtype *p, const realtype *k, const realtype *h, "
142+
"const realtype *w",
143+
"realtype *dxdotdw, const realtype t, const realtype *x, "
144+
"const realtype *p, const realtype *k, const realtype *h, "
145+
"const realtype *dx, const realtype *w",
146+
assume_pow_positivity=True,
147+
sparse=True,
148+
),
149+
"dxdotdx_explicit": _FunctionInfo(
150+
"realtype *dxdotdx_explicit, const realtype t, "
151+
"const realtype *x, const realtype *p, const realtype *k, "
152+
"const realtype *h, const realtype *w",
153+
"realtype *dxdotdx_explicit, const realtype t, "
154+
"const realtype *x, const realtype *p, const realtype *k, "
155+
"const realtype *h, const realtype *dx, const realtype *w",
156+
assume_pow_positivity=True,
157+
sparse=True,
158+
),
159+
"dxdotdp_explicit": _FunctionInfo(
160+
"realtype *dxdotdp_explicit, const realtype t, "
161+
"const realtype *x, const realtype *p, const realtype *k, "
162+
"const realtype *h, const realtype *w",
163+
"realtype *dxdotdp_explicit, const realtype t, "
164+
"const realtype *x, const realtype *p, const realtype *k, "
165+
"const realtype *h, const realtype *dx, const realtype *w",
166+
assume_pow_positivity=True,
167+
sparse=True,
168+
),
169+
"dydx": _FunctionInfo(
170+
"realtype *dydx, const realtype t, const realtype *x, "
171+
"const realtype *p, const realtype *k, const realtype *h, "
172+
"const realtype *w, const realtype *dwdx",
173+
),
174+
"dydp": _FunctionInfo(
175+
"realtype *dydp, const realtype t, const realtype *x, "
176+
"const realtype *p, const realtype *k, const realtype *h, "
177+
"const int ip, const realtype *w, const realtype *tcl, "
178+
"const realtype *dtcldp, const realtype *spl, const realtype *sspl"
179+
),
180+
"dzdx": _FunctionInfo(
181+
"realtype *dzdx, const int ie, const realtype t, "
182+
"const realtype *x, const realtype *p, const realtype *k, "
183+
"const realtype *h",
184+
),
185+
"dzdp": _FunctionInfo(
186+
"realtype *dzdp, const int ie, const realtype t, "
187+
"const realtype *x, const realtype *p, const realtype *k, "
188+
"const realtype *h, const int ip",
189+
),
190+
"drzdx": _FunctionInfo(
191+
"realtype *drzdx, const int ie, const realtype t, "
192+
"const realtype *x, const realtype *p, const realtype *k, "
193+
"const realtype *h",
194+
),
195+
"drzdp": _FunctionInfo(
196+
"realtype *drzdp, const int ie, const realtype t, "
197+
"const realtype *x, const realtype *p, const realtype *k, "
198+
"const realtype *h, const int ip",
199+
),
200+
"dsigmaydy": _FunctionInfo(
201+
"realtype *dsigmaydy, const realtype t, const realtype *p, "
202+
"const realtype *k, const realtype *y"
203+
),
204+
"dsigmaydp": _FunctionInfo(
205+
"realtype *dsigmaydp, const realtype t, const realtype *p, "
206+
"const realtype *k, const realtype *y, const int ip",
207+
),
208+
"sigmay": _FunctionInfo(
209+
"realtype *sigmay, const realtype t, const realtype *p, "
210+
"const realtype *k, const realtype *y",
211+
),
212+
"dsigmazdp": _FunctionInfo(
213+
"realtype *dsigmazdp, const realtype t, const realtype *p,"
214+
" const realtype *k, const int ip",
215+
),
216+
"sigmaz": _FunctionInfo(
217+
"realtype *sigmaz, const realtype t, const realtype *p, "
218+
"const realtype *k",
219+
),
220+
"sroot": _FunctionInfo(
221+
"realtype *stau, const realtype t, const realtype *x, "
222+
"const realtype *p, const realtype *k, const realtype *h, "
223+
"const realtype *sx, const int ip, const int ie, "
224+
"const realtype *tcl",
225+
generate_body=False,
226+
),
227+
"drootdt": _FunctionInfo(generate_body=False),
228+
"drootdt_total": _FunctionInfo(generate_body=False),
229+
"drootdp": _FunctionInfo(generate_body=False),
230+
"drootdx": _FunctionInfo(generate_body=False),
231+
"stau": _FunctionInfo(
232+
"realtype *stau, const realtype t, const realtype *x, "
233+
"const realtype *p, const realtype *k, const realtype *h, "
234+
"const realtype *tcl, const realtype *sx, const int ip, "
235+
"const int ie"
236+
),
237+
"deltax": _FunctionInfo(
238+
"double *deltax, const realtype t, const realtype *x, "
239+
"const realtype *p, const realtype *k, const realtype *h, "
240+
"const int ie, const realtype *xdot, const realtype *xdot_old"
241+
),
242+
"ddeltaxdx": _FunctionInfo(generate_body=False),
243+
"ddeltaxdt": _FunctionInfo(generate_body=False),
244+
"ddeltaxdp": _FunctionInfo(generate_body=False),
245+
"deltasx": _FunctionInfo(
246+
"realtype *deltasx, const realtype t, const realtype *x, "
247+
"const realtype *p, const realtype *k, const realtype *h, "
248+
"const realtype *w, const int ip, const int ie, "
249+
"const realtype *xdot, const realtype *xdot_old, "
250+
"const realtype *sx, const realtype *stau, const realtype *tcl"
251+
),
252+
"w": _FunctionInfo(
253+
"realtype *w, const realtype t, const realtype *x, "
254+
"const realtype *p, const realtype *k, "
255+
"const realtype *h, const realtype *tcl, const realtype *spl, "
256+
"bool include_static",
257+
assume_pow_positivity=True,
258+
),
259+
"x0": _FunctionInfo(
260+
"realtype *x0, const realtype t, const realtype *p, "
261+
"const realtype *k"
262+
),
263+
"x0_fixedParameters": _FunctionInfo(
264+
"realtype *x0_fixedParameters, const realtype t, "
265+
"const realtype *p, const realtype *k, "
266+
"gsl::span<const int> reinitialization_state_idxs",
267+
),
268+
"sx0": _FunctionInfo(
269+
"realtype *sx0, const realtype t, const realtype *x, "
270+
"const realtype *p, const realtype *k, const int ip",
271+
),
272+
"sx0_fixedParameters": _FunctionInfo(
273+
"realtype *sx0_fixedParameters, const realtype t, "
274+
"const realtype *x0, const realtype *p, const realtype *k, "
275+
"const int ip, gsl::span<const int> reinitialization_state_idxs",
276+
),
277+
"xdot": _FunctionInfo(
278+
"realtype *xdot, const realtype t, const realtype *x, "
279+
"const realtype *p, const realtype *k, const realtype *h, "
280+
"const realtype *w",
281+
"realtype *xdot, const realtype t, const realtype *x, "
282+
"const realtype *p, const realtype *k, const realtype *h, "
283+
"const realtype *dx, const realtype *w",
284+
assume_pow_positivity=True,
285+
),
286+
"xdot_old": _FunctionInfo(generate_body=False),
287+
"y": _FunctionInfo(
288+
"realtype *y, const realtype t, const realtype *x, "
289+
"const realtype *p, const realtype *k, "
290+
"const realtype *h, const realtype *w",
291+
),
292+
"x_rdata": _FunctionInfo(
293+
"realtype *x_rdata, const realtype *x, const realtype *tcl, "
294+
"const realtype *p, const realtype *k"
295+
),
296+
"total_cl": _FunctionInfo(
297+
"realtype *total_cl, const realtype *x_rdata, "
298+
"const realtype *p, const realtype *k"
299+
),
300+
"dtotal_cldp": _FunctionInfo(
301+
"realtype *dtotal_cldp, const realtype *x_rdata, "
302+
"const realtype *p, const realtype *k, const int ip"
303+
),
304+
"dtotal_cldx_rdata": _FunctionInfo(
305+
"realtype *dtotal_cldx_rdata, const realtype *x_rdata, "
306+
"const realtype *p, const realtype *k, const realtype *tcl",
307+
sparse=True,
308+
),
309+
"x_solver": _FunctionInfo("realtype *x_solver, const realtype *x_rdata"),
310+
"dx_rdatadx_solver": _FunctionInfo(
311+
"realtype *dx_rdatadx_solver, const realtype *x, "
312+
"const realtype *tcl, const realtype *p, const realtype *k",
313+
sparse=True,
314+
),
315+
"dx_rdatadp": _FunctionInfo(
316+
"realtype *dx_rdatadp, const realtype *x, "
317+
"const realtype *tcl, const realtype *p, const realtype *k, "
318+
"const int ip"
319+
),
320+
"dx_rdatadtcl": _FunctionInfo(
321+
"realtype *dx_rdatadtcl, const realtype *x, "
322+
"const realtype *tcl, const realtype *p, const realtype *k",
323+
sparse=True,
324+
),
325+
"z": _FunctionInfo(
326+
"realtype *z, const int ie, const realtype t, const realtype *x, "
327+
"const realtype *p, const realtype *k, const realtype *h"
328+
),
329+
"rz": _FunctionInfo(
330+
"realtype *rz, const int ie, const realtype t, const realtype *x, "
331+
"const realtype *p, const realtype *k, const realtype *h"
332+
),
333+
}
334+
335+
#: list of sparse functions
336+
sparse_functions = [
337+
func_name for func_name, func_info in functions.items() if func_info.sparse
338+
]
339+
340+
#: list of nobody functions
341+
nobody_functions = [
342+
func_name
343+
for func_name, func_info in functions.items()
344+
if not func_info.generate_body
345+
]
346+
347+
#: list of sensitivity functions
348+
sensi_functions = [
349+
func_name
350+
for func_name, func_info in functions.items()
351+
if "const int ip" in func_info.arguments()
352+
]
353+
354+
#: list of sparse sensitivity functions
355+
sparse_sensi_functions = [
356+
func_name
357+
for func_name, func_info in functions.items()
358+
if "const int ip" not in func_info.arguments()
359+
and func_name.endswith("dp")
360+
or func_name.endswith("dp_explicit")
361+
]
362+
363+
#: list of event functions
364+
event_functions = [
365+
func_name
366+
for func_name, func_info in functions.items()
367+
if "const int ie" in func_info.arguments()
368+
and "const int ip" not in func_info.arguments()
369+
]
370+
371+
#: list of event sensitivity functions
372+
event_sensi_functions = [
373+
func_name
374+
for func_name, func_info in functions.items()
375+
if "const int ie" in func_info.arguments()
376+
and "const int ip" in func_info.arguments()
377+
]
378+
379+
#: list of multiobs functions
380+
multiobs_functions = [
381+
func_name
382+
for func_name, func_info in functions.items()
383+
if "const int iy" in func_info.arguments()
384+
or "const int iz" in func_info.arguments()
385+
]

0 commit comments

Comments
 (0)