Skip to content

Commit 7a5b376

Browse files
generalize numbered exprs
1 parent 979df0d commit 7a5b376

File tree

2 files changed

+34
-44
lines changed

2 files changed

+34
-44
lines changed

src/build_function.jl

Lines changed: 31 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -430,55 +430,44 @@ function numbered_expr(O::Equation,args...;kwargs...)
430430
:($(numbered_expr(O.lhs,args...;kwargs...)) = $(numbered_expr(O.rhs,args...;kwargs...)))
431431
end
432432

433-
function numbered_expr(O::Operation,vars,parameters;offset = 0,
434-
derivname=:du,
435-
varname=:u,paramname=:p)
436-
if isa(O.op, ModelingToolkit.Differential)
437-
varop = O.args[1]
438-
i = get_varnumber(varop,vars)
439-
return :($derivname[$(i+offset)])
440-
elseif isa(O.op, ModelingToolkit.Variable)
441-
i = get_varnumber(O,vars)
442-
if i == nothing
443-
i = get_varnumber(O,parameters)
444-
return :($paramname[$(i+offset)])
445-
else
446-
return :($varname[$(i+offset)])
447-
end
433+
function numbered_expr(O::Operation,args...;varordering = args[1],offset = 0,
434+
lhsname=gensym("du"),rhsnames=[gensym("MTK") for i in 1:length(args)])
435+
if isa(O.op, ModelingToolkit.Variable)
436+
for j in 1:length(args)
437+
i = get_varnumber(O,args[j])
438+
if i !== nothing
439+
return :($(rhsnames[j])[$(i+offset)])
440+
end
441+
end
448442
end
449443
return Expr(:call, Symbol(O.op),
450-
[numbered_expr(x,vars,parameters;offset=offset,derivname=derivname,
451-
varname=varname,paramname=paramname) for x in O.args]...)
444+
[numbered_expr(x,args...;offset=offset,lhsname=lhsname,
445+
rhsnames=rhsnames,varordering=varordering) for x in O.args]...)
452446
end
453447

454-
function numbered_expr(de::ModelingToolkit.Equation,vars::Vector{<:Variable},parameters;
455-
derivname=:du,varname=:u,paramname=:p,offset=0)
456-
i = findfirst(x->isequal(x.name,var_from_nested_derivative(de.lhs)[1].name),vars)
457-
:($derivname[$(i+offset)] = $(numbered_expr(de.rhs,vars,parameters;offset=offset,
458-
derivname=derivname,
459-
varname=varname,paramname=paramname)))
460-
end
461-
function numbered_expr(de::ModelingToolkit.Equation,vars::Vector{Operation},parameters;
462-
derivname=:du,varname=:u,paramname=:p,offset=0)
463-
i = findfirst(x->isequal(x.op.name,var_from_nested_derivative(de.lhs)[1].name),vars)
464-
:($derivname[$(i+offset)] = $(numbered_expr(de.rhs,vars,parameters;offset=offset,
465-
derivname=derivname,
466-
varname=varname,paramname=paramname)))
448+
function numbered_expr(de::ModelingToolkit.Equation,args...;varordering = args[1],
449+
lhsname=gensym("du"),rhsnames=[gensym("MTK") for i in 1:length(args)],offset=0)
450+
i = findfirst(x->isequal(x isa Variable ? x.name : x.op.name,var_from_nested_derivative(de.lhs)[1].name),varordering)
451+
:($lhsname[$(i+offset)] = $(numbered_expr(de.rhs,args...;offset=offset,
452+
varordering = varordering,
453+
lhsname = lhsname,
454+
rhsnames = rhsnames)))
467455
end
468456
numbered_expr(c::ModelingToolkit.Constant,args...;kwargs...) = c.value
469457

470458
function _build_function(target::StanTarget, eqs, vs, ps, iv,
471459
conv = simplified_expr, expression = Val{true};
472-
fname = :diffeqf, derivname=:internal_var___du,
460+
fname = :diffeqf, lhsname=:internal_var___du,
473461
varname=:internal_var___u,paramname=:internal_var___p)
474-
differential_equation = string(join([numbered_expr(eq,vs,ps,derivname=derivname,
475-
varname=varname,paramname=paramname) for
462+
rhsnames=[varname,paramname]
463+
differential_equation = string(join([numbered_expr(eq,vs,ps,lhsname=lhsname,
464+
rhsnames=rhsnames) for
476465
(i, eq) enumerate(eqs)],";\n "),";")
477466
"""
478-
real[] $fname(real $iv,real[] $varname,real[] $paramname,real[] x_r,int[] x_i) {
479-
real $derivname[$(length(eqs))];
467+
real[] $fname(real $iv,real[] $(rhsnames[1]),real[] $(rhsnames[2]),real[] x_r,int[] x_i) {
468+
real $lhsname[$(length(eqs))];
480469
$differential_equation
481-
return $derivname;
470+
return $lhsname;
482471
}
483472
"""
484473
end
@@ -487,8 +476,8 @@ function _build_function(target::CTarget, eqs, vs, ps, iv;
487476
conv = simplified_expr, expression = Val{true},
488477
fname = :diffeqf, derivname=:internal_var___du,
489478
varname=:internal_var___u,paramname=:internal_var___p)
490-
differential_equation = string(join([numbered_expr(eq,vs,ps,derivname=derivname,
491-
varname=varname,paramname=paramname,offset=-1) for
479+
differential_equation = string(join([numbered_expr(eq,vs,ps,lhsname=derivname,
480+
rhsnames=[varname,paramname],offset=-1) for
492481
(i, eq) enumerate(eqs)],";\n "),";")
493482
"""
494483
void $fname(double* $derivname, double* $varname, double* $paramname, double $iv) {
@@ -501,13 +490,14 @@ function _build_function(target::MATLABTarget, eqs, vs, ps, iv;
501490
conv = simplified_expr, expression = Val{true},
502491
fname = :diffeqf, derivname=:internal_var___du,
503492
varname=:internal_var___u,paramname=:internal_var___p)
504-
matstr = join([numbered_expr(eq.rhs,vs,ps,derivname=derivname,
505-
varname=varname,paramname=paramname) for
493+
rhsnames=[varname,paramname]
494+
matstr = join([numbered_expr(eq.rhs,vs,ps,lhsname=derivname,
495+
rhsnames=rhsnames) for
506496
(i, eq) enumerate(eqs)],"; ")
507497

508498
matstr = replace(matstr,"["=>"(")
509499
matstr = replace(matstr,"]"=>")")
510-
matstr = "$fname = @(t,$varname) ["*matstr*"];"
500+
matstr = "$fname = @(t,$(rhsnames[1])) ["*matstr*"];"
511501
matstr
512502
end
513503

test/build_targets.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ eqs = [D(x) ~ a*x - x*y,
1616

1717
@test ModelingToolkit.build_function(eqs,[x,y],[a],t,target = ModelingToolkit.CTarget()) ==
1818
"""
19-
void diffeqf(double* internal_var___du, double* internal_var___u, double* internal_var___p, t) {
20-
internal_var___du[1] = internal_var___p[1] * internal_var___u[1] - internal_var___u[1] * internal_var___u[2];
21-
internal_var___du[2] = -3 * internal_var___u[2] + internal_var___u[1] * internal_var___u[2];
19+
void diffeqf(double* internal_var___du, double* internal_var___u, double* internal_var___p, double t) {
20+
internal_var___du[0] = internal_var___p[0] * internal_var___u[0] - internal_var___u[0] * internal_var___u[1];
21+
internal_var___du[1] = -3 * internal_var___u[1] + internal_var___u[0] * internal_var___u[1];
2222
}
2323
"""
2424

0 commit comments

Comments
 (0)