Skip to content

Commit 7c52843

Browse files
utkarsh530ChrisRackauckas
authored andcommitted
Fixes few correctness issues
1 parent 68052a5 commit 7c52843

File tree

3 files changed

+89
-59
lines changed

3 files changed

+89
-59
lines changed

src/ensemblegpuarray/kernels.jl

Lines changed: 68 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,15 @@ function Adapt.adapt_structure(to, ps::ParamWrapper{P, T}) where {P, T}
1212
adapt(to, ps.data))
1313
end
1414

15+
# The reparameterization is adapted from:https://github.com/rtqichen/torchdiffeq/issues/122#issuecomment-738978844
1516
@kernel function gpu_kernel(f, du, @Const(u),
1617
@Const(params::AbstractArray{ParamWrapper{P, T}}),
1718
@Const(t)) where {P, T}
1819
i = @index(Global, Linear)
1920
@inbounds p = params[i].params
2021
@inbounds tspan = params[i].data
22+
# reparameterization t->(t_0, t_f) from t->(0, 1).
23+
t = (tspan[2] - tspan[1]) * t + tspan[1]
2124
@views @inbounds f(du[:, i], u[:, i], p, t)
2225
@inbounds for j in 1:size(du, 1)
2326
du[j, i] = du[j, i] * (tspan[2] - tspan[1])
@@ -30,7 +33,8 @@ end
3033
i = @index(Global, Linear)
3134
@inbounds p = params[i].params
3235
@inbounds tspan = params[i].data
33-
36+
# reparameterization
37+
t = (tspan[2] - tspan[1]) * t + tspan[1]
3438
@views @inbounds x = f(u[:, i], p, t)
3539
@inbounds for j in 1:size(du, 1)
3640
du[j, i] = x[j] * (tspan[2] - tspan[1])
@@ -66,6 +70,9 @@ end
6670
@inbounds p = params[i + 1].params
6771
@inbounds tspan = params[i + 1].data
6872

73+
# reparameterization
74+
t = (tspan[2] - tspan[1]) * t + tspan[1]
75+
6976
@views @inbounds f(J[section, section], u[:, i + 1], p, t)
7077
@inbounds for j in section, k in section
7178
J[k, j] = J[k, j] * (tspan[2] - tspan[1])
@@ -81,6 +88,9 @@ end
8188
@inbounds p = params[i + 1].params
8289
@inbounds tspan = params[i + 1].data
8390

91+
# reparameterization
92+
t = (tspan[2] - tspan[1]) * t + tspan[1]
93+
8494
@views @inbounds x = f(u[:, i + 1], p, t)
8595

8696
@inbounds for j in section, k in section
@@ -150,6 +160,9 @@ end
150160
@inbounds p = params[i].params
151161
@inbounds tspan = params[i].data
152162

163+
# reparameterization
164+
t = (tspan[2] - tspan[1]) * t + tspan[1]
165+
153166
@views @inbounds jac(_W, u[:, i], p, t)
154167

155168
@inbounds for i in eachindex(_W)
@@ -187,6 +200,9 @@ end
187200

188201
_W = @inbounds @view(W[:, :, i])
189202

203+
# reparameterization
204+
t = (tspan[2] - tspan[1]) * t + tspan[1]
205+
190206
@views @inbounds x = jac(u[:, i], p, t)
191207
@inbounds for j in 1:length(_W)
192208
_W[j] = x[j] * (tspan[2] - tspan[1])
@@ -217,38 +233,51 @@ end
217233
end
218234
end
219235

220-
@kernel function Wt_kernel(f::AbstractArray{T}, W, @Const(u), @Const(p), @Const(gamma),
221-
@Const(t)) where {T}
236+
@kernel function Wt_kernel(
237+
jac, W, @Const(u), @Const(params::AbstractArray{ParamWrapper{P, T}}),
238+
@Const(gamma), @Const(t)) where {P, T}
222239
i = @index(Global, Linear)
223240
len = size(u, 1)
241+
@inbounds p = params[i].params
242+
@inbounds tspan = params[i].data
243+
244+
# reparameterization
245+
t = (tspan[2] - tspan[1]) * t + tspan[1]
246+
224247
_W = @inbounds @view(W[:, :, i])
225-
@inbounds jac = f[i].tgrad
226-
@views @inbounds jac(_W, u[:, i], p[:, i], t)
248+
@views @inbounds jac(_W, u[:, i], p, t)
227249
@inbounds for i in 1:len
228-
_W[i, i] = -inv(gamma) + _W[i, i]
250+
_W[i, i] = -inv(gamma) + _W[i, i] * (tspan[2] - tspan[1])
229251
end
230252
end
231253

232-
@kernel function Wt_kernel(jac, W, @Const(u), @Const(p), @Const(gamma), @Const(t))
254+
@kernel function Wt_kernel_oop(
255+
jac, W, @Const(u), @Const(params::AbstractArray{ParamWrapper{P, T}}),
256+
@Const(gamma), @Const(t)) where {P, T}
233257
i = @index(Global, Linear)
234258
len = size(u, 1)
259+
260+
@inbounds p = params[i].params
261+
@inbounds tspan = params[i].data
262+
263+
# reparameterization
264+
t = (tspan[2] - tspan[1]) * t + tspan[1]
265+
235266
_W = @inbounds @view(W[:, :, i])
236-
@views @inbounds jac(_W, u[:, i], p[:, i], t)
267+
@views @inbounds x = jac(u[:, i], p, t)
268+
@inbounds for j in 1:length(_W)
269+
_W[j] = x[j] * (tspan[2] - tspan[1])
270+
end
237271
@inbounds for i in 1:len
238272
_W[i, i] = -inv(gamma) + _W[i, i]
239273
end
240274
end
241275

242-
@kernel function Wt_kernel_oop(f::AbstractArray{T}, W, @Const(u), @Const(p), @Const(gamma),
243-
@Const(t)) where {T}
276+
@kernel function Wt_kernel(jac, W, @Const(u), @Const(p), @Const(gamma), @Const(t))
244277
i = @index(Global, Linear)
245278
len = size(u, 1)
246279
_W = @inbounds @view(W[:, :, i])
247-
@inbounds jac = f[i].tgrad
248-
@views @inbounds x = jac(u[:, i], p[:, i], t)
249-
@inbounds for j in 1:length(_W)
250-
_W[j] = x[j]
251-
end
280+
@views @inbounds jac(_W, u[:, i], p[:, i], t)
252281
@inbounds for i in 1:len
253282
_W[i, i] = -inv(gamma) + _W[i, i]
254283
end
@@ -267,30 +296,30 @@ end
267296
end
268297
end
269298

270-
@kernel function gpu_kernel_tgrad(f::AbstractArray{T}, du, @Const(u), @Const(p),
271-
@Const(t)) where {T}
272-
i = @index(Global, Linear)
273-
@inbounds f = f[i].tgrad
274-
if eltype(p) <: Number
275-
@views @inbounds f(du[:, i], u[:, i], p[:, i], t)
276-
else
277-
@views @inbounds f(du[:, i], u[:, i], p[i], t)
278-
end
279-
end
280-
281-
@kernel function gpu_kernel_oop_tgrad(f::AbstractArray{T}, du, @Const(u), @Const(p),
282-
@Const(t)) where {T}
283-
i = @index(Global, Linear)
284-
@inbounds f = f[i].tgrad
285-
if eltype(p) <: Number
286-
@views @inbounds x = f(u[:, i], p[:, i], t)
287-
else
288-
@views @inbounds x = f(u[:, i], p[i], t)
289-
end
290-
@inbounds for j in 1:size(du, 1)
291-
du[j, i] = x[j]
292-
end
293-
end
299+
# @kernel function gpu_kernel_tgrad(f::AbstractArray{T}, du, @Const(u), @Const(p),
300+
# @Const(t)) where {T}
301+
# i = @index(Global, Linear)
302+
# @inbounds f = f[i].tgrad
303+
# if eltype(p) <: Number
304+
# @views @inbounds f(du[:, i], u[:, i], p[:, i], t)
305+
# else
306+
# @views @inbounds f(du[:, i], u[:, i], p[i], t)
307+
# end
308+
# end
309+
310+
# @kernel function gpu_kernel_oop_tgrad(f::AbstractArray{T}, du, @Const(u), @Const(p),
311+
# @Const(t)) where {T}
312+
# i = @index(Global, Linear)
313+
# @inbounds f = f[i].tgrad
314+
# if eltype(p) <: Number
315+
# @views @inbounds x = f(u[:, i], p[:, i], t)
316+
# else
317+
# @views @inbounds x = f(u[:, i], p[i], t)
318+
# end
319+
# @inbounds for j in 1:size(du, 1)
320+
# du[j, i] = x[j]
321+
# end
322+
# end
294323

295324
function lufact!(::CPU, W)
296325
len = size(W, 1)

test/Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
[deps]
22
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
33
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
4+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
45
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
56
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
6-
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
77
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
88
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
11+
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
1012
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
1113
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1214
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
@@ -15,3 +17,4 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1517
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
1618
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
1719
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
20+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

test/reverse_ad_tests.jl

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,40 @@
1-
using OrdinaryDiffEq, Flux, DiffEqGPU, Test
1+
using OrdinaryDiffEq, Optimization, OptimizationOptimisers, DiffEqGPU, Test
2+
import Zygote
23

34
include("utils.jl")
45

56
function modelf(du, u, p, t)
67
du[1] = 1.01 * u[1] * p[1] * p[2]
78
end
89

9-
function model()
10-
prob = ODEProblem(modelf, u0, (0.0, 1.0), pa)
10+
function model(θ, ensemblealg)
11+
prob = ODEProblem(modelf, [θ[1]], (0.0, 1.0), [θ[2], θ[3]])
1112

1213
function prob_func(prob, i, repeat)
1314
remake(prob, u0 = 0.5 .+ i / 100 .* prob.u0)
1415
end
1516

1617
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
17-
solve(ensemble_prob, Tsit5(), EnsembleGPUArray(backend), saveat = 0.1,
18+
solve(ensemble_prob, Tsit5(), ensemblealg, saveat = 0.1,
1819
trajectories = 10)
1920
end
2021

21-
# loss function
22-
loss() = sum(abs2, 1.0 .- Array(model()))
23-
24-
data = Iterators.repeated((), 10)
25-
26-
cb = function () # callback function to observe training
27-
@show loss()
22+
callback = function (θ, l) # callback function to observe training
23+
@show l
24+
false
2825
end
2926

3027
pa = [1.0, 2.0]
3128
u0 = [3.0]
32-
opt = ADAM(0.1)
33-
println("Starting to train")
3429

35-
l1 = loss()
30+
θ = [u0; pa]
3631

37-
for epoch in 1:10
38-
Flux.train!(loss, Flux.params([pa]), data, opt; cb = cb)
39-
end
32+
opt = Adam(0.1)
33+
loss_gpu(θ) = sum(abs2, 1.0 .- Array(model(θ, EnsembleCPUArray())))
34+
l1 = loss_gpu(θ)
35+
36+
adtype = Optimization.AutoZygote()
37+
optf = Optimization.OptimizationFunction((x, p) -> loss_gpu(x), adtype)
38+
optprob = Optimization.OptimizationProblem(optf, θ)
4039

41-
l2 = loss()
42-
@test 3l2 < l1
40+
res_gpu = Optimization.solve(optprob, opt; callback = callback, maxiters = 100)

0 commit comments

Comments
 (0)