Skip to content

Commit 70d4ee1

Browse files
Merge pull request #19 from gabrielgellner/master
Add 2 argument method. Fixes #18
2 parents 432ea07 + 71e6c64 commit 70d4ee1

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

src/ode_def_opts.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,10 @@ function ode_def_opts(name::Symbol,opts::Dict{Symbol,Bool},ex::Expr,params...;M=
238238
overloadex = :(((p::$name))(t::Number,u,params,du) = $pex) |> esc
239239
push!(exprs,overloadex)
240240

241+
# Add a method which allocates the `du` and returns it instead of being inplace
242+
overloadex = :(((p::$name))(t::Number,u) = (du=similar(u); p(t,u,du); du)) |> esc
243+
push!(exprs,overloadex)
244+
241245
# Value Dispatches for the Parameter Derivatives
242246
if pderiv_exists
243247
for i in 1:length(params)

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ iJ= zeros(2,2)
6262
iW= zeros(2,2)
6363
f(t,u,du)
6464
@test du == [-3.0,-3.0]
65+
@test du == f(t,u)
6566
f_t(t,u,du)
6667
@test du == [2.0,-3.0]
6768
f_t2(t,u,du)
@@ -119,6 +120,7 @@ g = LotkaVolterra(a=1.0,b=2.0)
119120
@test g.a * u[1] - g.b * u[1]*u[2] == -10.0
120121
g(t,u,du)
121122
@test du == [-10.0,-3.0]
123+
@test du == g(t,u)
122124
h = LotkaVolterra2(1.0,2.0)
123125
h(t,u,du)
124126
@test du == [-10.0,-3.0]
@@ -143,6 +145,7 @@ NJ = @ode_def NoJacTest begin
143145
end a=>1.5 b=>1 c=3 d=4
144146
NJ(t,u,du)
145147
@test du == [-3.0;-3*3.0 + erf(2.0*3.0/4)]
148+
@test du == NJ(t,u)
146149
@test has_jac(NJ) == true
147150
println(NJ.Jex)
148151
#NJ(Val{:jac},t,u,J)

0 commit comments

Comments
 (0)