Skip to content

Commit a160f55

Browse files
fix: HACK: run parameter initialization in null solutions, handle InitialFailure
1 parent c43eaf8 commit a160f55

File tree

1 file changed

+24
-8
lines changed

1 file changed

+24
-8
lines changed

src/solve.jl

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,18 @@ function step!(integ::NullODEIntegrator, dt = nothing, stop_at_tdt = false)
676676
return nothing
677677
end
678678

679+
function hack_null_solution_init(prob)
680+
if SciMLBase.has_initializeprob(prob.f) && SciMLBase.has_initializeprobpmap(prob.f)
681+
initializeprob = prob.f.initializeprob
682+
nlsol = solve(initializeprob)
683+
success = SciMLBase.successful_retcode(nlsol)
684+
@set! prob.p = prob.f.initializeprobpmap(prob, nlsol)
685+
else
686+
success = true
687+
end
688+
return prob, success
689+
end
690+
679691
function build_null_solution(prob::AbstractDEProblem, args...;
680692
saveat = (),
681693
save_everystep = true,
@@ -702,12 +714,9 @@ function build_null_solution(prob::AbstractDEProblem, args...;
702714

703715
timeseries = [Float64[] for i in 1:length(ts)]
704716

705-
if SciMLBase.has_initializeprob(prob.f) && SciMLBase.has_initializeprobpmap(prob.f)
706-
initializeprob = prob.f.initializeprob
707-
nlsol = solve(initializeprob)
708-
@set! prob.p = prob.f.initializeprobpmap(prob, nlsol)
709-
end
710-
build_solution(prob, nothing, ts, timeseries, retcode = ReturnCode.Success)
717+
prob, success = hack_null_solution_init(prob)
718+
retcode = success ? ReturnCode.Success : ReturnCode.InitialFailure
719+
build_solution(prob, nothing, ts, timeseries, retcode)
711720
end
712721

713722
function build_null_solution(
@@ -720,21 +729,28 @@ function build_null_solution(
720729
saveat isa Number || prob.tspan[1] in saveat,
721730
save_end = true,
722731
kwargs...)
732+
prob, success = hack_null_solution_init(prob)
733+
retcode = success ? ReturnCode.Success : ReturnCode.InitialFailure
723734
SciMLBase.build_solution(prob, nothing, Float64[], nothing;
724-
retcode = ReturnCode.Success)
735+
retcode)
725736
end
726737

727738
function build_null_solution(
728739
prob::NonlinearLeastSquaresProblem,
729740
args...; abstol = 1e-6, kwargs...)
741+
prob, success = hack_null_solution_init(prob)
742+
retcode = success ? ReturnCode.Success : ReturnCode.InitialFailure
743+
730744
if isinplace(prob)
731745
resid = isnothing(prob.f.resid_prototype) ? Float64[] : copy(prob.f.resid_prototype)
732746
prob.f(resid, prob.u0, prob.p)
733747
else
734748
resid = prob.f(prob.u0, prob.p)
735749
end
736750

737-
retcode = norm(resid) < abstol ? ReturnCode.Success : ReturnCode.Failure
751+
if success
752+
retcode = norm(resid) < abstol ? ReturnCode.Success : ReturnCode.Failure
753+
end
738754

739755
SciMLBase.build_solution(prob, nothing, Float64[], resid;
740756
retcode)

0 commit comments

Comments
 (0)