Skip to content

Commit af4fe17

Browse files
committed
build_targets
1 parent b8b0941 commit af4fe17

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

src/build_function.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -470,33 +470,36 @@ function rm_calls_with_iv(expr)
470470
Rewriters.Prewalk(Rewriters.Chain([@rule((~f::(x->x isa Sym))(~t::(x->x isa Sym)) => Sym{symtype((~f)(~t))}((term_to_symbol(~f))))]))(value(expr))
471471
end
472472

473-
get_varnumber(varop::Operation,vars::Vector{Operation}) = findfirst(x->isequal(x,varop),vars)
474-
get_varnumber(varop::Operation,vars::Vector{<:Variable}) = findfirst(x->isequal(x,varop.op),vars)
473+
get_varnumber(varop, vars::Vector) = findfirst(x->isequal(x,varop),vars)
475474

476-
function numbered_expr(O::Operation,args...;varordering = args[1],offset = 0,
475+
function numbered_expr(O::Union{Term,Sym},args...;varordering = args[1],offset = 0,
477476
lhsname=gensym("du"),rhsnames=[gensym("MTK") for i in 1:length(args)])
478-
if isa(O.op, ModelingToolkit.Variable)
477+
O = value(O)
478+
if O isa Sym || isa(O.op, Sym)
479479
for j in 1:length(args)
480480
i = get_varnumber(O,args[j])
481481
if i !== nothing
482482
return :($(rhsnames[j])[$(i+offset)])
483483
end
484484
end
485485
end
486-
return Expr(:call, Symbol(O.op),
486+
return Expr(:call, O isa Sym ? nameof(O) : Symbol(O.op),
487487
[numbered_expr(x,args...;offset=offset,lhsname=lhsname,
488488
rhsnames=rhsnames,varordering=varordering) for x in O.args]...)
489489
end
490490

491491
function numbered_expr(de::ModelingToolkit.Equation,args...;varordering = args[1],
492492
lhsname=gensym("du"),rhsnames=[gensym("MTK") for i in 1:length(args)],offset=0)
493+
494+
varordering = value.(args[1])
493495
i = findfirst(x->isequal(x isa Sym ? term_to_symbol(x) : term_to_symbol(x.op),term_to_symbol(var_from_nested_derivative(de.lhs)[1])),varordering)
494496
:($lhsname[$(i+offset)] = $(numbered_expr(de.rhs,args...;offset=offset,
495497
varordering = varordering,
496498
lhsname = lhsname,
497499
rhsnames = rhsnames)))
498500
end
499-
numbered_expr(c::ModelingToolkit.Constant,args...;kwargs...) = c.value
501+
numbered_expr(c,args...;kwargs...) = c
502+
numbered_expr(c::Num,args...;kwargs...) = error("Num found")
500503

501504
"""
502505
Build function target: CTarget

test/build_targets.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ModelingToolkit, Test
44
@derivatives D'~t
55
eqs = [D(x) ~ a*x - x*y,
66
D(y) ~ -3y + x*y]
7-
@test ModelingToolkit.build_function(eqs,convert.(Variable,[x,y]),convert.(Variable,[a]),t,target = ModelingToolkit.StanTarget()) ==
7+
@test ModelingToolkit.build_function(eqs,[x,y],[a],t,target = ModelingToolkit.StanTarget()) ==
88
"""
99
real[] diffeqf(real t,real[] internal_var___u,real[] internal_var___p,real[] x_r,int[] x_i) {
1010
real internal_var___du[2];

0 commit comments

Comments
 (0)