Skip to content

Commit 36ffe5b

Browse files
Fix function generation
1 parent 4b6e441 commit 36ffe5b

File tree

3 files changed

+27
-9
lines changed

3 files changed

+27
-9
lines changed

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,6 @@ function calculate_jacobian(sys::ODESystem)
9494
return jac
9595
end
9696

97-
function generate_jacobian(sys::ODESystem; version::FunctionVersion = ArrayFunction)
98-
jac = calculate_jacobian(sys)
99-
return build_function(jac, sys.dvs, sys.ps, (sys.iv.name,); version = version)
100-
end
101-
10297
struct ODEToExpr
10398
sys::ODESystem
10499
end
@@ -113,6 +108,11 @@ function (f::ODEToExpr)(O::Operation)
113108
end
114109
(f::ODEToExpr)(x) = convert(Expr, x)
115110

111+
function generate_jacobian(sys::ODESystem; version::FunctionVersion = ArrayFunction)
112+
jac = calculate_jacobian(sys)
113+
return build_function(jac, sys.dvs, sys.ps, (sys.iv.name,), ODEToExpr(sys); version = version)
114+
end
115+
116116
function generate_function(sys::ODESystem, dvs, ps; version::FunctionVersion = ArrayFunction)
117117
rhss = [deq.rhs for deq sys.eqs]
118118
dvs′ = [clean(dv) for dv dvs]

src/systems/nonlinear/nonlinear_system.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,26 @@ end
4545

4646
function generate_jacobian(sys::NonlinearSystem; version::FunctionVersion = ArrayFunction)
4747
jac = calculate_jacobian(sys)
48-
return build_function(jac, clean.(sys.vs), sys.ps; version = version)
48+
return build_function(jac, clean.(sys.vs), sys.ps, (), NLSysToExpr(sys); version = version)
4949
end
5050

51+
struct NLSysToExpr
52+
sys::NonlinearSystem
53+
end
54+
function (f::NLSysToExpr)(O::Operation)
55+
any(isequal(O), f.sys.vs) && return O.op.name # variables
56+
if isa(O.op, Variable)
57+
isempty(O.args) && return O.op.name # 0-ary parameters
58+
return build_expr(:call, Any[O.op.name; f.(O.args)])
59+
end
60+
return build_expr(:call, Any[O.op; f.(O.args)])
61+
end
62+
(f::NLSysToExpr)(x) = convert(Expr, x)
63+
64+
5165
function generate_function(sys::NonlinearSystem, vs, ps; version::FunctionVersion = ArrayFunction)
5266
rhss = [eq.rhs for eq sys.eqs]
5367
vs′ = [clean(v) for v vs]
5468
ps′ = [clean(p) for p ps]
55-
return build_function(rhss, vs′, ps′; version = version)
69+
return build_function(rhss, vs′, ps′, (), NLSysToExpr(sys); version = version)
5670
end

test/system_construction.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,12 @@ eqs = [0 ~ σ*(y-x),
119119
0 ~ x*y - β*z]
120120
ns = NonlinearSystem(eqs, [x,y,z])
121121
test_nlsys_inference("standard", ns, (x, y, z), (σ, ρ, β))
122-
123-
generate_function(ns, [x,y,z], [σ,ρ,β])
122+
@test begin
123+
f = eval(generate_function(ns, [x,y,z], [σ,ρ,β]))
124+
du = [0.0, 0.0, 0.0]
125+
f(du, [1,2,3], [1,2,3])
126+
du [1, -3, -7]
127+
end
124128

125129
@derivatives D'~t
126130
@parameters A() B() C()

0 commit comments

Comments
 (0)