Skip to content

Commit 189a132

Browse files
Merge pull request #643 from SciML/s/fix-buildfn
fix build_function when args are terms e.g. x(t)
2 parents 1f35567 + 8f6e8a8 commit 189a132

File tree

5 files changed

+49
-23
lines changed

5 files changed

+49
-23
lines changed

src/ModelingToolkit.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ tosymbolic(a) = a
110110
@num_method Base.:(<=) (tosymbolic(a) <= tosymbolic(b)) (Real,)
111111
@num_method Base.:(>) (tosymbolic(a) > tosymbolic(b)) (Real,)
112112
@num_method Base.:(>=) (tosymbolic(a) >= tosymbolic(b)) (Real,)
113-
@num_method Base.isequal isequal(tosymbolic(a), tosymbolic(b)) (Number, Symbolic)
114-
@num_method Base.:(==) tosymbolic(a) == tosymbolic(b) (Number,)
113+
@num_method Base.isequal isequal(tosymbolic(a), tosymbolic(b)) (AbstractFloat, Number, Symbolic)
114+
@num_method Base.:(==) tosymbolic(a) == tosymbolic(b) (AbstractFloat,Number)
115115

116116
Base.hash(x::Num, h::UInt) = hash(value(x), h)
117117

src/build_function.jl

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -104,14 +104,16 @@ function _build_function(target::JuliaTarget, op, args...;
104104
linenumbers = true, headerfun=addheader)
105105

106106
argnames = [gensym(:MTKArg) for i in 1:length(args)]
107-
arg_pairs = map(vars_to_pairs,zip(argnames,args))
107+
symsdict = Dict()
108+
arg_pairs = map((x,y)->vars_to_pairs(x,y, symsdict), argnames, args)
109+
process = unflatten_long_ops(x->substitute(x, symsdict, fold=false))
108110
ls = reduce(vcat,first.(arg_pairs))
109111
rs = reduce(vcat,last.(arg_pairs))
110-
var_eqs = Expr(:(=), ModelingToolkit.build_expr(:tuple, ls), ModelingToolkit.build_expr(:tuple, unflatten_long_ops.(rs)))
112+
var_eqs = Expr(:(=), ModelingToolkit.build_expr(:tuple, ls), ModelingToolkit.build_expr(:tuple, process.(rs)))
111113

112114
fname = gensym(:ModelingToolkitFunction)
113-
op = unflatten_long_ops(op)
114-
out_expr = conv(op)
115+
op = process(op)
116+
out_expr = conv(substitute(op, symsdict, fold=false))
115117
let_expr = Expr(:let, var_eqs, Expr(:block, out_expr))
116118
bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end)
117119

@@ -229,7 +231,8 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
229231
end
230232

231233
argnames = [gensym(:MTKArg) for i in 1:length(args)]
232-
arg_pairs = map(vars_to_pairs,zip(argnames,args))
234+
symsdict = Dict()
235+
arg_pairs = map((x,y)->vars_to_pairs(x,y, symsdict), argnames, args)
233236
ls = reduce(vcat,first.(arg_pairs))
234237
rs = reduce(vcat,last.(arg_pairs))
235238
var_eqs = Expr(:(=), ModelingToolkit.build_expr(:tuple, ls), ModelingToolkit.build_expr(:tuple, rs))
@@ -241,12 +244,14 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
241244
oidx = isnothing(outputidxs) ? (i -> i) : (i -> outputidxs[i])
242245
X = gensym(:MTIIPVar)
243246

247+
process = unflatten_long_ops(x->substitute(x, symsdict, fold=false))
248+
244249
if rhss isa SparseMatrixCSC
245250
rhs_length = length(rhss.nzval)
246-
rhss = SparseMatrixCSC(rhss.m, rhss.m, rhss.colptr, rhss.rowval, map(unflatten_long_ops, rhss.nzval))
251+
rhss = SparseMatrixCSC(rhss.m, rhss.m, rhss.colptr, rhss.rowval, map(process, rhss.nzval))
247252
else
248253
rhs_length = length(rhss)
249-
rhss = [unflatten_long_ops(r) for r in rhss]
254+
rhss = [process(r) for r in rhss]
250255
end
251256

252257
if parallel isa DistributedForm
@@ -388,9 +393,9 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
388393
end
389394

390395
if rhss isa SparseMatrixCSC
391-
rhss′ = map(convunflatten_long_ops, rhss.nzval)
396+
rhss′ = map(convprocess, rhss.nzval)
392397
else
393-
rhss′ = [conv(unflatten_long_ops(r)) for r in rhss]
398+
rhss′ = [conv(process(r)) for r in rhss]
394399
end
395400

396401
tuple_sys_expr = build_expr(:tuple, rhss′)
@@ -456,13 +461,16 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
456461
end
457462
end
458463

459-
vars_to_pairs(args) = vars_to_pairs(args[1],args[2])
460-
function vars_to_pairs(name,vs::AbstractArray)
461-
vs_names = [tosymbol(u) for u vs]
464+
function vars_to_pairs(name,vs::Union{Tuple, AbstractArray}, symsdict=Dict())
465+
vs_names = tosymbol.(vs)
466+
for (v,k) in zip(vs_names, vs)
467+
symsdict[k] = v
468+
end
462469
exs = [:($name[$i]) for (i, u) enumerate(vs)]
463470
vs_names,exs
464471
end
465-
function vars_to_pairs(name,vs)
472+
function vars_to_pairs(name,vs, symsdict)
473+
symsdict[vs] = tosymbol(vs)
466474
[tosymbol(vs)], [name]
467475
end
468476

test/build_function.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ModelingToolkit, Test
1+
using ModelingToolkit, SparseArrays, Test
22
@variables a b c1 c2 c3 d e g
33

44
# Multiple argument matrix
@@ -71,3 +71,21 @@ h_oop_scalar = eval(h_str_scalar)
7171

7272
@test isequal(simplify(ModelingToolkit.unflatten_long_ops(prod(z))),
7373
simplify(prod(z)))
74+
75+
@variables t x(t) y(t) k
76+
f = eval(build_function((x+y)/k, [x,y,k]))
77+
@test f([1,1,2]) == 1
78+
79+
f = eval(build_function([(x+y)/k], [x,y,k])[1])
80+
@test f([1,1,2]) == [1]
81+
82+
f = eval(build_function([(x+y)/k], [x,y,k])[2])
83+
z = [0.0]
84+
f(z, [1,1,2])
85+
@test z == [1]
86+
87+
f = eval(build_function(sparse([1],[1], [(x+y)/k], 10,10), [x,y,k])[1])
88+
89+
@test size(f([1.,1.,2])) == (10,10)
90+
@test f([1.,1.,2])[1,1] == 1.0
91+
@test sum(f([1.,1.,2])) == 1.0

test/latexify.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ eqs = [D(u[1]) ~ p[3]*(u[2]-u[1]),
4242

4343
@test latexify(eqs) ==
4444
raw"\begin{align}
45-
\frac{du_{1(t)}}{dt} =& p_{3} \left( \mathrm{u_2}\left( t \right) - \mathrm{u_1}\left( t \right) \right) \\
46-
0 =& \frac{p_{2} p_{3} \mathrm{u_1}\left( t \right) \left( p_{1} - \mathrm{u_1}\left( t \right) \right)}{10} - \mathrm{u_2}\left( t \right) \\
47-
\frac{du_{3(t)}}{dt} =& \mathrm{u_1}\left( t \right) \left( \mathrm{u_2}\left( t \right) \right)^{\frac{2}{3}} - p_{3} \mathrm{u_3}\left( t \right)
45+
\frac{du{_1}(t)}{dt} =& p{_3} \left( \mathrm{u{_2}}\left( t \right) - \mathrm{u{_1}}\left( t \right) \right) \\
46+
0 =& \frac{p{_2} p{_3} \mathrm{u{_1}}\left( t \right) \left( p{_1} - \mathrm{u{_1}}\left( t \right) \right)}{10} - \mathrm{u{_2}}\left( t \right) \\
47+
\frac{du{_3}(t)}{dt} =& \mathrm{u{_1}}\left( t \right) \left( \mathrm{u{_2}}\left( t \right) \right)^{\frac{2}{3}} - p{_3} \mathrm{u{_3}}\left( t \right)
4848
\end{align}
4949
"
5050

@@ -56,8 +56,8 @@ sys = ODESystem(eqs)
5656

5757
@test latexify(eqs) ==
5858
raw"\begin{align}
59-
\frac{du_{1(t)}}{dt} =& p_{3} \left( \mathrm{u_2}\left( t \right) - \mathrm{u_1}\left( t \right) \right) \\
60-
\frac{du_{2(t)}}{dt} =& \frac{p_{2} p_{3} \mathrm{u_1}\left( t \right) \left( p_{1} - \mathrm{u_1}\left( t \right) \right)}{10} - \mathrm{u_2}\left( t \right) \\
61-
\frac{du_{3(t)}}{dt} =& \mathrm{u_1}\left( t \right) \left( \mathrm{u_2}\left( t \right) \right)^{\frac{2}{3}} - p_{3} \mathrm{u_3}\left( t \right)
59+
\frac{du{_1}(t)}{dt} =& p{_3} \left( \mathrm{u{_2}}\left( t \right) - \mathrm{u{_1}}\left( t \right) \right) \\
60+
\frac{du{_2}(t)}{dt} =& \frac{p{_2} p{_3} \mathrm{u{_1}}\left( t \right) \left( p{_1} - \mathrm{u{_1}}\left( t \right) \right)}{10} - \mathrm{u{_2}}\left( t \right) \\
61+
\frac{du{_3}(t)}{dt} =& \mathrm{u{_1}}\left( t \right) \left( \mathrm{u{_2}}\left( t \right) \right)^{\frac{2}{3}} - p{_3} \mathrm{u{_3}}\left( t \right)
6262
\end{align}
6363
"

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ using SafeTestsets, Test
2828
@safetestset "Depdendency Graph Test" begin include("dep_graphs.jl") end
2929
@safetestset "Function Registration Test" begin include("function_registration.jl") end
3030
@safetestset "Array of Array Test" begin include("build_function_arrayofarray.jl") end
31-
@safetestset "Latexify recipes Test" begin include("latexify.jl") end
3231
@testset "Distributed Test" begin include("distributed.jl") end
3332
@safetestset "Variable Utils Test" begin include("variable_utils.jl") end
3433
println("Last test requires gcc available in the path!")
3534
@safetestset "C Compilation Test" begin include("ccompile.jl") end
35+
@safetestset "Latexify recipes Test" begin include("latexify.jl") end

0 commit comments

Comments
 (0)