Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ MetalExt = ["Metal"]
oneAPIExt = ["oneAPI"]

[compat]
AMDGPU = "0.5, 0.6, 0.7, 0.8, 0.9"
AMDGPU = "0.5, 0.6, 0.7, 0.8, 0.9, 1"
Adapt = "3, 4"
CUDA = "4.1.0, 5"
ChainRulesCore = "1"
Expand Down
3 changes: 2 additions & 1 deletion src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ prob = ODEProblem{false}(lorenz, u0, tspan, p)
prob_func = (prob, i, repeat) -> remake(prob, p = (@SVector rand(Float32, 3)) .* p)
monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false)

@time sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(CUDA.CUDABackend()), trajectories = 10_000,
@time sol = solve(
monteprob, GPUTsit5(), EnsembleGPUKernel(CUDA.CUDABackend()), trajectories = 10_000,
adaptive = false, dt = 0.1f0)
```
"""
Expand Down
107 changes: 68 additions & 39 deletions src/ensemblegpuarray/kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@ function Adapt.adapt_structure(to, ps::ParamWrapper{P, T}) where {P, T}
adapt(to, ps.data))
end

# The reparameterization is adapted from:https://github.com/rtqichen/torchdiffeq/issues/122#issuecomment-738978844
@kernel function gpu_kernel(f, du, @Const(u),
@Const(params::AbstractArray{ParamWrapper{P, T}}),
@Const(t)) where {P, T}
i = @index(Global, Linear)
@inbounds p = params[i].params
@inbounds tspan = params[i].data
# reparameterization t->(t_0, t_f) from t->(0, 1).
t = (tspan[2] - tspan[1]) * t + tspan[1]
@views @inbounds f(du[:, i], u[:, i], p, t)
@inbounds for j in 1:size(du, 1)
du[j, i] = du[j, i] * (tspan[2] - tspan[1])
Expand All @@ -30,7 +33,8 @@ end
i = @index(Global, Linear)
@inbounds p = params[i].params
@inbounds tspan = params[i].data

# reparameterization
t = (tspan[2] - tspan[1]) * t + tspan[1]
@views @inbounds x = f(u[:, i], p, t)
@inbounds for j in 1:size(du, 1)
du[j, i] = x[j] * (tspan[2] - tspan[1])
Expand Down Expand Up @@ -66,6 +70,9 @@ end
@inbounds p = params[i + 1].params
@inbounds tspan = params[i + 1].data

# reparameterization
t = (tspan[2] - tspan[1]) * t + tspan[1]

@views @inbounds f(J[section, section], u[:, i + 1], p, t)
@inbounds for j in section, k in section
J[k, j] = J[k, j] * (tspan[2] - tspan[1])
Expand All @@ -81,6 +88,9 @@ end
@inbounds p = params[i + 1].params
@inbounds tspan = params[i + 1].data

# reparameterization
t = (tspan[2] - tspan[1]) * t + tspan[1]

@views @inbounds x = f(u[:, i + 1], p, t)

@inbounds for j in section, k in section
Expand Down Expand Up @@ -150,6 +160,9 @@ end
@inbounds p = params[i].params
@inbounds tspan = params[i].data

# reparameterization
t = (tspan[2] - tspan[1]) * t + tspan[1]

@views @inbounds jac(_W, u[:, i], p, t)

@inbounds for i in eachindex(_W)
Expand Down Expand Up @@ -187,6 +200,9 @@ end

_W = @inbounds @view(W[:, :, i])

# reparameterization
t = (tspan[2] - tspan[1]) * t + tspan[1]

@views @inbounds x = jac(u[:, i], p, t)
@inbounds for j in 1:length(_W)
_W[j] = x[j] * (tspan[2] - tspan[1])
Expand Down Expand Up @@ -217,38 +233,51 @@ end
end
end

@kernel function Wt_kernel(f::AbstractArray{T}, W, @Const(u), @Const(p), @Const(gamma),
@Const(t)) where {T}
@kernel function Wt_kernel(
jac, W, @Const(u), @Const(params::AbstractArray{ParamWrapper{P, T}}),
@Const(gamma), @Const(t)) where {P, T}
i = @index(Global, Linear)
len = size(u, 1)
@inbounds p = params[i].params
@inbounds tspan = params[i].data

# reparameterization
t = (tspan[2] - tspan[1]) * t + tspan[1]

_W = @inbounds @view(W[:, :, i])
@inbounds jac = f[i].tgrad
@views @inbounds jac(_W, u[:, i], p[:, i], t)
@views @inbounds jac(_W, u[:, i], p, t)
@inbounds for i in 1:len
_W[i, i] = -inv(gamma) + _W[i, i]
_W[i, i] = -inv(gamma) + _W[i, i] * (tspan[2] - tspan[1])
end
end

@kernel function Wt_kernel(jac, W, @Const(u), @Const(p), @Const(gamma), @Const(t))
@kernel function Wt_kernel_oop(
jac, W, @Const(u), @Const(params::AbstractArray{ParamWrapper{P, T}}),
@Const(gamma), @Const(t)) where {P, T}
i = @index(Global, Linear)
len = size(u, 1)

@inbounds p = params[i].params
@inbounds tspan = params[i].data

# reparameterization
t = (tspan[2] - tspan[1]) * t + tspan[1]

_W = @inbounds @view(W[:, :, i])
@views @inbounds jac(_W, u[:, i], p[:, i], t)
@views @inbounds x = jac(u[:, i], p, t)
@inbounds for j in 1:length(_W)
_W[j] = x[j] * (tspan[2] - tspan[1])
end
@inbounds for i in 1:len
_W[i, i] = -inv(gamma) + _W[i, i]
end
end

@kernel function Wt_kernel_oop(f::AbstractArray{T}, W, @Const(u), @Const(p), @Const(gamma),
@Const(t)) where {T}
@kernel function Wt_kernel(jac, W, @Const(u), @Const(p), @Const(gamma), @Const(t))
i = @index(Global, Linear)
len = size(u, 1)
_W = @inbounds @view(W[:, :, i])
@inbounds jac = f[i].tgrad
@views @inbounds x = jac(u[:, i], p[:, i], t)
@inbounds for j in 1:length(_W)
_W[j] = x[j]
end
@views @inbounds jac(_W, u[:, i], p[:, i], t)
@inbounds for i in 1:len
_W[i, i] = -inv(gamma) + _W[i, i]
end
Expand All @@ -267,30 +296,30 @@ end
end
end

@kernel function gpu_kernel_tgrad(f::AbstractArray{T}, du, @Const(u), @Const(p),
@Const(t)) where {T}
i = @index(Global, Linear)
@inbounds f = f[i].tgrad
if eltype(p) <: Number
@views @inbounds f(du[:, i], u[:, i], p[:, i], t)
else
@views @inbounds f(du[:, i], u[:, i], p[i], t)
end
end

@kernel function gpu_kernel_oop_tgrad(f::AbstractArray{T}, du, @Const(u), @Const(p),
@Const(t)) where {T}
i = @index(Global, Linear)
@inbounds f = f[i].tgrad
if eltype(p) <: Number
@views @inbounds x = f(u[:, i], p[:, i], t)
else
@views @inbounds x = f(u[:, i], p[i], t)
end
@inbounds for j in 1:size(du, 1)
du[j, i] = x[j]
end
end
# @kernel function gpu_kernel_tgrad(f::AbstractArray{T}, du, @Const(u), @Const(p),
# @Const(t)) where {T}
# i = @index(Global, Linear)
# @inbounds f = f[i].tgrad
# if eltype(p) <: Number
# @views @inbounds f(du[:, i], u[:, i], p[:, i], t)
# else
# @views @inbounds f(du[:, i], u[:, i], p[i], t)
# end
# end

# @kernel function gpu_kernel_oop_tgrad(f::AbstractArray{T}, du, @Const(u), @Const(p),
# @Const(t)) where {T}
# i = @index(Global, Linear)
# @inbounds f = f[i].tgrad
# if eltype(p) <: Number
# @views @inbounds x = f(u[:, i], p[:, i], t)
# else
# @views @inbounds x = f(u[:, i], p[i], t)
# end
# @inbounds for j in 1:size(du, 1)
# du[j, i] = x[j]
# end
# end

function lufact!(::CPU, W)
len = size(W, 1)
Expand Down
5 changes: 4 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Expand All @@ -15,3 +17,4 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2 changes: 1 addition & 1 deletion test/gpu_kernel_de/gpu_ode_regression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ for alg in algs
@test norm(asol[1].u[end] - sol[1].u[end]) < 6e-3

@test norm(bench_sol.u - sol[1].u) < 2e-3
@test norm(bench_asol.u - asol[1].u) < 4e-3
@test norm(bench_asol.u - asol[1].u) < 5e-3

@test length(sol[1].u) == length(saveat)
@test length(asol[1].u) == length(saveat)
Expand Down
36 changes: 17 additions & 19 deletions test/reverse_ad_tests.jl
Original file line number Diff line number Diff line change
@@ -1,42 +1,40 @@
using OrdinaryDiffEq, Flux, DiffEqGPU, Test
using OrdinaryDiffEq, Optimization, OptimizationOptimisers, DiffEqGPU, Test
import Zygote

include("utils.jl")

function modelf(du, u, p, t)
du[1] = 1.01 * u[1] * p[1] * p[2]
end

function model()
prob = ODEProblem(modelf, u0, (0.0, 1.0), pa)
function model(θ, ensemblealg)
prob = ODEProblem(modelf, [θ[1]], (0.0, 1.0), [θ[2], θ[3]])

function prob_func(prob, i, repeat)
remake(prob, u0 = 0.5 .+ i / 100 .* prob.u0)
end

ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
solve(ensemble_prob, Tsit5(), EnsembleGPUArray(backend), saveat = 0.1,
solve(ensemble_prob, Tsit5(), ensemblealg, saveat = 0.1,
trajectories = 10)
end

# loss function
loss() = sum(abs2, 1.0 .- Array(model()))

data = Iterators.repeated((), 10)

cb = function () # callback function to observe training
@show loss()
callback = function (θ, l) # callback function to observe training
@show l
false
end

pa = [1.0, 2.0]
u0 = [3.0]
opt = ADAM(0.1)
println("Starting to train")

l1 = loss()
θ = [u0; pa]

for epoch in 1:10
Flux.train!(loss, Flux.params([pa]), data, opt; cb = cb)
end
opt = Adam(0.1)
loss_gpu(θ) = sum(abs2, 1.0 .- Array(model(θ, EnsembleCPUArray())))
l1 = loss_gpu(θ)

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_gpu(x), adtype)
optprob = Optimization.OptimizationProblem(optf, θ)

l2 = loss()
@test 3l2 < l1
res_gpu = Optimization.solve(optprob, opt; callback = callback, maxiters = 100)
Loading