Skip to content

Commit fe15920

Browse files
Fix GPU usage and add downstream testing
1 parent a571acc commit fe15920

File tree

5 files changed

+44
-17
lines changed

5 files changed

+44
-17
lines changed

.travis.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ os:
1010
julia:
1111
- 1
1212
# - nightly
13+
env:
14+
- GROUP=Core
15+
- GROUP=Downstream
1316
notifications:
1417
email: false
1518
jobs:

src/function.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ function instantiate_function(f, x, ::AbstractADType, p, num_cons = 0)
3737
cons_j,cons_h)
3838
end
3939

40-
function instantiate_function(f, x, ::AutoForwardDiff{_chunksize}, p, num_cons = 0) where _chunksize
40+
function instantiate_function(f::OptimizationFunction{true}, x, ::AutoForwardDiff{_chunksize}, p, num_cons = 0) where _chunksize
4141

4242
chunksize = _chunksize === nothing ? default_chunk_size(length(x)) : _chunksize
4343

src/solve.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ function DiffEqBase.solve(prob::OptimizationProblem, opt, args...;kwargs...)
1212
__solve(prob, opt, args...; kwargs...)
1313
end
1414

15+
#=
1516
function update!(x::AbstractArray, x̄::AbstractArray{<:ForwardDiff.Dual})
1617
x .-= x̄
1718
end
@@ -31,6 +32,7 @@ end
3132
function update!(opt, xs::Flux.Zygote.Params, gs)
3233
update!(opt, xs[1], gs)
3334
end
35+
=#
3436

3537
maybe_with_logger(f, logger) = logger === nothing ? f() : Logging.with_logger(f, logger)
3638

@@ -62,7 +64,10 @@ macro withprogress(progress, exprs...)
6264
end |> esc
6365
end
6466

65-
function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;cb = (args...) -> (false), maxiters::Number = 1000, progress = true, save_best = true, kwargs...)
67+
function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;
68+
cb = (args...) -> (false), maxiters::Number = 1000,
69+
progress = true, save_best = true, kwargs...)
70+
6671
if maxiters <= 0.0
6772
error("The number of maxiters has to be a non-negative and non-zero number.")
6873
else
@@ -76,7 +81,7 @@ function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;cb = (args.
7681

7782
if data != DEFAULT_DATA
7883
maxiters = length(data)
79-
else
84+
else
8085
data = take(data, maxiters)
8186
end
8287

@@ -90,8 +95,10 @@ function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;cb = (args.
9095

9196
@withprogress progress name="Training" begin
9297
for (i,d) in enumerate(data)
93-
gs = prob.f.adtype isa AutoFiniteDiff ? Array{Number}(undef,length(θ)) : DiffResults.GradientResult(θ)
94-
f.grad(gs, θ, d...)
98+
gs = Flux.Zygote.gradient(ps) do
99+
x = prob.f(θ,prob.p)
100+
first(x)
101+
end
95102
x = f.f(θ, prob.p, d...)
96103
cb_call = cb(θ, x...)
97104
if !(typeof(cb_call) <: Bool)
@@ -101,7 +108,7 @@ function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;cb = (args.
101108
end
102109
msg = @sprintf("loss: %.3g", x[1])
103110
progress && ProgressLogging.@logprogress msg i/maxiters
104-
update!(opt, ps, prob.f.adtype isa AutoFiniteDiff ? gs : DiffResults.gradient(gs))
111+
Flux.update!(opt, ps, gs)
105112

106113
if save_best
107114
if first(x) < first(min_err) #found a better solution
@@ -215,7 +222,7 @@ function __solve(prob::OptimizationProblem, opt::Union{Optim.Fminbox,Optim.SAMIN
215222
if !(typeof(cb_call) <: Bool)
216223
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
217224
end
218-
cur, state = iterate(data, state)
225+
cur, state = iterate(data, state)
219226
cb_call
220227
end
221228

test/downstream/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[deps]
2+
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"

test/runtests.jl

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,25 @@
1-
using SafeTestsets
2-
3-
println("Rosenbrock Tests")
4-
@safetestset "Rosenbrock" begin include("rosenbrock.jl") end
5-
println("AD Tests")
6-
@safetestset "AD Tests" begin include("ADtests.jl") end
7-
println("Mini batching Tests")
8-
@safetestset "Mini batching" begin include("minibatch.jl") end
9-
println("DiffEqFlux Tests")
10-
@safetestset "DiffEqFlux" begin include("diffeqfluxtests.jl") end
1+
using SafeTestsets, Pkg
2+
3+
const GROUP = get(ENV, "GROUP", "All")
4+
const is_APPVEYOR = Sys.iswindows() && haskey(ENV,"APPVEYOR")
5+
const is_TRAVIS = haskey(ENV,"TRAVIS")
6+
7+
function activate_downstream_env()
8+
Pkg.activate("downstream")
9+
Pkg.develop(PackageSpec(path=dirname(@__DIR__)))
10+
Pkg.instantiate()
11+
end
12+
13+
@time begin
14+
if GROUP == "All" || GROUP == "Core"
15+
@safetestset "Rosenbrock" begin include("rosenbrock.jl") end
16+
@safetestset "AD Tests" begin include("ADtests.jl") end
17+
@safetestset "Mini batching" begin include("minibatch.jl") end
18+
@safetestset "DiffEqFlux" begin include("diffeqfluxtests.jl") end
19+
end
20+
21+
if !is_APPVEYOR && GROUP == "Downstream"
22+
activate_downstream_env()
23+
Pkg.test("DiffEqFlux")
24+
end
25+
end

0 commit comments

Comments
 (0)