Skip to content

Commit 7050289

Browse files
Merge #126
126: Make save_func not a function of the integrator r=charleskawczynski a=charleskawczynski This PR makes the `save_func` not a function of the integrator, which alleviates a circular dependency upon initializing the integrator. A peel off from #115. Co-authored-by: Charles Kawczynski <[email protected]>
2 parents 23f113d + 7f695c9 commit 7050289

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/integrators.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,11 @@ function DiffEqBase.__init(
7373
tstops = (),
7474
saveat = nothing,
7575
save_everystep = false,
76-
save_func = (u, t, integrator) -> copy(u),
7776
callback = nothing,
7877
advance_to_tstop = false,
79-
dtchangeable = true, # custom kwarg
80-
stepstop = -1, # custom kwarg
78+
save_func = (u, t) -> copy(u), # custom kwarg
79+
dtchangeable = true, # custom kwarg
80+
stepstop = -1, # custom kwarg
8181
kwargs...,
8282
)
8383
(; u0, p) = prob
@@ -91,7 +91,7 @@ function DiffEqBase.__init(
9191
_saveat = saveat
9292
tstops, saveat = tstops_and_saveat_heaps(t0, tf, tstops, saveat)
9393

94-
sol = DiffEqBase.build_solution(prob, alg, typeof(t0)[], typeof(u0)[])
94+
sol = DiffEqBase.build_solution(prob, alg, typeof(t0)[], typeof(save_func(u0, t0))[])
9595
saving_callback =
9696
NonInterpolatingSavingCallback(save_func, DiffEqCallbacks.SavedValues(sol.t, sol.u), save_everystep)
9797
callback = DiffEqBase.CallbackSet(callback, saving_callback)
@@ -270,7 +270,7 @@ function NonInterpolatingSavingCallback(save_func, saved_values, save_everystep)
270270
end
271271
function affect!(integrator)
272272
push!(saved_values.t, integrator.t)
273-
push!(saved_values.saveval, save_func(integrator.u, integrator.t, integrator))
273+
push!(saved_values.saveval, save_func(integrator.u, integrator.t))
274274
end
275275
initialize(cb, u, t, integrator) = condition(u, t, integrator) && affect!(integrator)
276276
finalize(cb, u, t, integrator) = !save_everystep && !isempty(integrator.saveat) && affect!(integrator)

0 commit comments

Comments
 (0)