Skip to content

Commit 84cd9ad

Browse files
committed
Handle failure in finish!
1 parent a0ea689 commit 84cd9ad

File tree

3 files changed

+15
-11
lines changed

3 files changed

+15
-11
lines changed

src/independentlylinearizedutils.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ num_us(ils::IndependentlyLinearizedSolution) = length(ils.us)
199199
Base.size(ils::IndependentlyLinearizedSolution) = size(ils.time_mask)
200200
Base.length(ils::IndependentlyLinearizedSolution) = length(ils.ts)
201201

202-
function finish!(ils::IndependentlyLinearizedSolution)
202+
function finish!(ils::IndependentlyLinearizedSolution{T, S}, return_code) where {T,S}
203203
function trim_chunk(chunks::Vector, offset)
204204
chunks = [chunk for chunk in chunks]
205205
if eltype(chunks) <: AbstractVector
@@ -216,18 +216,25 @@ function finish!(ils::IndependentlyLinearizedSolution)
216216
end
217217

218218
ilsc = ils.ilsc::IndependentlyLinearizedSolutionChunks
219-
ts = vcat(trim_chunk(ilsc.t_chunks, ilsc.t_offset)...)
220-
time_mask = hcat(trim_chunk(ilsc.time_masks, ilsc.t_offset)...)
221-
us = [hcat(trim_chunk(ilsc.u_chunks[u_idx], ilsc.u_offsets[u_idx])...)
222-
for u_idx in 1:length(ilsc.u_chunks)]
219+
if return_code == ReturnCode.InitialFailure
220+
# then no (consistent) data to put in, so just put in empty values
221+
ts = Vector{T}()
222+
us = Vector{Matrix{S}}()
223+
time_mask = BitMatrix(undef, 0, 0)
224+
else
225+
ts = vcat(trim_chunk(ilsc.t_chunks, ilsc.t_offset)...)
226+
time_mask = hcat(trim_chunk(ilsc.time_masks, ilsc.t_offset)...)
227+
us = [hcat(trim_chunk(ilsc.u_chunks[u_idx], ilsc.u_offsets[u_idx])...)
228+
for u_idx in 1:length(ilsc.u_chunks)]
229+
end
223230

224231
# Sanity-check lengths
225232
if length(ts) != size(time_mask, 2)
226233
throw(ArgumentError("`length(ts)` must equal `size(time_mask, 2)`!"))
227234
end
228235

229236
# All time masks must start and end with `1`:
230-
if !all(@view time_mask[:, 1]) || !all(@view time_mask[:, end])
237+
if !isempty(time_mask) && (!all(@view time_mask[:, 1]) || !all(@view time_mask[:, end]))
231238
throw(ArgumentError("Time mask must start and end with 1s!"))
232239
end
233240

src/saving.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -433,10 +433,7 @@ function LinearizingSavingCallback(ils::IndependentlyLinearizedSolution{T,S};
433433
end,
434434
# We need to finalize the ils and free our caches
435435
finalize = (c, u, t, integ) -> begin
436-
# Don't run the `finish!` if ils is in an inconsistent state
437-
if check_error(integ) != ReturnCode.InitialFailure
438-
finish!(ils)
439-
end
436+
finish!(ils, check_error(integ))
440437
caches = nothing
441438
end,
442439
# Don't add tstops to the left and right.

test/saving_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ if VERSION >= v"1.9" # stack
226226
u0 = [1.0];
227227
du0 = [1.0];
228228
prob = DAEProblem(f_error2, u0, du0, (0.0, 1.0); differential_vars = [true])
229-
ils = IndependentlyLinearizedSolution(unstable_prob, 0)
229+
ils = IndependentlyLinearizedSolution(prob, 0)
230230
lsc = LinearizingSavingCallback(ils)
231231
sol = solve(prob, DFBDF(); callback = lsc) # this would if we were not failing with grace
232232
@test sol.retcode == ReturnCode.InitialFailure

0 commit comments

Comments
 (0)