Skip to content

Commit 73728bc

Browse files
add GPU neural ODE test
1 parent fe15920 commit 73728bc

File tree

3 files changed

+57
-0
lines changed

3 files changed

+57
-0
lines changed

test/downstream/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
[deps]
2+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
23
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
4+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"

test/downstream/gpu_neural_ode.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
using DiffEqFlux, OrdinaryDiffEq, Flux, CUDA
2+
CUDA.allowscalar(false) # Makes sure no slow operations are occuring
3+
4+
# Generate Data
5+
u0 = Float32[2.0; 0.0]
6+
datasize = 30
7+
tspan = (0.0f0, 1.5f0)
8+
tsteps = range(tspan[1], tspan[2], length = datasize)
9+
function trueODEfunc(du, u, p, t)
10+
true_A = [-0.1 2.0; -2.0 -0.1]
11+
du .= ((u.^3)'true_A)'
12+
end
13+
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
14+
# Make the data into a GPU-based array if the user has a GPU
15+
ode_data = gpu(solve(prob_trueode, Tsit5(), saveat = tsteps))
16+
17+
18+
dudt2 = FastChain((x, p) -> x.^3,
19+
FastDense(2, 50, tanh),
20+
FastDense(50, 2))
21+
u0 = Float32[2.0; 0.0] |> gpu
22+
p = initial_params(dudt2) |> gpu
23+
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
24+
25+
function predict_neuralode(p)
26+
gpu(prob_neuralode(u0,p))
27+
end
28+
function loss_neuralode(p)
29+
pred = predict_neuralode(p)
30+
loss = sum(abs2, ode_data .- pred)
31+
return loss, pred
32+
end
33+
# Callback function to observe training
34+
list_plots = []
35+
iter = 0
36+
callback = function (p, l, pred; doplot = false)
37+
global list_plots, iter
38+
if iter == 0
39+
list_plots = []
40+
end
41+
iter += 1
42+
display(l)
43+
# plot current prediction against data
44+
plt = scatter(tsteps, Array(ode_data[1,:]), label = "data")
45+
scatter!(plt, tsteps, Array(pred[1,:]), label = "prediction")
46+
push!(list_plots, plt)
47+
if doplot
48+
display(plot(plt))
49+
end
50+
return false
51+
end
52+
result_neuralode = DiffEqFlux.sciml_train(loss_neuralode, p,
53+
ADAM(0.05), cb = callback,
54+
maxiters = 300)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,6 @@ end
2121
if !is_APPVEYOR && GROUP == "Downstream"
2222
activate_downstream_env()
2323
Pkg.test("DiffEqFlux")
24+
@safetestset "DiffEqFlux GPU" begin include("gpu_neural_ode.jl") end
2425
end
2526
end

0 commit comments

Comments
 (0)