Skip to content

Commit 723a1d0

Browse files
committed
Support preface in more codegen functions
1 parent ff64940 commit 723a1d0

File tree

4 files changed

+24
-13
lines changed

4 files changed

+24
-13
lines changed

src/structural_transformation/codegen.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ function build_torn_function(
201201
s = structure(sys)
202202
states = map(i->s.fullvars[i], diffvars_range(s))
203203
syms = map(Symbol, states)
204+
pre = get_postprocess_fbody(sys)
204205

205206
expr = SymbolicUtils.Code.toexpr(
206207
Func(
@@ -211,10 +212,10 @@ function build_torn_function(
211212
independent_variables(sys)
212213
],
213214
[],
214-
Let(
215+
pre(Let(
215216
collect(Iterators.flatten(get_torn_eqs_vars(sys, checkbounds=checkbounds))),
216217
odefunbody
217-
)
218+
))
218219
)
219220
)
220221
if expression
@@ -307,6 +308,7 @@ function build_observed_function(
307308
obs[observed_idx[sym]].rhs
308309
end
309310
end
311+
pre = get_postprocess_fbody(sys)
310312

311313
ex = Func(
312314
[
@@ -315,13 +317,13 @@ function build_observed_function(
315317
independent_variables(sys)
316318
],
317319
[],
318-
Let(
320+
pre(Let(
319321
[
320322
collect(Iterators.flatten(solves))
321323
map(eq -> eq.lhseq.rhs, obs[1:maxidx])
322324
],
323325
isscalar ? output[1] : MakeArray(output, output_type)
324-
)
326+
))
325327
) |> Code.toexpr
326328

327329
expression ? ex : @RuntimeGeneratedFunction(ex)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,16 +102,12 @@ function generate_function(
102102
p = map(x->time_varying_as_func(value(x), sys), ps)
103103
t = get_iv(sys)
104104

105-
if has_preface(sys) && (pre = preface(sys); pre !== nothing)
106-
pre_ = ex -> Let(pre, ex)
107-
else
108-
pre_ = ex -> ex
109-
end
105+
pre = get_postprocess_fbody(sys)
110106

111107
if implicit_dae
112-
build_function(rhss, ddvs, u, p, t; postprocess_fbody=pre_, kwargs...)
108+
build_function(rhss, ddvs, u, p, t; postprocess_fbody=pre, kwargs...)
113109
else
114-
build_function(rhss, u, p, t; postprocess_fbody=pre_, kwargs...)
110+
build_function(rhss, u, p, t; postprocess_fbody=pre, kwargs...)
115111
end
116112
end
117113

src/systems/diffeqs/odesystem.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,13 +244,14 @@ function build_explicit_observed_function(
244244
ps = DestructuredArgs(parameters(sys), inbounds=!checkbounds)
245245
ivs = independent_variables(sys)
246246
args = [dvs, ps, ivs...]
247+
pre = get_postprocess_fbody(sys)
247248

248249
ex = Func(
249250
args, [],
250-
Let(
251+
pre(Let(
251252
map(eq -> eq.lhseq.rhs, obs[1:maxidx]),
252253
isscalar ? output[1] : MakeArray(output, output_type)
253-
)
254+
))
254255
) |> toexpr
255256

256257
expression ? ex : @RuntimeGeneratedFunction(ex)

src/utils.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,3 +293,15 @@ function collect_var!(states, parameters, var, iv)
293293
end
294294
return nothing
295295
end
296+
297+
298+
function get_postprocess_fbody(sys)
299+
if has_preface(sys) && (pre = preface(sys); pre !== nothing)
300+
pre_ = let pre=pre
301+
ex -> Let(pre, ex)
302+
end
303+
else
304+
pre_ = ex -> ex
305+
end
306+
return pre_
307+
end

0 commit comments

Comments
 (0)