Skip to content

Commit 4c271cc

Browse files
committed
fix build_function when args are terms e.g. x(t)
1 parent 1f35567 commit 4c271cc

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

src/build_function.jl

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -104,14 +104,16 @@ function _build_function(target::JuliaTarget, op, args...;
104104
linenumbers = true, headerfun=addheader)
105105

106106
argnames = [gensym(:MTKArg) for i in 1:length(args)]
107-
arg_pairs = map(vars_to_pairs,zip(argnames,args))
107+
symsdict = Dict()
108+
arg_pairs = map((x,y)->vars_to_pairs(x,y, symsdict), argnames, args)
109+
process = unflatten_long_ops(x->substitute(x, symsdict, fold=false))
108110
ls = reduce(vcat,first.(arg_pairs))
109111
rs = reduce(vcat,last.(arg_pairs))
110-
var_eqs = Expr(:(=), ModelingToolkit.build_expr(:tuple, ls), ModelingToolkit.build_expr(:tuple, unflatten_long_ops.(rs)))
112+
var_eqs = Expr(:(=), ModelingToolkit.build_expr(:tuple, ls), ModelingToolkit.build_expr(:tuple, process.(rs)))
111113

112114
fname = gensym(:ModelingToolkitFunction)
113-
op = unflatten_long_ops(op)
114-
out_expr = conv(op)
115+
op = process(op)
116+
out_expr = conv(substitute(op, symsdict, fold=false))
115117
let_expr = Expr(:let, var_eqs, Expr(:block, out_expr))
116118
bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end)
117119

@@ -229,7 +231,8 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
229231
end
230232

231233
argnames = [gensym(:MTKArg) for i in 1:length(args)]
232-
arg_pairs = map(vars_to_pairs,zip(argnames,args))
234+
symsdict = Dict()
235+
arg_pairs = map((x,y)->vars_to_pairs(x,y, symsdict), argnames, args)
233236
ls = reduce(vcat,first.(arg_pairs))
234237
rs = reduce(vcat,last.(arg_pairs))
235238
var_eqs = Expr(:(=), ModelingToolkit.build_expr(:tuple, ls), ModelingToolkit.build_expr(:tuple, rs))
@@ -241,12 +244,14 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
241244
oidx = isnothing(outputidxs) ? (i -> i) : (i -> outputidxs[i])
242245
X = gensym(:MTIIPVar)
243246

247+
process = unflatten_long_ops(x->substitute(x, symsdict, fold=false))
248+
244249
if rhss isa SparseMatrixCSC
245250
rhs_length = length(rhss.nzval)
246-
rhss = SparseMatrixCSC(rhss.m, rhss.m, rhss.colptr, rhss.rowval, map(unflatten_long_ops, rhss.nzval))
251+
rhss = SparseMatrixCSC(rhss.m, rhss.m, rhss.colptr, rhss.rowval, map(process, rhss.nzval))
247252
else
248253
rhs_length = length(rhss)
249-
rhss = [unflatten_long_ops(r) for r in rhss]
254+
rhss = [process(r) for r in rhss]
250255
end
251256

252257
if parallel isa DistributedForm
@@ -388,9 +393,9 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
388393
end
389394

390395
if rhss isa SparseMatrixCSC
391-
rhss′ = map(convunflatten_long_ops, rhss.nzval)
396+
rhss′ = map(convprocess, rhss.nzval)
392397
else
393-
rhss′ = [conv(unflatten_long_ops(r)) for r in rhss]
398+
rhss′ = [conv(process(r)) for r in rhss]
394399
end
395400

396401
tuple_sys_expr = build_expr(:tuple, rhss′)
@@ -456,13 +461,16 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
456461
end
457462
end
458463

459-
vars_to_pairs(args) = vars_to_pairs(args[1],args[2])
460-
function vars_to_pairs(name,vs::AbstractArray)
461-
vs_names = [tosymbol(u) for u vs]
464+
function vars_to_pairs(name,vs::AbstractArray, symsdict=Dict())
465+
vs_names = tosymbol.(vs)
466+
for (v,k) in zip(vs_names, vs)
467+
symsdict[k] = v
468+
end
462469
exs = [:($name[$i]) for (i, u) enumerate(vs)]
463470
vs_names,exs
464471
end
465-
function vars_to_pairs(name,vs)
472+
function vars_to_pairs(name,vs, symsdict)
473+
symsdict[vs] = tosymbol(vs)
466474
[tosymbol(vs)], [name]
467475
end
468476

0 commit comments

Comments
 (0)