Skip to content

Commit ddc68d9

Browse files
add an ensemble inference test
1 parent dc2bd47 commit ddc68d9

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

test/downstream/inference.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,38 @@ tspan = (0.0,1.0)
99
prob = ODEProblem(lorenz,u0,tspan)
1010
sol = solve(prob,Tsit5())
1111
@inferred solve(prob,Tsit5())
12+
13+
function f(du,u,p,t)
14+
du[1] = p.a
15+
du[2] = p.b
16+
end
17+
18+
const alg = Tsit5()
19+
20+
function solve_ode(f::F, p::P) where {F,P}
21+
22+
tspan = (0., 1.0)
23+
Δt = tspan[2] - tspan[1]
24+
dt = 1/252
25+
nodes = Int(ceil(Δt / dt) + 1)
26+
t = T = [tspan[1] + (i - 1) * dt for i = 1:nodes]
27+
28+
# if I do not set {true}, prob type Any...
29+
prob = ODEProblem{true}(f, [0., 0.], tspan, p)
30+
# prob = ODEProblem(f, [0., 0.], tspan, p)
31+
32+
prob_func = (prob, i, repeat) -> begin
33+
remake(prob, tspan = (T[i + 1], t[1]))
34+
end
35+
36+
# ensemble problem
37+
odes = EnsembleProblem(prob, prob_func = prob_func)
38+
39+
sol = OrdinaryDiffEq.solve(
40+
odes, OrdinaryDiffEq.Tsit5(), OrdinaryDiffEq.EnsembleThreads(),
41+
trajectories = nodes - 1, saveat = -dt
42+
)
43+
44+
return sol
45+
end
46+
@inferred solve_ode(f, (a = 1, b = 1))

0 commit comments

Comments
 (0)