Skip to content

Commit 659e4ef

Browse files
ParameterizedFunction type
1 parent 60d176e commit 659e4ef

File tree

3 files changed

+33
-0
lines changed

3 files changed

+33
-0
lines changed

src/ParameterizedFunctions.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,14 @@ module ParameterizedFunctions
2929
include("dict_build.jl")
3030
include("fem.jl")
3131
include("macros.jl")
32+
include("parameterized_function_type.jl")
3233

3334
export @ode_def, @fem_def, ode_def_opts,
3435
@ode_def_bare, @ode_def_nohes, @ode_def_noinvjac, @ode_def_noinvhes,
3536
@ode_def_mm, @ode_def_nohes_mm, @ode_def_noinvjac, @ode_def_noinvhes_mm
3637

38+
export ParameterizedFunction
39+
3740
end # module
3841

3942

src/parameterized_function_type.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
type ParameterizedFunction{isinplace,F,P} <: AbstractParameterizedFunction{isinplace}
2+
f::F
3+
p::P
4+
end
5+
6+
function ParameterizedFunction(f,p)
7+
isinplace = numparameters(f)>=4
8+
ParameterizedFunction{isinplace,typeof(f),typeof(p)}(f,p)
9+
end
10+
11+
(pf::ParameterizedFunction{true,F,P}){F,P}(t,u,du) = pf.f(t,u,pf.p,du)
12+
(pf::ParameterizedFunction{true,F,P}){F,P}(t,u,params,du) = pf.f(t,u,params,du)
13+
(pf::ParameterizedFunction{false,F,P}){F,P}(t,u) = pf.f(t,u,pf.p)
14+
(pf::ParameterizedFunction{false,F,P}){F,P}(t,u,params) = pf.f(t,u,params)

test/runtests.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,19 @@ end
3636
du[2] = -3 * u[2] + u[1]*u[2]
3737
end
3838

39+
pf_func = function (t,u,p,du)
40+
du[1] = p[1] * u[1] - p[2] * u[1]*u[2]
41+
du[2] = -3 * u[2] + u[1]*u[2]
42+
end
43+
44+
pf = ParameterizedFunction(pf_func,[1.5,1.0])
45+
46+
pf_func2 = function (t,u,p)
47+
[p[1] * u[1] - p[2] * u[1]*u[2];-3 * u[2] + u[1]*u[2]]
48+
end
49+
50+
pf2 = ParameterizedFunction(pf_func2,[1.5,1.0])
51+
3952
println("Test Values")
4053
t = 1.0
4154
u = [2.0,3.0]
@@ -46,6 +59,9 @@ iJ= zeros(2,2)
4659
iW= zeros(2,2)
4760
f(t,u,du)
4861
@test du == [-3.0,-3.0]
62+
pf(t,u,du)
63+
@test du == [-3.0,-3.0]
64+
@test pf2(t,u) == [-3.0,-3.0]
4965

5066
println("Test t-gradient")
5167
f(Val{:tgrad},t,u,grad)

0 commit comments

Comments
 (0)