Skip to content
161 changes: 96 additions & 65 deletions src/ensemblegpukernel/lowerlevel_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ function vectorized_solve(probs, prob::ODEProblem, alg;

prob = convert(ImmutableODEProblem, prob)
dt = convert(eltype(prob.tspan), dt)
saveat_converted = nothing

if saveat === nothing
if save_everystep
Expand All @@ -51,34 +52,30 @@ function vectorized_solve(probs, prob::ODEProblem, alg;
fill!(ts, prob.tspan[1])
us = allocate(backend, typeof(prob.u0), (len, length(probs)))
else
saveat = if saveat isa AbstractRange
_saveat = range(convert(eltype(prob.tspan), first(saveat)),
convert(eltype(prob.tspan), last(saveat)),
length = length(saveat))
convert(
StepRangeLen{
eltype(_saveat),
eltype(_saveat),
eltype(_saveat),
eltype(_saveat) === Float32 ? Int32 : Int64
},
_saveat)
elseif saveat isa AbstractVector
adapt(backend, convert.(eltype(prob.tspan), saveat))
# Get the time type from the problem
Tt = eltype(prob.tspan)

# FIX for Issue #379: Convert saveat to proper type
saveat_converted = if saveat isa AbstractRange || saveat isa AbstractArray
Tt.(collect(saveat))
else
_saveat = prob.tspan[1]:convert(eltype(prob.tspan), saveat):prob.tspan[end]
convert(
StepRangeLen{
eltype(_saveat),
eltype(_saveat),
eltype(_saveat),
eltype(_saveat) === Float32 ? Int32 : Int64
},
_saveat)
# saveat is a Number (step size)
t0, tf = Tt.(prob.tspan)
if Tt(saveat) == Tt(0.0)
Tt.([t0, tf])
else
num_points = Int(ceil(abs(tf - t0) / abs(Tt(saveat)))) + 1
Tt.(collect(range(t0, tf, length = num_points)))
end
end
ts = allocate(backend, typeof(dt), (length(saveat), length(probs)))

ts = allocate(backend, typeof(dt), (length(saveat_converted), length(probs)))
fill!(ts, prob.tspan[1])
us = allocate(backend, typeof(prob.u0), (length(saveat), length(probs)))
us = allocate(backend, typeof(prob.u0), (length(saveat_converted), length(probs)))
end

if saveat_converted !== nothing
saveat_converted = adapt(backend, saveat_converted)
end

tstops = adapt(backend, tstops)
Expand All @@ -89,7 +86,7 @@ function vectorized_solve(probs, prob::ODEProblem, alg;
@warn "Running the kernel on CPU"
end

kernel(probs, alg, us, ts, dt, callback, tstops, nsteps, saveat,
kernel(probs, alg, us, ts, dt, callback, tstops, nsteps, saveat_converted,
Val(save_everystep);
ndrange = length(probs))

Expand All @@ -111,7 +108,7 @@ function vectorized_solve(probs, prob::SDEProblem, alg;
backend = maybe_prefer_blocks(backend)

dt = convert(eltype(prob.tspan), dt)

saveat_converted = nothing
if saveat === nothing
if save_everystep
len = length(prob.tspan[1]:dt:prob.tspan[2])
Expand All @@ -122,20 +119,30 @@ function vectorized_solve(probs, prob::SDEProblem, alg;
fill!(ts, prob.tspan[1])
us = allocate(backend, typeof(prob.u0), (len, length(probs)))
else
saveat = if saveat isa AbstractRange
range(convert(eltype(prob.tspan), first(saveat)),
convert(eltype(prob.tspan), last(saveat)),
length = length(saveat))
elseif saveat isa AbstractVector
convert.(eltype(prob.tspan), adapt(backend, saveat))
# Get the time type from the problem
Tt = eltype(prob.tspan)

# FIX for Issue #379: Convert saveat to proper type
saveat_converted = if saveat isa AbstractRange || saveat isa AbstractArray
Tt.(collect(saveat))
else
prob.tspan[1]:convert(eltype(prob.tspan), saveat):prob.tspan[end]
# saveat is a Number (step size)
t0, tf = Tt.(prob.tspan)
if Tt(saveat) == Tt(0.0)
Tt.([t0, tf])
else
num_points = Int(ceil(abs(tf - t0) / abs(Tt(saveat)))) + 1
Tt.(collect(range(t0, tf, length = num_points)))
end
end
ts = allocate(backend, typeof(dt), (length(saveat), length(probs)))

ts = allocate(backend, typeof(dt), (length(saveat_converted), length(probs)))
fill!(ts, prob.tspan[1])
us = allocate(backend, typeof(prob.u0), (length(saveat), length(probs)))
us = allocate(backend, typeof(prob.u0), (length(saveat_converted), length(probs)))
end
if saveat_converted !== nothing
saveat_converted = adapt(backend, saveat_converted)
end

if alg isa GPUEM
kernel = em_kernel(backend)
elseif alg isa Union{GPUSIEA}
Expand All @@ -148,7 +155,7 @@ function vectorized_solve(probs, prob::SDEProblem, alg;
@warn "Running the kernel on CPU"
end

kernel(probs, us, ts, dt, saveat, Val(save_everystep);
kernel(probs, us, ts, dt, saveat_converted, Val(save_everystep);
ndrange = length(probs))
ts, us
end
Expand Down Expand Up @@ -184,62 +191,86 @@ function vectorized_asolve(probs, prob::ODEProblem, alg;
abstol = 1.0f-6, reltol = 1.0f-3,
debug = false, callback = CallbackSet(nothing), tstops = nothing,
kwargs...)

backend = get_backend(probs)
backend = maybe_prefer_blocks(backend)

prob = convert(ImmutableODEProblem, prob)
# Get the time type from the problem
Tt = eltype(prob.tspan)

# FIX for Issue #379: Convert saveat to eliminate
# StepRangeLen's internal Float64 fields which crash Metal

if saveat !== nothing
if saveat isa Number
# Handle edge case: saveat = 0.0 means only save endpoints
if Tt(saveat) == Tt(0.0)
saveat_converted = Tt.([prob.tspan[1], prob.tspan[2]])
else
# Create proper range with correct type
t0, tf = Tt.(prob.tspan)

# Handle both forward and reverse time integration
num_points = Int(ceil(abs(tf - t0) / abs(Tt(saveat)))) + 1

# Safety check: prevent massive arrays
max_saveat_length = 100_000
if num_points > max_saveat_length
error("saveat would create too many save points ($num_points). " *
"Consider using a larger saveat value.")
end

# Create range and convert to pure Vector{Tt}
saveat_range = range(t0, tf, length = num_points)
saveat_converted = Tt.(collect(saveat_range))
end
elseif saveat isa AbstractRange || saveat isa AbstractArray
# Range or array - convert all elements to Tt
# This eliminates StepRangeLen's Float64 internals
saveat_converted = Tt.(collect(saveat))
else
# Already in correct form
saveat_converted = saveat
end
else
saveat_converted = nothing
end

prob = convert(ImmutableODEProblem, prob)
dt = convert(eltype(prob.tspan), dt)
abstol = convert(eltype(prob.tspan), abstol)
reltol = convert(eltype(prob.tspan), reltol)
# if saveat is specified, we'll use a vector of timestamps.
# otherwise it's a matrix that may be different for each ODE.
if saveat === nothing

if saveat_converted === nothing
if save_everystep
error("Don't use adaptive version with saveat == nothing and save_everystep = true")
len = ceil(Int, (prob.tspan[2] - prob.tspan[1]) / dt) + 1
else
len = 2
end
# if tstops !== nothing
# len += length(tstops)
# end
ts = allocate(backend, typeof(dt), (len, length(probs)))
fill!(ts, prob.tspan[1])
us = allocate(backend, typeof(prob.u0), (len, length(probs)))
else
saveat = if saveat isa AbstractRange
range(convert(eltype(prob.tspan), first(saveat)),
convert(eltype(prob.tspan), last(saveat)),
length = length(saveat))
elseif saveat isa AbstractVector
adapt(backend, convert.(eltype(prob.tspan), saveat))
else
prob.tspan[1]:convert(eltype(prob.tspan), saveat):prob.tspan[end]
end
ts = allocate(backend, typeof(dt), (length(saveat), length(probs)))
ts = allocate(backend, typeof(dt), (length(saveat_converted), length(probs)))
fill!(ts, prob.tspan[1])
us = allocate(backend, typeof(prob.u0), (length(saveat), length(probs)))
us = allocate(backend, typeof(prob.u0), (length(saveat_converted), length(probs)))
end

us = adapt(backend, us)
ts = adapt(backend, ts)
tstops = adapt(backend, tstops)

if saveat_converted !== nothing
saveat_converted = adapt(backend, saveat_converted)
end
kernel = ode_asolve_kernel(backend)

if backend isa CPU
@warn "Running the kernel on CPU"
end

kernel(probs, alg, us, ts, dt, callback, tstops,
abstol, reltol, saveat, Val(save_everystep);
abstol, reltol, saveat_converted, Val(save_everystep);
ndrange = length(probs))

# we build the actual solution object on the CPU because the GPU would create one
# containig CuDeviceArrays, which we cannot use on the host (not GC tracked,
# no useful operations, etc). That's unfortunate though, since this loop is
# generally slower than the entire GPU execution, and necessitates synchronization
#EDIT: Done when using with DiffEqGPU
ts, us
end

Expand Down
2 changes: 1 addition & 1 deletion src/ensemblegpukernel/nlsolve/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Δz = linear_solve(W_eval, f_rhs)
z_i = z_i - Δz

if norm(dt * integrator.f(tmp + γ * z_i, p, t + c * dt) - z_i) < abstol
if diffeqgpunorm(dt * integrator.f(tmp + γ * z_i, p, t + c * dt) - z_i, t) < abstol
break
end
end
Expand Down
38 changes: 17 additions & 21 deletions src/ensemblegpukernel/perform_step/gpu_em_perform_step.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
@kernel function em_kernel(@Const(probs), _us, _ts, dt,
saveat, ::Val{save_everystep}) where {save_everystep}
i = @index(Global, Linear)

# get the actual problem for this thread
prob = @inbounds probs[i]

Random.seed!(prob.seed)

# get the input/output arrays for this thread
ts = @inbounds view(_ts, :, i)
us = @inbounds view(_us, :, i)

_saveat = get(prob.kwargs, :saveat, nothing)

saveat = _saveat === nothing ? saveat : _saveat

f = prob.f
g = prob.g
u0 = prob.u0
tspan = prob.tspan
p = prob.p


# FIX for Issue #379: Get time type from tspan
Tt = typeof(tspan[1])

is_diagonal_noise = SciMLBase.is_diagonal_noise(prob)

cur_t = 0
if saveat !== nothing
cur_t = 1
Expand All @@ -34,42 +30,42 @@
@inbounds ts[1] = tspan[1]
@inbounds us[1] = u0
end

sqdt = sqrt(dt)

# FIX: Use Tt for sqrt to ensure proper type
sqdt = sqrt(Tt(dt))
u = copy(u0)
t = copy(tspan[1])
n = length(tspan[1]:dt:tspan[2])


# FIX: Ensure n calculation uses proper types
t0, tf = tspan[1], tspan[2]
n = floor(Int, abs(tf - t0) / abs(Tt(dt))) + 1

for j in 2:n
uprev = u

if is_diagonal_noise
u = uprev + f(uprev, p, t) * dt +
u = uprev + f(uprev, p, t) * Tt(dt) +
sqdt * g(uprev, p, t) .* randn(typeof(u0))
else
u = uprev + f(uprev, p, t) * dt +
u = uprev + f(uprev, p, t) * Tt(dt) +
sqdt * g(uprev, p, t) * randn(typeof(prob.noise_rate_prototype[1, :]))
end

t += dt

t += Tt(dt)
if saveat === nothing && save_everystep
@inbounds us[j] = u
@inbounds ts[j] = t
elseif saveat !== nothing
while cur_t <= length(saveat) && saveat[cur_t] <= t
savet = saveat[cur_t]
Θ = (savet - (t - dt)) / dt
Θ = (savet - (t - Tt(dt))) / Tt(dt)
# Linear Interpolation
@inbounds us[cur_t] = uprev + (u - uprev) * Θ
@inbounds ts[cur_t] = savet
cur_t += 1
end
end
end

if saveat === nothing && !save_everystep
@inbounds us[2] = u
@inbounds ts[2] = t
end
end
end
Loading