Skip to content

Commit bdcf035

Browse files
Merge pull request #155 from JuliaDiffEq/builder
Function-local safe symbolic Jacobians of numerical functions
2 parents 0c833cb + 9ab9df3 commit bdcf035

File tree

6 files changed

+115
-46
lines changed

6 files changed

+115
-46
lines changed

README.md

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -118,25 +118,24 @@ The `@variables` and `@parameters` macros support this with the following syntax
118118
julia> @variables x[1:3];
119119
julia> x
120120
3-element Array{Operation,1}:
121-
x[1]()
122-
x[2]()
123-
x[3]()
121+
x()
122+
x()
123+
x()
124124

125125
# support for arbitrary ranges and tensors
126126
julia> @variables y[2:3,1:5:6];
127127
julia> y
128128
2×2 Array{Operation,2}:
129-
y[2,1]() y[2,6]()
130-
y[3,1]() y[3,6]()
131-
129+
y₂̒₁() y₂̒₆()
130+
y₃̒₁() y₃̒₆()
132131

133132
# also works for dependent variables
134133
julia> @parameters t; @variables z[1:3](t);
135134
julia> z
136135
3-element Array{Operation,1}:
137-
z[1](t())
138-
z[2](t())
139-
z[3](t())
136+
z(t())
137+
z(t())
138+
z(t())
140139
```
141140
142141
## Core Principles
@@ -201,7 +200,7 @@ aliased to the given call, allowing implicit use of dependents for convenience.
201200
@parameters t α σ(..) β[1:2]
202201
@variables w(..) x(t) y() z(t, α, x)
203202

204-
expr = β[1] * x + y^α + σ(3) * (z - t) - β[2] * w(t - 1)
203+
expr = β* x + y^α + σ(3) * (z - t) - β * w(t - 1)
205204
```
206205
207206
Note that `@parameters` and `@variables` implicitly add `()` to values that

src/ModelingToolkit.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,11 @@ include("variables.jl")
8484
include("operations.jl")
8585
include("differentials.jl")
8686
include("equations.jl")
87-
include("systems/diffeqs/diffeqsystem.jl")
88-
include("systems/diffeqs/first_order_transform.jl")
89-
include("systems/nonlinear/nonlinear_system.jl")
9087
include("function_registration.jl")
9188
include("simplify.jl")
9289
include("utils.jl")
90+
include("systems/diffeqs/diffeqsystem.jl")
91+
include("systems/diffeqs/first_order_transform.jl")
92+
include("systems/nonlinear/nonlinear_system.jl")
9393

9494
end # module

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 74 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -209,15 +209,82 @@ Create an `ODEFunction` from the [`ODESystem`](@ref). The arguments `dvs` and `p
209209
are used to set the order of the dependent variable and parameter vectors,
210210
respectively.
211211
"""
212-
function DiffEqBase.ODEFunction{iip}(sys::ODESystem, dvs, ps;
212+
function DiffEqBase.ODEFunction{iip}(sys::ODESystem, dvs, ps,
213+
safe = Val{true};
213214
version = nothing,
214-
jac = false, Wfact = false) where iip
215-
expr = eval(generate_function(sys, dvs, ps))
216-
jac_expr = jac ? nothing : eval(generate_jacobian(sys, dvs, ps))
217-
Wfact_expr,Wfact_t_expr = Wfact ? (nothing,nothing) : eval.(generate_factorized_W(sys, dvs, ps))
218-
ODEFunction{iip}(eval(expr),jac=jac_expr,
219-
Wfact = Wfact_expr, Wfact_t = Wfact_t_expr)
215+
jac = false, Wfact = false) where {iip}
216+
_f = eval(generate_function(sys, dvs, ps))
217+
out_f_safe(u,p,t) = ModelingToolkit.fast_invokelatest(_f,typeof(u),u,p,t)
218+
out_f_safe(du,u,p,t) = ModelingToolkit.fast_invokelatest(_f,Nothing,du,u,p,t)
219+
out_f(u,p,t) = _f(u,p,t)
220+
out_f(du,u,p,t) = _f(du,u,p,t)
221+
222+
if jac
223+
@show generate_jacobian(sys, dvs, ps)
224+
_jac = eval(generate_jacobian(sys, dvs, ps))
225+
jac_f_safe(u,p,t) = ModelingToolkit.fast_invokelatest(_jac,Matrix{eltype(u)},u,p,t)
226+
jac_f_safe(J,u,p,t) = ModelingToolkit.fast_invokelatest(_jac,Nothing,J,u,p,t)
227+
jac_f(u,p,t) = _jac(u,p,t)
228+
jac_f(J,u,p,t) = _jac(J,u,p,t)
229+
else
230+
jac_f_safe = nothing
231+
jac_f = nothing
232+
end
233+
234+
if Wfact
235+
_Wfact,_Wfact_t = eval.(generate_factorized_W(sys, dvs, ps))
236+
Wfact_f_safe(u,p,t) = ModelingToolkit.fast_invokelatest(_Wfact,Matrix{eltype(u)},u,p,t)
237+
Wfact_f_safe(J,u,p,t) = ModelingToolkit.fast_invokelatest(_Wfact,Nothing,J,u,p,t)
238+
Wfact_f_t_safe(u,p,t) = ModelingToolkit.fast_invokelatest(_Wfact,Matrix{eltype(u)},u,p,t)
239+
Wfact_f_t_safe(J,u,p,t) = ModelingToolkit.fast_invokelatest(_Wfact,Nothing,J,u,p,t)
240+
Wfact_f(u,p,t) = _Wfact(u,p,t)
241+
Wfact_f(J,u,p,t) = _Wfact(J,u,p,t)
242+
Wfact_f_t(u,p,t) = _Wfact_t(u,p,t)
243+
Wfact_f_t(J,u,p,t) = _Wfact_t(J,u,p,t)
244+
else
245+
Wfact_f_safe = nothing
246+
Wfact_f_t_safe = nothing
247+
Wfact_f = nothing
248+
Wfact_t_f = nothing
249+
end
250+
251+
if safe === Val{true}
252+
ODEFunction{iip}(out_f_safe,jac=jac_f_safe,
253+
Wfact = Wfact_f_safe,
254+
Wfact_t = Wfact_f_t_safe)
255+
else
256+
ODEFunction{iip}(out_f,jac=jac_f,
257+
Wfact = Wfact_f,
258+
Wfact_t = Wfact_t_f)
259+
end
220260
end
261+
221262
function DiffEqBase.ODEFunction(sys::ODESystem, args...; kwargs...)
222263
ODEFunction{true}(sys, args...; kwargs...)
223264
end
265+
266+
"""
267+
$(SIGNATURES)
268+
269+
Generate `ODESystem`, dependent variables, and parameters from an `ODEProblem`.
270+
"""
271+
function modelingtoolkitize(prob::DiffEqBase.ODEProblem)
272+
@parameters t
273+
vars = [Variable(:x, i)(t) for i in eachindex(prob.u0)]
274+
params = [Variable(,i; known = true)() for i in eachindex(prob.p)]
275+
@derivatives D'~t
276+
277+
rhs = [D(var) for var in vars]
278+
279+
if DiffEqBase.isinplace(prob)
280+
lhs = similar(vars, Any)
281+
prob.f(lhs, vars, params, t)
282+
else
283+
lhs = prob.f(vars, params, t)
284+
end
285+
286+
eqs = vcat([rhs[i] ~ lhs[i] for i in eachindex(prob.u0)]...)
287+
de = ODESystem(eqs)
288+
289+
de, vars, params
290+
end

src/utils.jl

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,11 @@ function build_function(rhss, vs, ps, args = (), conv = rhs -> convert(Expr, rhs
5151
quote
5252
function $fname($X,u,p,$(args...))
5353
$ip_let_expr
54+
nothing
5455
end
5556
function $fname(u,p,$(args...))
5657
X = $let_expr
57-
T = $(constructor === nothing ? :(u isa StaticArrays.StaticArray ? StaticArrays.similar_type(typeof(u), eltype(X)) : x->(du=similar(u, eltype(X)); du .= x)) : constructor)
58+
T = $(constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.similar_type(typeof(u), eltype(X)) : x->(du=similar(u, eltype(X)); du .= x)) : constructor)
5859
T(X)
5960
end
6061
end
@@ -91,26 +92,10 @@ function vars!(vars, O)
9192
return vars
9293
end
9394

94-
"""
95-
$(SIGNATURES)
96-
97-
Generate `ODESystem`, dependent variables, and parameters from an `ODEProblem`.
98-
"""
99-
function modelingtoolkitize(prob::DiffEqBase.ODEProblem)
100-
t, = @parameters t; vars = [Variable(Symbol(:x, i))(t) for i in eachindex(prob.u0)]; params = [Variable(Symbol(, i); known = true)() for i in eachindex(prob.p)];
101-
D, = @derivatives D'~t
102-
103-
rhs = [D(var) for var in vars]
104-
105-
if DiffEqBase.isinplace(prob)
106-
lhs = similar(vars, Any)
107-
prob.f(lhs, vars, params, t)
108-
else
109-
lhs = prob.f(vars, params, t)
110-
end
111-
112-
eqs = vcat([rhs[i] ~ lhs[i] for i in eachindex(prob.u0)]...)
113-
de = ODESystem(eqs)
114-
115-
de, vars, params
95+
@inline @generated function fast_invokelatest(f, ::Type{rt}, args...) where rt
96+
tupargs = Expr(:tuple,(a==Nothing ? Int : a for a in args)...)
97+
quote
98+
_f = $(Expr(:cfunction, Base.CFunction, :f, rt, :((Core.svec)($((a==Nothing ? Int : a for a in args)...))), :(:ccall)))
99+
return ccall(_f.ptr,rt,$tupargs,$((:(getindex(args,$i) === nothing ? 0 : getindex(args,$i)) for i in 1:length(args))...))
100+
end
116101
end

src/variables.jl

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,20 @@
11
export Variable, @variables, @parameters
22

3+
const IndexMap = Dict{Char,Char}(
4+
'0' => '',
5+
'1' => '',
6+
'2' => '',
7+
'3' => '',
8+
'4' => '',
9+
'5' => '',
10+
'6' => '',
11+
'7' => '',
12+
'8' => '',
13+
'9' => '')
14+
function map_subscripts(indices)
15+
str = string(indices)
16+
join(IndexMap[c] for c in str)
17+
end
318

419
"""
520
$(TYPEDEF)
@@ -20,7 +35,7 @@ struct Variable <: Function
2035
Variable(name; known = false) = new(name, known)
2136
end
2237
function Variable(name, indices...; known = false)
23-
var_name = Symbol("$name[$(join(indices, ","))]")
38+
var_name = Symbol("$(name)$(join(map_subscripts.(indices), "̒"))")
2439
Variable(var_name; known=known)
2540
end
2641

@@ -113,7 +128,7 @@ end
113128
function _construct_var(var_name, known, call_args)
114129
if call_args === nothing
115130
:(Variable($(Meta.quot(var_name)); known = $known)())
116-
elseif call_args[end] == :..
131+
elseif !isempty(call_args) && call_args[end] == :..
117132
:(Variable($(Meta.quot(var_name)); known = $known))
118133
else
119134
:(Variable($(Meta.quot(var_name)); known = $known)($(call_args...)))
@@ -123,7 +138,7 @@ end
123138
function _construct_var(var_name, known, call_args, ind)
124139
if call_args === nothing
125140
:(Variable($(Meta.quot(var_name)), $ind...; known = $known)())
126-
elseif call_args[end] == :..
141+
elseif !isempty(call_args) && call_args[end] == :..
127142
:(Variable($(Meta.quot(var_name)), $ind...; known = $known))
128143
else
129144
:(Variable($(Meta.quot(var_name)), $ind...; known = $known)($(call_args...)))

test/variable_parsing.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,6 @@ s1 = [Variable(:s, 1, 1; known = true)() Variable(:s, 1, 2; known = true)()
5656
@variables x[1:2](t)
5757
x1 = [Variable(:x, 1)(t), Variable(:x, 2)(t)]
5858
@test isequal(x1, x)
59+
60+
@variables a[1:11,1:2]
61+
@variables a()

0 commit comments

Comments
 (0)