Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/integrators.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import DataStructures
import Base.Cartesian: @nexprs

"""
DistributedODEIntegrator <: AbstractODEIntegrator
Expand Down Expand Up @@ -226,14 +227,13 @@ reached_tstop(integrator, tstop, stop_at_tstop = integrator.dtchangeable) =
integrator.t == tstop || (!stop_at_tstop && is_past_t(integrator, tstop))


@inline unrolled_foreach(::Tuple{}, integrator) = nothing
@inline unrolled_foreach(callback, integrator) =
callback.condition(integrator.u, integrator.t, integrator) ? callback.affect!(integrator) : nothing
@inline unrolled_foreach(discrete_callbacks::Tuple{Any}, integrator) =
unrolled_foreach(first(discrete_callbacks), integrator)
@inline function unrolled_foreach(discrete_callbacks::Tuple, integrator)
unrolled_foreach(first(discrete_callbacks), integrator)
unrolled_foreach(Base.tail(discrete_callbacks), integrator)
@generated function unrolled_foreach(::Val{N}, callbacks, integrator) where {N}
return quote
@nexprs $N i -> begin
callback = callbacks[i]
callback.condition(integrator.u, integrator.t, integrator) ? callback.affect!(integrator) : nothing
end
end
end

function __step!(integrator)
Expand All @@ -257,7 +257,7 @@ function __step!(integrator)

# apply callbacks
discrete_callbacks = integrator.callback.discrete_callbacks
unrolled_foreach(discrete_callbacks, integrator)
unrolled_foreach(Val(length(discrete_callbacks)), discrete_callbacks, integrator)

# remove tstops that were just reached
while !isempty(tstops) && reached_tstop(integrator, first(tstops))
Expand Down
12 changes: 5 additions & 7 deletions src/solvers/imex_ark.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import NVTX
import Base.Cartesian: @nexprs

has_jac(T_imp!) =
hasfield(typeof(T_imp!), :Wfact) &&
Expand Down Expand Up @@ -68,7 +69,7 @@ function step_u!(integrator, cache::IMEXARKCache)
end
end

update_stage!(integrator, cache, ntuple(i -> i, Val(s)))
update_stage!(Val(s), integrator, cache)

t_final = t + dt

Expand All @@ -88,13 +89,10 @@ function step_u!(integrator, cache::IMEXARKCache)
return u
end


@inline update_stage!(integrator, cache, ::Tuple{}) = nothing
@inline update_stage!(integrator, cache, is::Tuple{Int}) = update_stage!(integrator, cache, first(is))
@inline function update_stage!(integrator, cache, is::Tuple)
update_stage!(integrator, cache, first(is))
update_stage!(integrator, cache, Base.tail(is))
@generated update_stage!(::Val{s}, integrator, cache::IMEXARKCache) where {s} = quote
@nexprs $s i -> update_stage!(integrator, cache, i)
end

@inline function update_stage!(integrator, cache::IMEXARKCache, i::Int)
(; u, p, t, dt, alg) = integrator
(; f) = integrator.sol.prob
Expand Down
Loading