@@ -104,14 +104,16 @@ function _build_function(target::JuliaTarget, op, args...;
104
104
linenumbers = true , headerfun= addheader)
105
105
106
106
argnames = [gensym (:MTKArg ) for i in 1 : length (args)]
107
- arg_pairs = map (vars_to_pairs,zip (argnames,args))
107
+ symsdict = Dict ()
108
+ arg_pairs = map ((x,y)-> vars_to_pairs (x,y, symsdict), argnames, args)
109
+ process = unflatten_long_ops∘ (x-> substitute (x, symsdict, fold= false ))
108
110
ls = reduce (vcat,first .(arg_pairs))
109
111
rs = reduce (vcat,last .(arg_pairs))
110
- var_eqs = Expr (:(= ), ModelingToolkit. build_expr (:tuple , ls), ModelingToolkit. build_expr (:tuple , unflatten_long_ops .(rs)))
112
+ var_eqs = Expr (:(= ), ModelingToolkit. build_expr (:tuple , ls), ModelingToolkit. build_expr (:tuple , process .(rs)))
111
113
112
114
fname = gensym (:ModelingToolkitFunction )
113
- op = unflatten_long_ops (op)
114
- out_expr = conv (op )
115
+ op = process (op)
116
+ out_expr = conv (substitute (op, symsdict, fold = false ) )
115
117
let_expr = Expr (:let , var_eqs, Expr (:block , out_expr))
116
118
bounds_block = checkbounds ? let_expr : :(@inbounds begin $ let_expr end )
117
119
@@ -229,7 +231,8 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
229
231
end
230
232
231
233
argnames = [gensym (:MTKArg ) for i in 1 : length (args)]
232
- arg_pairs = map (vars_to_pairs,zip (argnames,args))
234
+ symsdict = Dict ()
235
+ arg_pairs = map ((x,y)-> vars_to_pairs (x,y, symsdict), argnames, args)
233
236
ls = reduce (vcat,first .(arg_pairs))
234
237
rs = reduce (vcat,last .(arg_pairs))
235
238
var_eqs = Expr (:(= ), ModelingToolkit. build_expr (:tuple , ls), ModelingToolkit. build_expr (:tuple , rs))
@@ -241,12 +244,14 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
241
244
oidx = isnothing (outputidxs) ? (i -> i) : (i -> outputidxs[i])
242
245
X = gensym (:MTIIPVar )
243
246
247
+ process = unflatten_long_ops∘ (x-> substitute (x, symsdict, fold= false ))
248
+
244
249
if rhss isa SparseMatrixCSC
245
250
rhs_length = length (rhss. nzval)
246
- rhss = SparseMatrixCSC (rhss. m, rhss. m, rhss. colptr, rhss. rowval, map (unflatten_long_ops , rhss. nzval))
251
+ rhss = SparseMatrixCSC (rhss. m, rhss. m, rhss. colptr, rhss. rowval, map (process , rhss. nzval))
247
252
else
248
253
rhs_length = length (rhss)
249
- rhss = [unflatten_long_ops (r) for r in rhss]
254
+ rhss = [process (r) for r in rhss]
250
255
end
251
256
252
257
if parallel isa DistributedForm
@@ -388,9 +393,9 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
388
393
end
389
394
390
395
if rhss isa SparseMatrixCSC
391
- rhss′ = map (conv∘ unflatten_long_ops , rhss. nzval)
396
+ rhss′ = map (conv∘ process , rhss. nzval)
392
397
else
393
- rhss′ = [conv (unflatten_long_ops (r)) for r in rhss]
398
+ rhss′ = [conv (process (r)) for r in rhss]
394
399
end
395
400
396
401
tuple_sys_expr = build_expr (:tuple , rhss′)
@@ -456,13 +461,16 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
456
461
end
457
462
end
458
463
459
- vars_to_pairs (args) = vars_to_pairs (args[1 ],args[2 ])
460
- function vars_to_pairs (name,vs:: AbstractArray )
461
- vs_names = [tosymbol (u) for u ∈ vs]
464
+ function vars_to_pairs (name,vs:: AbstractArray , symsdict= Dict ())
465
+ vs_names = tosymbol .(vs)
466
+ for (v,k) in zip (vs_names, vs)
467
+ symsdict[k] = v
468
+ end
462
469
exs = [:($ name[$ i]) for (i, u) ∈ enumerate (vs)]
463
470
vs_names,exs
464
471
end
465
- function vars_to_pairs (name,vs)
472
+ function vars_to_pairs (name,vs, symsdict)
473
+ symsdict[vs] = tosymbol (vs)
466
474
[tosymbol (vs)], [name]
467
475
end
468
476
0 commit comments