Skip to content

Commit 60536ee

Browse files
committed
Fix types
1 parent 19ce0c6 commit 60536ee

File tree

2 files changed

+15
-21
lines changed

2 files changed

+15
-21
lines changed

src/solve/mirk.jl

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y::AbstractVector) where {
212212
end
213213

214214
function __mirk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC, residual, mesh,
215-
cache) where {BC <: Function}
215+
cache) where {BC}
216216
y_ = recursive_unflatten!(y, u)
217217
resids = [get_tmp(r, u) for r in residual]
218218
eval_bc_residual!(resids[1], pt, bc!, y_, p, mesh)
@@ -222,7 +222,7 @@ function __mirk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC, residual,
222222
end
223223

224224
function __mirk_loss!(resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2}, residual,
225-
mesh, cache) where {BC1 <: Function, BC2 <: Function}
225+
mesh, cache) where {BC1, BC2}
226226
y_ = recursive_unflatten!(y, u)
227227
resids = [get_tmp(r, u) for r in residual]
228228
resida = @view resids[1][1:prod(cache.resid_size[1])]
@@ -233,29 +233,28 @@ function __mirk_loss!(resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2
233233
return nothing
234234
end
235235

236-
function __mirk_loss(u, p, y, pt::StandardBVProblem, bc::BC, mesh,
237-
cache) where {BC <: Function}
236+
function __mirk_loss(u, p, y, pt::StandardBVProblem, bc::BC, mesh, cache) where {BC}
238237
y_ = recursive_unflatten!(y, u)
239238
resid_bc = eval_bc_residual(pt, bc, y_, p, mesh)
240239
resid_co = Φ(cache, y_, u, p)
241240
return vcat(resid_bc, mapreduce(vec, vcat, resid_co))
242241
end
243242

244243
function __mirk_loss(u, p, y, pt::TwoPointBVProblem, bc::Tuple{BC1, BC2}, mesh,
245-
cache) where {BC1 <: Function, BC2 <: Function}
244+
cache) where {BC1, BC2}
246245
y_ = recursive_unflatten!(y, u)
247246
resid_bca, resid_bcb = eval_bc_residual(pt, bc, y_, p, mesh)
248247
resid_co = Φ(cache, y_, u, p)
249248
return vcat(resid_bca, mapreduce(vec, vcat, resid_co), resid_bcb)
250249
end
251250

252-
function __mirk_loss_bc!(resid, u, p, pt, bc!::BC, y, mesh) where {BC <: Function}
251+
function __mirk_loss_bc!(resid, u, p, pt, bc!::BC, y, mesh) where {BC}
253252
y_ = recursive_unflatten!(y, u)
254253
eval_bc_residual!(resid, pt, bc!, y_, p, mesh)
255254
return nothing
256255
end
257256

258-
function __mirk_loss_bc(u, p, pt, bc!::BC, y, mesh) where {BC <: Function}
257+
function __mirk_loss_bc(u, p, pt, bc!::BC, y, mesh) where {BC}
259258
y_ = recursive_unflatten!(y, u)
260259
return eval_bc_residual(pt, bc!, y_, p, mesh)
261260
end
@@ -275,7 +274,7 @@ function __mirk_loss_collocation(u, p, y, mesh, residual, cache)
275274
end
276275

277276
function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collocation::C,
278-
loss::L, ::StandardBVProblem) where {iip, BC <: Function, C <: Function, L <: Function}
277+
loss::L, ::StandardBVProblem) where {iip, BC, C, L}
279278
@unpack nlsolve, jac_alg = cache.alg
280279
N = length(cache.mesh)
281280

@@ -317,24 +316,23 @@ end
317316

318317
function __mirk_mpoint_jacobian!(J, x, p, bc_diffmode, nonbc_diffmode, bc_diffcache,
319318
nonbc_diffcache, loss_bc::BC, loss_collocation::C, resid_bc, resid_collocation,
320-
M::Int) where {BC <: Function, C <: Function}
319+
M::Int) where {BC, C}
321320
sparse_jacobian!(@view(J[1:M, :]), bc_diffmode, bc_diffcache, loss_bc, resid_bc, x)
322321
sparse_jacobian!(@view(J[(M + 1):end, :]), nonbc_diffmode, nonbc_diffcache,
323322
loss_collocation, resid_collocation, x)
324323
return nothing
325324
end
326325

327326
function __mirk_mpoint_jacobian(x, p, J, bc_diffmode, nonbc_diffmode, bc_diffcache,
328-
nonbc_diffcache, loss_bc::BC, loss_collocation::C,
329-
M::Int) where {BC <: Function, C <: Function}
327+
nonbc_diffcache, loss_bc::BC, loss_collocation::C, M::Int) where {BC, C}
330328
sparse_jacobian!(@view(J[1:M, :]), bc_diffmode, bc_diffcache, loss_bc, x)
331329
sparse_jacobian!(@view(J[(M + 1):end, :]), nonbc_diffmode, nonbc_diffcache,
332330
loss_collocation, x)
333331
return J
334332
end
335333

336334
function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collocation::C,
337-
loss::L, ::TwoPointBVProblem) where {iip, BC <: Function, C <: Function, L <: Function}
335+
loss::L, ::TwoPointBVProblem) where {iip, BC, C, L}
338336
@unpack nlsolve, jac_alg = cache.alg
339337
N = length(cache.mesh)
340338

@@ -366,14 +364,12 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
366364
return NonlinearProblem(NonlinearFunction{iip}(loss; jac, jac_prototype), y, cache.p)
367365
end
368366

369-
function __mirk_2point_jacobian!(J, x, p, diffmode, diffcache, loss_fn::L,
370-
resid) where {L <: Function}
367+
function __mirk_2point_jacobian!(J, x, p, diffmode, diffcache, loss_fn::L, resid) where {L}
371368
sparse_jacobian!(J, diffmode, diffcache, loss_fn, resid, x)
372369
return J
373370
end
374371

375-
function __mirk_2point_jacobian(x, p, J, diffmode, diffcache,
376-
loss_fn::L) where {L <: Function}
372+
function __mirk_2point_jacobian(x, p, J, diffmode, diffcache, loss_fn::L) where {L}
377373
sparse_jacobian!(J, diffmode, diffcache, loss_fn, x)
378374
return J
379375
end

src/solve/single_shooting.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ function __solve(prob::BVProblem, alg::Shooting; odesolve_kwargs = (;),
2929
end
3030

3131
function __single_shooting_loss!(resid_, u0_, p, f::F, bc::BC, u0_size, tspan,
32-
pt::TwoPointBVProblem, (resida_size, residb_size), alg::Shooting,
33-
kwargs) where {F <: Function, BC <: Function}
32+
pt::TwoPointBVProblem, (resida_size, residb_size), alg::Shooting, kwargs) where {F, BC}
3433
resida = @view resid_[1:prod(resida_size)]
3534
residb = @view resid_[(prod(resida_size) + 1):end]
3635
resid = (reshape(resida, resida_size), reshape(residb, residb_size))
@@ -43,8 +42,7 @@ function __single_shooting_loss!(resid_, u0_, p, f::F, bc::BC, u0_size, tspan,
4342
end
4443

4544
function __single_shooting_loss!(resid_, u0_, p, f::F, bc::BC, u0_size, tspan,
46-
pt::StandardBVProblem, resid_size, alg::Shooting,
47-
kwargs) where {F <: Function, BC <: Function}
45+
pt::StandardBVProblem, resid_size, alg::Shooting, kwargs) where {F, BC}
4846
resid = reshape(resid_, resid_size)
4947

5048
odeprob = ODEProblem{true}(f, reshape(u0_, u0_size), tspan, p)
@@ -55,7 +53,7 @@ function __single_shooting_loss!(resid_, u0_, p, f::F, bc::BC, u0_size, tspan,
5553
end
5654

5755
function __single_shooting_loss(u0_, p, f::F, bc::BC, u0_size, tspan, pt, alg::Shooting,
58-
kwargs) where {F <: Function, BC <: Function}
56+
kwargs) where {F, BC}
5957
odeprob = ODEProblem{false}(f, reshape(u0_, u0_size), tspan, p)
6058
odesol = __solve(odeprob, alg.ode_alg; kwargs...)
6159
return __safe_vec(eval_bc_residual(pt, bc, odesol, p))

0 commit comments

Comments
 (0)