Skip to content

Commit 7cb145a

Browse files
refactor: compile functions in generate_*
1 parent bf46865 commit 7cb145a

File tree

1 file changed

+51
-9
lines changed

1 file changed

+51
-9
lines changed

src/systems/codegen.jl

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ Generate the RHS function for the `equations` of a `System`.
1010
"""
1111
function generate_rhs(sys::System, dvs = unknowns(sys),
1212
ps = parameters(sys; initial_parameters = true); implicit_dae = false,
13-
scalar = false, kwargs...)
13+
scalar = false, expression = Val{true}, eval_expression = false,
14+
eval_module = @__MODULE__, kwargs...)
1415
eqs = equations(sys)
1516
obs = observed(sys)
1617
u = dvs
@@ -65,7 +66,20 @@ function generate_rhs(sys::System, dvs = unknowns(sys),
6566
p_start += 1
6667
end
6768

68-
build_function_wrapper(sys, rhss, args...; p_start, extra_assignments, kwargs...)
69+
res = build_function_wrapper(sys, rhss, args...; p_start, extra_assignments,
70+
expression = Val{true}, expression_module = eval_module, kwargs...)
71+
if expression == Val{true}
72+
return res
73+
end
74+
75+
if res isa Tuple
76+
f_oop, f_iip = eval_or_rgf.(res; eval_expression, eval_module)
77+
else
78+
f_oop = eval_or_rgf(res; eval_expression, eval_module)
79+
f_iip = nothing
80+
end
81+
return GeneratedFunctionWrapper{(p_start, length(args) - length(p) + 1, is_split(sys))}(
82+
f_oop, f_iip)
6983
end
7084

7185
function calculate_tgrad(sys::System; simplify = false)
@@ -111,14 +125,21 @@ end
111125

112126
function generate_jacobian(sys::System, dvs = unknowns(sys),
113127
ps = parameters(sys; initial_parameters = true);
114-
simplify = false, sparse = false, kwargs...)
128+
simplify = false, sparse = false, eval_expression = false,
129+
eval_module = @__MODULE__, expression = Val{true}, kwargs...)
115130
jac = calculate_jacobian(sys; simplify, sparse, dvs)
116131
p = reorder_parameters(sys, ps)
117132
t = get_iv(sys)
118133
if t !== nothing
119134
wrap_code = sparse ? assert_jac_length_header(sys) : (identity, identity)
120135
end
121-
return build_function_wrapper(sys, jac, dvs, p..., t; wrap_code, kwargs...)
136+
res = build_function_wrapper(sys, jac, dvs, p..., t; wrap_code, expression = Val{true},
137+
expression_module = eval_module, kwargs...)
138+
if expression == Val{true}
139+
return res
140+
end
141+
f_oop, f_iip = eval_or_rgf.(res; eval_expression, eval_module)
142+
return GeneratedFunctionWrapper{(2, 3, is_split(sys))}(f_oop, f_iip)
122143
end
123144

124145
function assert_jac_length_header(sys)
@@ -134,21 +155,31 @@ end
134155
function generate_tgrad(
135156
sys::System, dvs = unknowns(sys), ps = parameters(
136157
sys; initial_parameters = true);
137-
simplify = false, kwargs...)
158+
simplify = false, eval_expression = false, eval_module = @__MODULE__,
159+
expression = Val{true}, kwargs...)
138160
tgrad = calculate_tgrad(sys, simplify = simplify)
139161
p = reorder_parameters(sys, ps)
140-
return build_function_wrapper(sys, tgrad,
162+
res = build_function_wrapper(sys, tgrad,
141163
dvs,
142164
p...,
143165
get_iv(sys);
166+
expression = Val{true},
167+
expression_module = eval_module,
144168
kwargs...)
169+
170+
if expression == Val{true}
171+
return res
172+
end
173+
f_oop, f_iip = eval_or_rgf.(res; eval_expression, eval_module)
174+
return GeneratedFunctionWrapper{(2, 3, is_split(sys))}(f_oop, f_iip)
145175
end
146176

147177
const W_GAMMA = only(@variables ˍ₋gamma)
148178

149179
function generate_W(sys::System, γ = 1.0, dvs = unknowns(sys),
150180
ps = parameters(sys; initial_parameters = true);
151-
simplify = false, sparse = false, kwargs...)
181+
simplify = false, sparse = false, expression = Val{true},
182+
eval_expression = false, eval_module = @__MODULE__, kwargs...)
152183
M = calculate_massmatrix(sys; simplify)
153184
if sparse
154185
M = SparseArrays.sparse(M)
@@ -161,12 +192,18 @@ function generate_W(sys::System, γ = 1.0, dvs = unknowns(sys),
161192
end
162193

163194
p = reorder_parameters(sys, ps)
164-
return build_function_wrapper(sys, W, dvs, p..., W_GAMMA, t; wrap_code,
195+
res = build_function_wrapper(sys, W, dvs, p..., W_GAMMA, t; wrap_code,
165196
p_end = 1 + length(p), kwargs...)
197+
if expression == Val{true}
198+
return res
199+
end
200+
f_oop, f_iip = eval_or_rgf.(res; eval_expression, eval_module)
201+
return GeneratedFunctionWrapper{(2, 4, is_split(sys))}(f_oop, f_iip)
166202
end
167203

168204
function generate_dae_jacobian(sys::System, dvs = unknowns(sys),
169205
ps = parameters(sys; initial_parameters = true); simplify = false, sparse = false,
206+
expression = Val{true}, eval_expression = false, eval_module = @__MODULE__,
170207
kwargs...)
171208
jac_u = calculate_jacobian(sys; simplify = simplify, sparse = sparse)
172209
t = get_iv(sys)
@@ -176,8 +213,13 @@ function generate_dae_jacobian(sys::System, dvs = unknowns(sys),
176213
dvs = unknowns(sys)
177214
jac = W_GAMMA * jac_du + jac_u
178215
p = reorder_parameters(sys, ps)
179-
return build_function_wrapper(sys, jac, derivatives, dvs, p..., W_GAMMA, t;
216+
res = build_function_wrapper(sys, jac, derivatives, dvs, p..., W_GAMMA, t;
180217
p_start = 3, p_end = 2 + length(p), kwargs...)
218+
if expression == Val{true}
219+
return res
220+
end
221+
f_oop, f_iip = eval_or_rgf.(res; eval_expression, eval_module)
222+
return GeneratedFunctionWrapper{(3, 5, is_split(sys))}(f_oop, f_iip)
181223
end
182224

183225
function calculate_massmatrix(sys::System; simplify = false)

0 commit comments

Comments
 (0)