|
1 |
| -function __solve(prob::BVProblem, alg::Shooting; odesolve_kwargs = (;), |
| 1 | +function __solve(prob::BVProblem, alg_::Shooting; odesolve_kwargs = (;), |
2 | 2 | nlsolve_kwargs = (;), verbose = true, kwargs...)
|
3 | 3 | ig, T, N, _, u0 = __extract_problem_details(prob; dt = 0.1)
|
4 | 4 | _unwrap_val(ig) && verbose &&
|
5 | 5 | @warn "Initial guess provided, but will be ignored for Shooting!"
|
6 | 6 |
|
| 7 | + alg = concretize_jacobian_algorithm(alg_, prob) |
| 8 | + |
7 | 9 | bcresid_prototype, resid_size = __get_bcresid_prototype(prob, u0)
|
8 | 10 | iip, bc, u0, u0_size = isinplace(prob), prob.f.bc, deepcopy(u0), size(u0)
|
9 | 11 | resid_prototype = __vec(bcresid_prototype)
|
10 | 12 |
|
| 13 | + # Construct the residual function |
11 | 14 | ode_kwargs = (; kwargs..., verbose, odesolve_kwargs...)
|
| 15 | + internal_prob = ODEProblem{iip}(prob.f, u0, prob.tspan, prob.p) |
| 16 | + ode_cache_loss_fn = SciMLBase.__init(internal_prob, alg.ode_alg; ode_kwargs...) |
| 17 | + |
12 | 18 | loss_fn = if iip
|
13 |
| - (du, u, p) -> __single_shooting_loss!(du, u, p, prob.f, bc, u0_size, prob.tspan, |
14 |
| - prob.problem_type, resid_size, alg, ode_kwargs) |
| 19 | + (du, u, p) -> __single_shooting_loss!(du, u, p, ode_cache_loss_fn, bc, u0_size, |
| 20 | + prob.problem_type, resid_size) |
| 21 | + else |
| 22 | + (u, p) -> __single_shooting_loss(u, p, ode_cache_loss_fn, bc, u0_size, |
| 23 | + prob.problem_type) |
| 24 | + end |
| 25 | + |
| 26 | + # Construct the jacobian function |
| 27 | + # NOTE: We pass in a separate Jacobian Function because that allows us to cache the |
| 28 | + # the internal ode solve cache. This cache needs to be distinct from the regular |
| 29 | + # residual function cache |
| 30 | + loss_fnₚ = ifelse(iip, (du, u) -> loss_fn(du, u, prob.p), (u) -> loss_fn(u, prob.p)) |
| 31 | + |
| 32 | + # TODO: We probably won't be able to support Symbolics through ODE Solver but we should |
| 33 | + # be able to allow prespecified coloring. |
| 34 | + if alg.jac_alg.diffmode isa AbstractSparseADType |
| 35 | + error("Single Shooting doesn't support sparse AD yet!") |
| 36 | + end |
| 37 | + sd = NoSparsityDetection() |
| 38 | + y_ = similar(resid_prototype) |
| 39 | + |
| 40 | + jac_cache = if iip |
| 41 | + sparse_jacobian_cache(alg.jac_alg.diffmode, sd, loss_fnₚ, y_, vec(u0)) |
| 42 | + else |
| 43 | + sparse_jacobian_cache(alg.jac_alg.diffmode, sd, loss_fnₚ, vec(u0); fx = y_) |
| 44 | + end |
| 45 | + |
| 46 | + ode_cache_jac_fn = __single_shooting_jacobian_ode_cache(internal_prob, jac_cache, |
| 47 | + alg.jac_alg.diffmode, vec(u0), alg.ode_alg; ode_kwargs...) |
| 48 | + |
| 49 | + jac_prototype = init_jacobian(jac_cache) |
| 50 | + |
| 51 | + loss_fn2ₚ = if iip |
| 52 | + (du, u) -> __single_shooting_loss!(du, u, prob.p, ode_cache_jac_fn, bc, u0_size, |
| 53 | + prob.problem_type, resid_size) |
| 54 | + else |
| 55 | + (u) -> __single_shooting_loss(u, prob.p, ode_cache_jac_fn, bc, u0_size, |
| 56 | + prob.problem_type) |
| 57 | + end |
| 58 | + |
| 59 | + jac_fn = if iip |
| 60 | + (J, u, p) -> __single_shooting_jacobian!(J, u, jac_cache, alg.jac_alg.diffmode, |
| 61 | + loss_fn2ₚ, y_) |
15 | 62 | else
|
16 |
| - (u, p) -> __single_shooting_loss(u, p, prob.f, bc, u0_size, prob.tspan, |
17 |
| - prob.problem_type, alg, ode_kwargs) |
| 63 | + (u, p) -> __single_shooting_jacobian(jac_prototype, u, jac_cache, |
| 64 | + alg.jac_alg.diffmode, loss_fn2ₚ) |
18 | 65 | end
|
19 | 66 |
|
20 |
| - nlf = NonlinearFunction{iip}(loss_fn; prob.f.jac_prototype, resid_prototype) |
| 67 | + nlf = NonlinearFunction{iip}(loss_fn; prob.f.jac_prototype, resid_prototype, |
| 68 | + jac = jac_fn) |
21 | 69 | nlprob = if length(resid_prototype) == length(u0)
|
22 | 70 | NonlinearProblem(nlf, vec(u0), prob.p)
|
23 | 71 | else
|
24 | 72 | NonlinearLeastSquaresProblem(nlf, vec(u0), prob.p)
|
25 | 73 | end
|
26 | 74 | opt = __solve(nlprob, alg.nlsolve; nlsolve_kwargs..., verbose, kwargs...)
|
27 | 75 |
|
28 |
| - newprob = ODEProblem{iip}(prob.f, reshape(opt.u, u0_size), prob.tspan, prob.p) |
29 |
| - sol = __solve(newprob, alg.ode_alg; odesolve_kwargs..., verbose, kwargs...) |
| 76 | + SciMLBase.reinit!(ode_cache_loss_fn, reshape(opt.u, u0_size)) |
| 77 | + sol = solve!(ode_cache_loss_fn) |
30 | 78 |
|
31 | 79 | !SciMLBase.successful_retcode(opt) &&
|
32 | 80 | return SciMLBase.solution_new_retcode(sol, ReturnCode.Failure)
|
33 | 81 | return sol
|
34 | 82 | end
|
35 | 83 |
|
36 |
| -function __single_shooting_loss!(resid_, u0_, p, f::F, bc::BC, u0_size, tspan, |
37 |
| - pt::TwoPointBVProblem, (resida_size, residb_size), alg::Shooting, |
38 |
| - kwargs) where {F, BC} |
| 84 | +function __single_shooting_loss!(resid_, u0_, p, cache, bc::BC, u0_size, |
| 85 | + pt::TwoPointBVProblem, (resida_size, residb_size)) where {BC} |
39 | 86 | resida = @view resid_[1:prod(resida_size)]
|
40 | 87 | residb = @view resid_[(prod(resida_size) + 1):end]
|
41 | 88 | resid = (reshape(resida, resida_size), reshape(residb, residb_size))
|
42 | 89 |
|
43 |
| - odeprob = ODEProblem{true}(f, reshape(u0_, u0_size), tspan, p) |
44 |
| - odesol = __solve(odeprob, alg.ode_alg; kwargs...) |
| 90 | + SciMLBase.reinit!(cache, reshape(u0_, u0_size)) |
| 91 | + odesol = solve!(cache) |
| 92 | + |
45 | 93 | eval_bc_residual!(resid, pt, bc, odesol, p)
|
46 | 94 |
|
47 | 95 | return nothing
|
48 | 96 | end
|
49 | 97 |
|
50 |
| -function __single_shooting_loss!(resid_, u0_, p, f::F, bc::BC, u0_size, tspan, |
51 |
| - pt::StandardBVProblem, resid_size, alg::Shooting, kwargs) where {F, BC} |
| 98 | +function __single_shooting_loss!(resid_, u0_, p, cache, bc::BC, u0_size, |
| 99 | + pt::StandardBVProblem, resid_size) where {BC} |
52 | 100 | resid = reshape(resid_, resid_size)
|
53 | 101 |
|
54 |
| - odeprob = ODEProblem{true}(f, reshape(u0_, u0_size), tspan, p) |
55 |
| - odesol = __solve(odeprob, alg.ode_alg; kwargs...) |
| 102 | + SciMLBase.reinit!(cache, reshape(u0_, u0_size)) |
| 103 | + odesol = solve!(cache) |
| 104 | + |
56 | 105 | eval_bc_residual!(resid, pt, bc, odesol, p)
|
57 | 106 |
|
58 | 107 | return nothing
|
59 | 108 | end
|
60 | 109 |
|
61 |
| -function __single_shooting_loss(u0_, p, f::F, bc::BC, u0_size, tspan, pt, alg::Shooting, |
62 |
| - kwargs) where {F, BC} |
63 |
| - odeprob = ODEProblem{false}(f, reshape(u0_, u0_size), tspan, p) |
64 |
| - odesol = __solve(odeprob, alg.ode_alg; kwargs...) |
| 110 | +function __single_shooting_loss(u, p, cache, bc::BC, u0_size, pt) where {BC} |
| 111 | + SciMLBase.reinit!(cache, reshape(u, u0_size)) |
| 112 | + odesol = solve!(cache) |
65 | 113 | return __safe_vec(eval_bc_residual(pt, bc, odesol, p))
|
66 | 114 | end
|
| 115 | + |
| 116 | +function __single_shooting_jacobian!(J, u, jac_cache, diffmode, loss_fn::L, fu) where {L} |
| 117 | + sparse_jacobian!(J, diffmode, jac_cache, loss_fn, fu, vec(u)) |
| 118 | + return J |
| 119 | +end |
| 120 | + |
| 121 | +function __single_shooting_jacobian(J, u, jac_cache, diffmode, loss_fn::L) where {L} |
| 122 | + sparse_jacobian!(J, diffmode, jac_cache, loss_fn, vec(u)) |
| 123 | + return J |
| 124 | +end |
| 125 | + |
| 126 | +function __single_shooting_jacobian_ode_cache(prob, jac_cache, alg, u0, ode_alg; kwargs...) |
| 127 | + prob_ = remake(prob; u0) |
| 128 | + return SciMLBase.__init(prob_, ode_alg; kwargs...) |
| 129 | +end |
| 130 | + |
| 131 | +function __single_shooting_jacobian_ode_cache(prob, jac_cache, |
| 132 | + ::Union{AutoForwardDiff, AutoSparseForwardDiff}, u0, ode_alg; kwargs...) |
| 133 | + cache = jac_cache.cache |
| 134 | + if cache isa ForwardDiff.JacobianConfig |
| 135 | + xduals = cache.duals isa Tuple ? cache.duals[2] : cache.duals |
| 136 | + prob_ = remake(prob; u0 = xduals) |
| 137 | + return SciMLBase.__init(prob_, ode_alg; kwargs...) |
| 138 | + else |
| 139 | + error("Single Shooting doesn't support sparse AD yet!") |
| 140 | + end |
| 141 | +end |
0 commit comments