Skip to content

Commit 30dc99a

Browse files
committed
sde build_function
1 parent 873400a commit 30dc99a

File tree

5 files changed

+33
-27
lines changed

5 files changed

+33
-27
lines changed

src/build_function.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
224224
skipzeros = outputidxs===nothing,
225225
fillzeros = skipzeros && !(typeof(rhss)<:SparseMatrixCSC),
226226
parallel=SerialForm(), kwargs...)
227+
conv = conv rm_calls_with_iv
227228
if multithread isa Bool
228229
@warn("multithraded is deprecated for the parallel argument. See the documentation.")
229230
parallel = multithread ? MultithreadedForm() : SerialForm()
@@ -466,6 +467,10 @@ function vars_to_pairs(name,vs)
466467
[term_to_symbol(value(vs))], [name]
467468
end
468469

470+
function rm_calls_with_iv(expr)
471+
Rewriters.Prewalk(Rewriters.Chain([@rule((~f::(x->x isa Sym))(~t::(x->x isa Sym)) => Sym{symtype((~f)(~t))}((term_to_symbol(~f))))]))(value(expr))
472+
end
473+
469474
get_varnumber(varop::Operation,vars::Vector{Operation}) = findfirst(x->isequal(x,varop),vars)
470475
get_varnumber(varop::Operation,vars::Vector{<:Variable}) = findfirst(x->isequal(x,varop.op),vars)
471476

@@ -486,7 +491,7 @@ end
486491

487492
function numbered_expr(de::ModelingToolkit.Equation,args...;varordering = args[1],
488493
lhsname=gensym("du"),rhsnames=[gensym("MTK") for i in 1:length(args)],offset=0)
489-
i = findfirst(x->isequal(x isa Variable ? term_to_symbol(x) : term_to_symbol(x.op),term_to_symbol(var_from_nested_derivative(de.lhs)[1])),varordering)
494+
i = findfirst(x->isequal(x isa Sym ? term_to_symbol(x) : term_to_symbol(x.op),term_to_symbol(var_from_nested_derivative(de.lhs)[1])),varordering)
490495
:($lhsname[$(i+offset)] = $(numbered_expr(de.rhs,args...;offset=offset,
491496
varordering = varordering,
492497
lhsname = lhsname,

src/systems/abstractsystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,4 @@ function (f::AbstractSysToExpr)(O::Operation)
272272
return build_expr(:call, Any[O.op; f.(O.args)])
273273
end
274274
(f::AbstractSysToExpr)(x) = toexpr(x)
275+

src/systems/diffeqs/modelingtoolkitize.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,17 @@ function modelingtoolkitize(prob::DiffEqBase.SDEProblem)
3939
prob.f isa DiffEqBase.AbstractParameterizedFunction &&
4040
return (prob.f.sys, prob.f.sys.states, prob.f.sys.ps)
4141
@parameters t
42-
vars = reshape([Variable(:x, i)(t) for i in eachindex(prob.u0)],size(prob.u0))
42+
var(x, i) = Sym{FnType{Tuple{symtype(t)}, Number}}(nameof(Variable(:x, i)))
43+
vars = reshape([var(:x, i)(t) for i in eachindex(prob.u0)],size(prob.u0))
4344
params = prob.p isa DiffEqBase.NullParameters ? [] :
44-
reshape([Variable(,i)() for i in eachindex(prob.p)],size(prob.p))
45+
reshape([Variable(,i) for i in eachindex(prob.p)],size(prob.p))
4546
@derivatives D'~t
4647

4748
rhs = [D(var) for var in vars]
4849

4950
if DiffEqBase.isinplace(prob)
5051
lhs = similar(vars, Any)
52+
5153
prob.f(lhs, vars, params, t)
5254

5355
if DiffEqBase.is_diagonal_noise(prob)

src/systems/diffeqs/odesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ end
8181
var_from_nested_derivative(x, i=0) = (missing, missing)
8282
var_from_nested_derivative(x::Term,i=0) = x.op isa Differential ? var_from_nested_derivative(x.args[1],i+1) : (x,i)
8383

84-
iv_from_nested_derivative(x) = x.op isa Differential ? iv_from_nested_derivative(x.args[1]) : x.args[1]
85-
iv_from_nested_derivative(x::Constant) = missing
84+
iv_from_nested_derivative(x::Term) = x.op isa Differential ? iv_from_nested_derivative(x.args[1]) : x.args[1]
85+
iv_from_nested_derivative(x) = missing
8686

8787
vars(exprs::Term) = vars([exprs])
8888
vars(exprs) = foldl(vars!, exprs; init = Set())

src/systems/diffeqs/sdesystem.jl

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,35 +30,35 @@ struct SDESystem <: AbstractODESystem
3030
"""The expressions defining the drift term."""
3131
eqs::Vector{Equation}
3232
"""The expressions defining the diffusion term."""
33-
noiseeqs::AbstractArray{Operation}
33+
noiseeqs::AbstractArray
3434
"""Independent variable."""
35-
iv::Variable
35+
iv::Sym
3636
"""Dependent (state) variables."""
37-
states::Vector{Variable}
37+
states::Vector
3838
"""Parameter variables."""
39-
ps::Vector{Variable}
40-
pins::Vector{Variable}
41-
observed::Vector{Equation}
39+
ps::Vector
40+
pins::Vector
41+
observed::Vector
4242
"""
4343
Time-derivative matrix. Note: this field will not be defined until
4444
[`calculate_tgrad`](@ref) is called on the system.
4545
"""
46-
tgrad::RefValue{Vector{Expression}}
46+
tgrad::RefValue
4747
"""
4848
Jacobian matrix. Note: this field will not be defined until
4949
[`calculate_jacobian`](@ref) is called on the system.
5050
"""
51-
jac::RefValue{Matrix{Expression}}
51+
jac::RefValue
5252
"""
5353
`Wfact` matrix. Note: this field will not be defined until
5454
[`generate_factorized_W`](@ref) is called on the system.
5555
"""
56-
Wfact::RefValue{Matrix{Expression}}
56+
Wfact::RefValue
5757
"""
5858
`Wfact_t` matrix. Note: this field will not be defined until
5959
[`generate_factorized_W`](@ref) is called on the system.
6060
"""
61-
Wfact_t::RefValue{Matrix{Expression}}
61+
Wfact_t::RefValue
6262
"""
6363
Name: the name of the system
6464
"""
@@ -70,24 +70,22 @@ struct SDESystem <: AbstractODESystem
7070
end
7171

7272
function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
73-
pins = Variable[],
74-
observed = Operation[],
73+
pins = [],
74+
observed = [],
7575
systems = SDESystem[],
7676
name = gensym(:SDESystem))
77-
iv′ = convert(Variable,iv)
78-
dvs′ = convert.(Variable,dvs)
79-
ps′ = convert.(Variable,ps)
80-
tgrad = RefValue(Vector{Expression}(undef, 0))
81-
jac = RefValue(Matrix{Expression}(undef, 0, 0))
82-
Wfact = RefValue(Matrix{Expression}(undef, 0, 0))
83-
Wfact_t = RefValue(Matrix{Expression}(undef, 0, 0))
77+
iv′ = value(iv)
78+
dvs′ = value.(dvs)
79+
ps′ = value.(ps)
80+
tgrad = RefValue(Vector{Num}(undef, 0))
81+
jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
82+
Wfact = RefValue(Matrix{Num}(undef, 0, 0))
83+
Wfact_t = RefValue(Matrix{Num}(undef, 0, 0))
8484
SDESystem(deqs, neqs, iv′, dvs′, ps′, pins, observed, tgrad, jac, Wfact, Wfact_t, name, systems)
8585
end
8686

8787
function generate_diffusion_function(sys::SDESystem, dvs = sys.states, ps = sys.ps; kwargs...)
88-
dvs′ = convert.(Variable,dvs)
89-
ps′ = convert.(Variable,ps)
90-
return build_function(sys.noiseeqs, dvs′, ps′, sys.iv;
88+
return build_function(sys.noiseeqs, dvs, ps, sys.iv;
9189
conv = ODEToExpr(sys),kwargs...)
9290
end
9391

0 commit comments

Comments
 (0)