Skip to content

Commit d5db8a1

Browse files
committed
Propagate multiple shooting jac_alg
1 parent 710ce0b commit d5db8a1

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

src/solve/multiple_shooting.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,17 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
6666
end
6767
end
6868

69+
if prob.problem_type isa TwoPointBVProblem
70+
diffmode_shooting = __get_non_sparse_ad(alg.jac_alg.diffmode)
71+
else
72+
diffmode_shooting = __get_non_sparse_ad(alg.jac_alg.bc_diffmode)
73+
end
74+
shooting_alg = Shooting(alg.ode_alg, alg.nlsolve,
75+
BVPJacobianAlgorithm(diffmode_shooting))
76+
6977
single_shooting_prob = remake(prob; u0 = reshape(u_at_nodes[1:N], u0_size))
70-
return __solve(single_shooting_prob, Shooting(alg.ode_alg; alg.nlsolve);
71-
odesolve_kwargs, nlsolve_kwargs, verbose, kwargs...)
78+
return __solve(single_shooting_prob, shooting_alg; odesolve_kwargs, nlsolve_kwargs,
79+
verbose, kwargs...)
7280
end
7381

7482
# TODO: We can save even more memory by hoisting the preallocated caches for the ODEs
@@ -150,10 +158,8 @@ function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_
150158
internal_ode_kwargs...)
151159

152160
# BC Part
153-
if alg.jac_alg.bc_diffmode isa AbstractSparseADType
154-
error("Multiple Shooting doesn't support sparse AD for Boundary Conditions yet!")
155-
end
156-
sd_bc = NoSparsityDetection()
161+
sd_bc = alg.jac_alg.bc_diffmode isa AbstractSparseADType ?
162+
SymbolicsSparsityDetection() : NoSparsityDetection()
157163
bc_jac_cache = sparse_jacobian_cache(alg.jac_alg.bc_diffmode,
158164
sd_bc, nothing, similar(bcresid_prototype), u_at_nodes)
159165
ode_cache_bc_jac_fn = __multiple_shooting_init_jacobian_odecache(ensemblealg, prob,

src/utils.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,25 @@ end
203203
__vec_bc(sol, p, t, bc, u_size) = vec(bc(__restructure_sol(sol, u_size), p, t))
204204
__vec_bc(sol, p, bc, u_size) = vec(bc(reshape(sol, u_size), p))
205205

206+
__get_non_sparse_ad(ad::AbstractADType) = ad
207+
function __get_non_sparse_ad(ad::AbstractSparseADType)
208+
if ad isa AutoSparseForwardDiff
209+
return AutoForwardDiff{__get_chunksize(ad), typeof(ad.tag)}(ad.tag)
210+
elseif ad isa AutoSparseEnzyme
211+
return AutoEnzyme()
212+
elseif ad isa AutoSparseFiniteDiff
213+
return AutoFiniteDiff()
214+
elseif ad isa AutoSparseReverseDiff
215+
return AutoReverseDiff(ad.compile)
216+
elseif ad isa AutoSparseZygote
217+
return AutoZygote()
218+
else
219+
throw(ArgumentError("Unknown AD Type"))
220+
end
221+
end
222+
223+
__get_chunksize(::AutoSparseForwardDiff{CK}) where {CK} = CK
224+
206225
# Restructure Solution
207226
function __restructure_sol(sol::Vector{<:AbstractArray}, u_size)
208227
return map(Base.Fix2(reshape, u_size), sol)

0 commit comments

Comments
 (0)