Skip to content

Commit 9dd26d7

Browse files
Merge pull request #74 from SciML/gpu
Add downstream tests and fix GPU
2 parents a571acc + b366a6b commit 9dd26d7

File tree

7 files changed

+108
-19
lines changed

7 files changed

+108
-19
lines changed

.travis.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ os:
1010
julia:
1111
- 1
1212
# - nightly
13+
env:
14+
- GROUP=Core
15+
- GROUP=Downstream
1316
notifications:
1417
email: false
1518
jobs:

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,12 @@ CMAEvolutionStrategy = "8d3b24bd-414e-49e0-94fb-163cc3a3e411"
4848
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
4949
Evolutionary = "86b6b26d-c046-49b6-aa0b-5f0f74682bd6"
5050
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
51-
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
5251
NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd"
52+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
5353
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
54+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
5455
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
5556
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5657

5758
[targets]
58-
test = ["BlackBoxOptim", "Evolutionary", "DiffEqFlux", "IterTools", "OrdinaryDiffEq", "NLopt", "CMAEvolutionStrategy", "Plots" ,"SafeTestsets", "Test"]
59+
test = ["BlackBoxOptim", "Evolutionary", "DiffEqFlux", "IterTools", "OrdinaryDiffEq", "NLopt", "CMAEvolutionStrategy", "Plots", "Pkg", "SafeTestsets", "Test"]

src/function.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ function instantiate_function(f, x, ::AbstractADType, p, num_cons = 0)
3737
cons_j,cons_h)
3838
end
3939

40-
function instantiate_function(f, x, ::AutoForwardDiff{_chunksize}, p, num_cons = 0) where _chunksize
40+
function instantiate_function(f::OptimizationFunction{true}, x, ::AutoForwardDiff{_chunksize}, p, num_cons = 0) where _chunksize
4141

4242
chunksize = _chunksize === nothing ? default_chunk_size(length(x)) : _chunksize
4343

src/solve.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ function DiffEqBase.solve(prob::OptimizationProblem, opt, args...;kwargs...)
1212
__solve(prob, opt, args...; kwargs...)
1313
end
1414

15+
#=
1516
function update!(x::AbstractArray, x̄::AbstractArray{<:ForwardDiff.Dual})
1617
x .-= x̄
1718
end
@@ -31,6 +32,7 @@ end
3132
function update!(opt, xs::Flux.Zygote.Params, gs)
3233
update!(opt, xs[1], gs)
3334
end
35+
=#
3436

3537
maybe_with_logger(f, logger) = logger === nothing ? f() : Logging.with_logger(f, logger)
3638

@@ -62,7 +64,10 @@ macro withprogress(progress, exprs...)
6264
end |> esc
6365
end
6466

65-
function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;cb = (args...) -> (false), maxiters::Number = 1000, progress = true, save_best = true, kwargs...)
67+
function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;
68+
cb = (args...) -> (false), maxiters::Number = 1000,
69+
progress = true, save_best = true, kwargs...)
70+
6671
if maxiters <= 0.0
6772
error("The number of maxiters has to be a non-negative and non-zero number.")
6873
else
@@ -76,7 +81,7 @@ function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;cb = (args.
7681

7782
if data != DEFAULT_DATA
7883
maxiters = length(data)
79-
else
84+
else
8085
data = take(data, maxiters)
8186
end
8287

@@ -90,8 +95,10 @@ function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;cb = (args.
9095

9196
@withprogress progress name="Training" begin
9297
for (i,d) in enumerate(data)
93-
gs = prob.f.adtype isa AutoFiniteDiff ? Array{Number}(undef,length(θ)) : DiffResults.GradientResult(θ)
94-
f.grad(gs, θ, d...)
98+
gs = Flux.Zygote.gradient(ps) do
99+
x = prob.f(θ,prob.p, d...)
100+
first(x)
101+
end
95102
x = f.f(θ, prob.p, d...)
96103
cb_call = cb(θ, x...)
97104
if !(typeof(cb_call) <: Bool)
@@ -101,7 +108,7 @@ function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;cb = (args.
101108
end
102109
msg = @sprintf("loss: %.3g", x[1])
103110
progress && ProgressLogging.@logprogress msg i/maxiters
104-
update!(opt, ps, prob.f.adtype isa AutoFiniteDiff ? gs : DiffResults.gradient(gs))
111+
Flux.update!(opt, ps, gs)
105112

106113
if save_best
107114
if first(x) < first(min_err) #found a better solution
@@ -215,7 +222,7 @@ function __solve(prob::OptimizationProblem, opt::Union{Optim.Fminbox,Optim.SAMIN
215222
if !(typeof(cb_call) <: Bool)
216223
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
217224
end
218-
cur, state = iterate(data, state)
225+
cur, state = iterate(data, state)
219226
cb_call
220227
end
221228

test/downstream/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[deps]
2+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3+
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
4+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"

test/downstream/gpu_neural_ode.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
using DiffEqFlux, OrdinaryDiffEq, Flux, CUDA
2+
CUDA.allowscalar(false) # Makes sure no slow operations are occuring
3+
4+
# Generate Data
5+
u0 = Float32[2.0; 0.0]
6+
datasize = 30
7+
tspan = (0.0f0, 1.5f0)
8+
tsteps = range(tspan[1], tspan[2], length = datasize)
9+
function trueODEfunc(du, u, p, t)
10+
true_A = [-0.1 2.0; -2.0 -0.1]
11+
du .= ((u.^3)'true_A)'
12+
end
13+
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
14+
# Make the data into a GPU-based array if the user has a GPU
15+
ode_data = gpu(solve(prob_trueode, Tsit5(), saveat = tsteps))
16+
17+
18+
dudt2 = FastChain((x, p) -> x.^3,
19+
FastDense(2, 50, tanh),
20+
FastDense(50, 2))
21+
u0 = Float32[2.0; 0.0] |> gpu
22+
p = initial_params(dudt2) |> gpu
23+
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
24+
25+
function predict_neuralode(p)
26+
gpu(prob_neuralode(u0,p))
27+
end
28+
function loss_neuralode(p)
29+
pred = predict_neuralode(p)
30+
loss = sum(abs2, ode_data .- pred)
31+
return loss, pred
32+
end
33+
# Callback function to observe training
34+
list_plots = []
35+
iter = 0
36+
callback = function (p, l, pred; doplot = false)
37+
global list_plots, iter
38+
if iter == 0
39+
list_plots = []
40+
end
41+
iter += 1
42+
display(l)
43+
# plot current prediction against data
44+
plt = scatter(tsteps, Array(ode_data[1,:]), label = "data")
45+
scatter!(plt, tsteps, Array(pred[1,:]), label = "prediction")
46+
push!(list_plots, plt)
47+
if doplot
48+
display(plot(plt))
49+
end
50+
return false
51+
end
52+
result_neuralode = DiffEqFlux.sciml_train(loss_neuralode, p,
53+
ADAM(0.05), cb = callback,
54+
maxiters = 300)

test/runtests.jl

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,30 @@
1-
using SafeTestsets
2-
3-
println("Rosenbrock Tests")
4-
@safetestset "Rosenbrock" begin include("rosenbrock.jl") end
5-
println("AD Tests")
6-
@safetestset "AD Tests" begin include("ADtests.jl") end
7-
println("Mini batching Tests")
8-
@safetestset "Mini batching" begin include("minibatch.jl") end
9-
println("DiffEqFlux Tests")
10-
@safetestset "DiffEqFlux" begin include("diffeqfluxtests.jl") end
1+
using SafeTestsets, Pkg
2+
3+
const GROUP = get(ENV, "GROUP", "All")
4+
const is_APPVEYOR = Sys.iswindows() && haskey(ENV,"APPVEYOR")
5+
const is_TRAVIS = haskey(ENV,"TRAVIS")
6+
7+
function activate_downstream_env()
8+
Pkg.activate("downstream")
9+
Pkg.develop(PackageSpec(path=dirname(@__DIR__)))
10+
Pkg.instantiate()
11+
end
12+
13+
@time begin
14+
if GROUP == "All" || GROUP == "Core"
15+
@safetestset "Rosenbrock" begin include("rosenbrock.jl") end
16+
@safetestset "AD Tests" begin include("ADtests.jl") end
17+
@safetestset "Mini batching" begin include("minibatch.jl") end
18+
@safetestset "DiffEqFlux" begin include("diffeqfluxtests.jl") end
19+
end
20+
21+
if !is_APPVEYOR && GROUP == "Downstream"
22+
activate_downstream_env()
23+
Pkg.test("DiffEqFlux")
24+
end
25+
26+
if !is_APPVEYOR && GROUP == "GPU"
27+
activate_downstream_env()
28+
@safetestset "DiffEqFlux GPU" begin include("downstream/gpu_neural_ode.jl") end
29+
end
30+
end

0 commit comments

Comments
 (0)