Skip to content

Commit 91bac8f

Browse files
Merge pull request #200 from oxinabox/ox/do_fail_with_grace
Fail with grace with the ILS after initialization failure
2 parents 63efb70 + eff1abf commit 91bac8f

File tree

4 files changed

+34
-11
lines changed

4 files changed

+34
-11
lines changed

src/independentlylinearizedutils.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ num_us(ils::IndependentlyLinearizedSolution) = length(ils.us)
199199
Base.size(ils::IndependentlyLinearizedSolution) = size(ils.time_mask)
200200
Base.length(ils::IndependentlyLinearizedSolution) = length(ils.ts)
201201

202-
function finish!(ils::IndependentlyLinearizedSolution)
202+
function finish!(ils::IndependentlyLinearizedSolution{T, S}, return_code) where {T,S}
203203
function trim_chunk(chunks::Vector, offset)
204204
chunks = [chunk for chunk in chunks]
205205
if eltype(chunks) <: AbstractVector
@@ -216,18 +216,25 @@ function finish!(ils::IndependentlyLinearizedSolution)
216216
end
217217

218218
ilsc = ils.ilsc::IndependentlyLinearizedSolutionChunks
219-
ts = vcat(trim_chunk(ilsc.t_chunks, ilsc.t_offset)...)
220-
time_mask = hcat(trim_chunk(ilsc.time_masks, ilsc.t_offset)...)
221-
us = [hcat(trim_chunk(ilsc.u_chunks[u_idx], ilsc.u_offsets[u_idx])...)
222-
for u_idx in 1:length(ilsc.u_chunks)]
219+
if return_code == ReturnCode.InitialFailure
220+
# then no (consistent) data to put in, so just put in empty values
221+
ts = Vector{T}()
222+
us = Vector{Matrix{S}}()
223+
time_mask = BitMatrix(undef, 0, 0)
224+
else
225+
ts = vcat(trim_chunk(ilsc.t_chunks, ilsc.t_offset)...)
226+
time_mask = hcat(trim_chunk(ilsc.time_masks, ilsc.t_offset)...)
227+
us = [hcat(trim_chunk(ilsc.u_chunks[u_idx], ilsc.u_offsets[u_idx])...)
228+
for u_idx in 1:length(ilsc.u_chunks)]
229+
end
223230

224231
# Sanity-check lengths
225232
if length(ts) != size(time_mask, 2)
226233
throw(ArgumentError("`length(ts)` must equal `size(time_mask, 2)`!"))
227234
end
228235

229236
# All time masks must start and end with `1`:
230-
if !all(@view time_mask[:, 1]) || !all(@view time_mask[:, end])
237+
if !isempty(time_mask) && (!all(@view time_mask[:, 1]) || !all(@view time_mask[:, end]))
231238
throw(ArgumentError("Time mask must start and end with 1s!"))
232239
end
233240

src/saving.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ function LinearizingSavingCallback(ils::IndependentlyLinearizedSolution{T,S};
433433
end,
434434
# We need to finalize the ils and free our caches
435435
finalize = (c, u, t, integ) -> begin
436-
finish!(ils)
436+
finish!(ils, check_error(integ))
437437
caches = nothing
438438
end,
439439
# Don't add tstops to the left and right.

test/independentlylinearizedtests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,20 +99,20 @@ display(@benchmark sample(ils, many_ts))
9999
# Muck with the `ilsc` here to explore some failure options
100100
# Error: timepoints longer than time_matrix
101101
ilsc.t_offset += 1
102-
@test_throws ArgumentError finish!(ils)
102+
@test_throws ArgumentError finish!(ils, ReturnCode.Default)
103103
ilsc.t_offset -= 1
104104

105105
# Error: time matrix row two has N elements, but `us` has N+1
106106
ilsc.u_offsets[2] += 1
107-
@test_throws ArgumentError finish!(ils)
107+
@test_throws ArgumentError finish!(ils, ReturnCode.Default)
108108
ilsc.u_offsets[2] -= 1
109109

110110
# Error: one of our time masks doesn't start with `1`
111111
ilsc.time_masks[1][1, 1] = 0
112-
@test_throws ArgumentError finish!(ils)
112+
@test_throws ArgumentError finish!(ils, ReturnCode.Default)
113113
ilsc.time_masks[1][1, 1] = 1
114114

115-
finish!(ils)
115+
finish!(ils, ReturnCode.Default)
116116
@test sample(ils, ils.ts) == repeat(1:num_timepoints, 1, num_us)
117117
@test sample(ils, ils.ts, 1) == 2*repeat(1:num_timepoints, 1, num_us)
118118
end

test/saving_tests.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,22 @@ if VERSION >= v"1.9" # stack
220220
end
221221
end
222222

223+
224+
@testset "fail gracefully" begin
225+
f_error2(du, u, p, t) = du .= u ./ t .- 1
226+
u0 = [1.0];
227+
du0 = [1.0];
228+
prob = DAEProblem(f_error2, u0, du0, (0.0, 1.0); differential_vars = [true])
229+
ils = IndependentlyLinearizedSolution(prob, 0)
230+
lsc = LinearizingSavingCallback(ils)
231+
sol = solve(prob, DFBDF(); callback = lsc) # this would if we were not failing with grace
232+
@test sol.retcode == ReturnCode.InitialFailure
233+
@test isempty(ils.ts)
234+
@test isempty(ils.us)
235+
@test isempty(ils.time_mask)
236+
end
237+
238+
223239
# We do not support 2d states yet.
224240
#test_linearization(prob_ode_2Dlinear, Tsit5())
225241
end

0 commit comments

Comments
 (0)