Skip to content

Commit 04ce0b7

Browse files
authored
Merge pull request #648 from SciML/s/sym-fix
fix symbol replacement in build_function
2 parents 06d15fd + 926d0b3 commit 04ce0b7

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

src/build_function.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,9 @@ function _build_function(target::JuliaTarget, op, args...;
107107
symsdict = Dict()
108108
arg_pairs = map((x,y)->vars_to_pairs(x,y, symsdict), argnames, args)
109109
process = unflatten_long_ops(x->substitute(x, symsdict, fold=false))
110-
ls = reduce(vcat,first.(arg_pairs))
110+
ls = reduce(vcat,conv.(first.(arg_pairs)))
111111
rs = reduce(vcat,last.(arg_pairs))
112-
var_eqs = Expr(:(=), ModelingToolkit.build_expr(:tuple, ls), ModelingToolkit.build_expr(:tuple, process.(rs)))
112+
var_eqs = Expr(:(=), ModelingToolkit.build_expr(:tuple, ls), ModelingToolkit.build_expr(:tuple, conv.(process.(rs))))
113113

114114
fname = gensym(:ModelingToolkitFunction)
115115
op = process(op)
@@ -233,9 +233,11 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
233233
argnames = [gensym(:MTKArg) for i in 1:length(args)]
234234
symsdict = Dict()
235235
arg_pairs = map((x,y)->vars_to_pairs(x,y, symsdict), argnames, args)
236-
ls = reduce(vcat,first.(arg_pairs))
236+
process = unflatten_long_ops(x->substitute(x, symsdict, fold=false))
237+
238+
ls = reduce(vcat,conv.(first.(arg_pairs)))
237239
rs = reduce(vcat,last.(arg_pairs))
238-
var_eqs = Expr(:(=), ModelingToolkit.build_expr(:tuple, ls), ModelingToolkit.build_expr(:tuple, rs))
240+
var_eqs = Expr(:(=), ModelingToolkit.build_expr(:tuple, ls), ModelingToolkit.build_expr(:tuple, conv.(process.(rs))))
239241

240242
fname = gensym(:ModelingToolkitFunction)
241243
fargs = Expr(:tuple,argnames...)
@@ -244,8 +246,6 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
244246
oidx = isnothing(outputidxs) ? (i -> i) : (i -> outputidxs[i])
245247
X = gensym(:MTIIPVar)
246248

247-
process = unflatten_long_ops(x->substitute(x, symsdict, fold=false))
248-
249249
if rhss isa SparseMatrixCSC
250250
rhs_length = length(rhss.nzval)
251251
rhss = SparseMatrixCSC(rhss.m, rhss.m, rhss.colptr, rhss.rowval, map(process, rhss.nzval))
@@ -464,13 +464,13 @@ end
464464
function vars_to_pairs(name,vs::Union{Tuple, AbstractArray}, symsdict=Dict())
465465
vs_names = tosymbol.(vs)
466466
for (v,k) in zip(vs_names, vs)
467-
symsdict[k] = v
467+
symsdict[k] = Sym{symtype(k)}(v)
468468
end
469469
exs = [:($name[$i]) for (i, u) enumerate(vs)]
470470
vs_names,exs
471471
end
472472
function vars_to_pairs(name,vs, symsdict)
473-
symsdict[vs] = tosymbol(vs)
473+
symsdict[vs] = Sym{symtype(vs)}(tosymbol(vs))
474474
[tosymbol(vs)], [name]
475475
end
476476

0 commit comments

Comments
 (0)