Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ FastClosures = "0.3.2"
ForwardDiff = "0.10.38, 1"
Hwloc = "3.3"
InteractiveUtils = "<0.0.1, 1"
JET = "0.9.18"
JET = "0.9.18, 0.11.0"
LinearAlgebra = "1.10"
LinearSolve = "3.12"
NonlinearSolveFirstOrder = "1.3"
Expand Down
2 changes: 1 addition & 1 deletion lib/BoundaryValueDiffEqAscher/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ FastClosures = "0.3.2"
ForwardDiff = "0.10.38, 1"
Hwloc = "3"
InteractiveUtils = "<0.0.1, 1"
JET = "0.9.18"
JET = "0.9.18, 0.11.0"
LinearAlgebra = "1.10"
PreallocationTools = "0.4.24"
Random = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion lib/BoundaryValueDiffEqCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ ConcreteStructs = "0.2.3"
DiffEqBase = "6.167"
ForwardDiff = "0.10.38, 1"
InteractiveUtils = "<0.0.1, 1"
JET = "0.9.18"
JET = "0.9.18, 0.11.0"
LineSearch = "0.1.4"
LinearAlgebra = "1.10"
Logging = "1.10"
Expand Down
2 changes: 2 additions & 0 deletions lib/BoundaryValueDiffEqCore/src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,11 @@ end
coloring_algorithm = GreedyColoringAlgorithm())
end

@inline __default_coloring_algorithm(_) = GreedyColoringAlgorithm()
@inline __default_coloring_algorithm(diffmode::AutoSparse) = isnothing(diffmode) ?
GreedyColoringAlgorithm() :
diffmode.coloring_algorithm
@inline __default_sparsity_detector(_) = TracerLocalSparsityDetector()
@inline __default_sparsity_detector(diffmode::AutoSparse) = isnothing(diffmode) ?
TracerLocalSparsityDetector() :
diffmode.sparsity_detector
Expand Down
113 changes: 98 additions & 15 deletions lib/BoundaryValueDiffEqCore/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ end
Constructs the internal problem based on the type of the boundary value problem and the
algorithm used. It returns either a `NonlinearProblem` or an `OptimizationProblem`.
"""
function __construct_internal_problem(prob::AbstractBVProblem, alg, loss, jac,
function __construct_internal_problem(prob, pt::StandardBVProblem, alg, loss, jac,
jac_prototype, resid_prototype, y, p, M::Int, N::Int)
T = eltype(y)
iip = SciMLBase.isinplace(prob)
Expand All @@ -658,16 +658,19 @@ function __construct_internal_problem(prob::AbstractBVProblem, alg, loss, jac,
jac_prototype = jac_prototype)
return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p)
else
optf = OptimizationFunction{true}(__default_cost(prob.f), AutoFiniteDiff(), # Need to investigate the ForwardDiff dual problem
optf = OptimizationFunction{true}(__default_cost(prob.f),
AutoSparse(get_dense_ad(alg.jac_alg.nonbc_diffmode),
sparsity_detector = __default_sparsity_detector(alg.jac_alg.diffmode)),
cons = loss,
cons_j = jac, cons_jac_prototype = jac_prototype)
cons_j = jac,
cons_jac_prototype = jac_prototype)
lcons, ucons = __extract_lcons_ucons(prob, T, M, N)
return __internal_optimization_problem(
prob, optf, y, p; lcons = lcons, ucons = ucons)
end
end

function __construct_internal_problem(prob::TwoPointBVProblem, alg, loss, jac,
function __construct_internal_problem(prob, pt::TwoPointBVProblem, alg, loss, jac,
jac_prototype, resid_prototype, y, p, M::Int, N::Int)
T = eltype(y)
iip = SciMLBase.isinplace(prob)
Expand All @@ -676,46 +679,126 @@ function __construct_internal_problem(prob::TwoPointBVProblem, alg, loss, jac,
jac_prototype = jac_prototype)
return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p)
else
optf = OptimizationFunction{true}(
__default_cost(prob.f), get_dense_ad(alg.jac_alg.diffmode),
cons = loss, cons_j = jac, cons_jac_prototype = jac_prototype)
optf = OptimizationFunction{true}(__default_cost(prob.f),
AutoSparse(get_dense_ad(alg.jac_alg.diffmode),
sparsity_detector = __default_sparsity_detector(alg.jac_alg.diffmode)),
cons = loss,
cons_j = jac,
cons_jac_prototype = jac_prototype)
lcons, ucons = __extract_lcons_ucons(prob, T, M, N)

return __internal_optimization_problem(
prob, optf, y, p; lcons = lcons, ucons = ucons)
end
end
# Multiple shooting always use inplace version internal problem constructor

# Single shooting use diffmode for StandardBVProblem and TwoPointBVProblem
function __construct_internal_problem(prob, alg, loss, jac, jac_prototype,
resid_prototype, y, p, M::Int, N::Int, ::Nothing)
T = eltype(y)
iip = SciMLBase.isinplace(prob)
if !isnothing(alg.nlsolve) || (isnothing(alg.nlsolve) && isnothing(alg.optimize))
nlf = NonlinearFunction{iip}(loss; jac = jac, resid_prototype = resid_prototype,
jac_prototype = jac_prototype)
return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p)
else
optf = OptimizationFunction{iip}(__default_cost(prob.f),
AutoSparse(get_dense_ad(alg.jac_alg.diffmode),
sparsity_detector = __default_sparsity_detector(alg.jac_alg.diffmode)),
cons = loss,
cons_j = jac,
cons_jac_prototype = jac_prototype)
lcons, ucons = __extract_lcons_ucons(prob, T, M, N)

return __internal_optimization_problem(
prob, optf, y, p; lcons = lcons, ucons = ucons)
end
end

# Multiple shooting always use inplace version internal problem constructor
function __construct_internal_problem(
prob, pt::StandardBVProblem, alg, loss, jac, jac_prototype,
resid_prototype, y, p, M::Int, N::Int, ::Nothing)
T = eltype(y)
if !isnothing(alg.nlsolve) || (isnothing(alg.nlsolve) && isnothing(alg.optimize))
nlf = NonlinearFunction{true}(loss; jac = jac, resid_prototype = resid_prototype,
jac_prototype = jac_prototype)
return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p)
else
optf = OptimizationFunction{true}(
__default_cost(prob.f), get_dense_ad(alg.jac_alg.diffmode),
cons = loss, cons_j = jac, cons_jac_prototype = jac_prototype)
optf = OptimizationFunction{true}(__default_cost(prob.f),
AutoSparse(get_dense_ad(alg.jac_alg.nonbc_diffmode),
sparsity_detector = __default_sparsity_detector(alg.jac_alg.nonbc_diffmode)),
cons = loss,
cons_j = jac,
cons_jac_prototype = jac_prototype)
lcons, ucons = __extract_lcons_ucons(prob, T, M, N)

return __internal_optimization_problem(
prob, optf, y, p; lcons = lcons, ucons = ucons)
end
end
function __construct_internal_problem(
prob::TwoPointBVProblem, alg, loss, jac, jac_prototype,
prob, pt::TwoPointBVProblem, alg, loss, jac, jac_prototype,
resid_prototype, y, p, M::Int, N::Int, ::Nothing)
T = eltype(y)
if !isnothing(alg.nlsolve) || (isnothing(alg.nlsolve) && isnothing(alg.optimize))
nlf = NonlinearFunction{true}(loss; jac = jac, resid_prototype = resid_prototype,
jac_prototype = jac_prototype)
return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p)
else
optf = OptimizationFunction{true}(__default_cost(prob.f),
AutoSparse(get_dense_ad(alg.jac_alg.diffmode),
sparsity_detector = __default_sparsity_detector(alg.jac_alg.nonbc_diffmode)),
cons = loss,
cons_j = jac,
cons_jac_prototype = jac_prototype)
lcons, ucons = __extract_lcons_ucons(prob, T, M, N)

return __internal_optimization_problem(
prob, optf, y, p; lcons = lcons, ucons = ucons)
end
end

# Second order BVProblem
function __construct_internal_problem(
prob, pt::StandardSecondOrderBVProblem, alg, loss, jac,
jac_prototype, resid_prototype, y, p, M::Int, N::Int)
T = eltype(y)
iip = SciMLBase.isinplace(prob)
if !isnothing(alg.nlsolve) || (isnothing(alg.nlsolve) && isnothing(alg.optimize))
nlf = NonlinearFunction{iip}(loss; jac = jac, resid_prototype = resid_prototype,
jac_prototype = jac_prototype)
return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p)
else
optf = OptimizationFunction{true}(
__default_cost(prob.f), get_dense_ad(alg.jac_alg.nonbc_diffmode),
cons = loss, cons_j = jac, cons_jac_prototype = jac_prototype)
optf = OptimizationFunction{iip}(__default_cost(prob.f.f),
AutoSparse(get_dense_ad(alg.jac_alg.nonbc_diffmode),
sparsity_detector = __default_sparsity_detector(alg.jac_alg.nonbc_diffmode)),
cons = loss,
cons_j = jac,
cons_jac_prototype = jac_prototype)
lcons, ucons = __extract_lcons_ucons(prob, T, M, N)
return __internal_optimization_problem(
prob, optf, y, p; lcons = lcons, ucons = ucons)
end
end

# Two point BVProblem
function __construct_internal_problem(
prob, pt::TwoPointSecondOrderBVProblem, alg, loss, jac,
jac_prototype, resid_prototype, y, p, M::Int, N::Int)
T = eltype(y)
iip = SciMLBase.isinplace(prob)
if !isnothing(alg.nlsolve) || (isnothing(alg.nlsolve) && isnothing(alg.optimize))
nlf = NonlinearFunction{iip}(loss; jac = jac, resid_prototype = resid_prototype,
jac_prototype = jac_prototype)
return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p)
else
optf = OptimizationFunction{iip}(__default_cost(prob.f.f),
AutoSparse(get_dense_ad(alg.jac_alg.diffmode),
sparsity_detector = __default_sparsity_detector(alg.jac_alg.diffmode)),
cons = loss,
cons_j = jac,
cons_jac_prototype = jac_prototype)
lcons, ucons = __extract_lcons_ucons(prob, T, M, N)

return __internal_optimization_problem(
Expand Down
2 changes: 1 addition & 1 deletion lib/BoundaryValueDiffEqFIRK/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ FastClosures = "0.3.2"
ForwardDiff = "0.10.38, 1"
Hwloc = "3"
InteractiveUtils = "<0.0.1, 1"
JET = "0.9.18"
JET = "0.9.18, 0.11.0"
LinearAlgebra = "1.10"
LinearSolve = "2.36.2, 3"
Mooncake = "0.4.146"
Expand Down
3 changes: 2 additions & 1 deletion lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ using BoundaryValueDiffEqCore: AbstractBoundaryValueDiffEqAlgorithm,
__construct_internal_problem, __initial_guess_length,
__initial_guess_on_mesh, __flatten_initial_guess,
__build_solution, __Fix3, __split_kwargs, _sparse_like,
get_dense_ad, __internal_optimization_problem
get_dense_ad, __internal_optimization_problem,
__internal_solve

using ConcreteStructs: @concrete
using DiffEqBase: DiffEqBase
Expand Down
42 changes: 35 additions & 7 deletions lib/BoundaryValueDiffEqFIRK/src/firk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ function __perform_firk_iteration(cache::Union{FIRKCacheExpand, FIRKCacheNested}
solve_alg = __concrete_solve_algorithm(nlprob, cache.alg.nlsolve, cache.alg.optimize)
kwargs = __concrete_kwargs(
cache.alg.nlsolve, cache.alg.optimize, cache.nlsolve_kwargs, cache.optimize_kwargs)
sol_nlprob = __solve(nlprob, solve_alg; kwargs...)
sol_nlprob = __internal_solve(nlprob, solve_alg; kwargs...)
recursive_unflatten!(cache.y₀, sol_nlprob.u)

defect_norm = 2 * abstol
Expand Down Expand Up @@ -494,6 +494,13 @@ function __construct_problem(cache::Union{FIRKCacheNested{iip}, FIRKCacheExpand{
u, p, cache.y, pt, cache.bc, cache.mesh, cache, eval_sol, trait)
end

if !isnothing(cache.alg.optimize)
loss = @closure (du,
u,
p) -> __firk_loss!(
du, u, p, cache.y, pt, cache.bc, cache.residual, cache.mesh, cache, trait)
end

return __construct_problem(cache, y, loss_bc, loss_collocation, loss, pt)
end

Expand Down Expand Up @@ -578,7 +585,8 @@ function __construct_problem(
end

resid_prototype = vcat(resid_bc, resid_collocation)
return __construct_internal_problem(cache.prob, cache.alg, loss, jac, jac_prototype,
return __construct_internal_problem(
cache.prob, cache.problem_type, cache.alg, loss, jac, jac_prototype,
resid_prototype, y, cache.p, cache.M, (N - 1) * (stage + 1) + 1)
end

Expand Down Expand Up @@ -637,7 +645,8 @@ function __construct_problem(
end

resid_prototype = copy(resid)
return __construct_internal_problem(cache.prob, cache.alg, loss, jac, jac_prototype,
return __construct_internal_problem(
cache.prob, cache.problem_type, cache.alg, loss, jac, jac_prototype,
resid_prototype, y, cache.p, cache.M, (N - 1) * (stage + 1) + 1)
end

Expand Down Expand Up @@ -716,8 +725,9 @@ function __construct_problem(
end

resid_prototype = vcat(resid_bc, resid_collocation)
return __construct_internal_problem(cache.prob, cache.alg, loss, jac, jac_prototype,
resid_prototype, y, cache.p, cache.M, N)
return __construct_internal_problem(
cache.prob, cache.problem_type, cache.alg, loss, jac,
jac_prototype, resid_prototype, y, cache.p, cache.M, N)
end

function __construct_problem(
Expand Down Expand Up @@ -765,8 +775,9 @@ function __construct_problem(
end

resid_prototype = copy(resid)
return __construct_internal_problem(cache.prob, cache.alg, loss, jac, jac_prototype,
resid_prototype, y, cache.p, cache.M, N)
return __construct_internal_problem(
cache.prob, cache.problem_type, cache.alg, loss, jac,
jac_prototype, resid_prototype, y, cache.p, cache.M, N)
end

@views function __firk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC, residual,
Expand All @@ -791,6 +802,16 @@ end
return nothing
end

# loss function for optimization based solvers
@views function __firk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC,
residual, mesh, cache, trait) where {BC}
bcresid = length(cache.bcresid_prototype)
__firk_loss_bc!(resid[1:bcresid], u, p, pt, bc!, y, mesh, cache, trait)
__firk_loss_collocation!(
resid[(bcresid + 1):end], u, p, y, mesh, residual, cache, trait)
return nothing
end

@views function __firk_loss!(
resid, u, p, y::AbstractVector, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2},
residual, mesh, cache, _, trait::DiffCacheNeeded) where {BC1, BC2}
Expand All @@ -816,6 +837,13 @@ end
return nothing
end

# loss function for optimization based solvers
@views function __firk_loss!(resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2},
residual, mesh, cache, trait) where {BC1, BC2}
__firk_loss!(resid, u, p, y, pt, bc!, residual, mesh, cache, nothing, trait)
return nothing
end

@views function __firk_loss(
u, p, y, pt::StandardBVProblem, bc::BC, mesh, cache, eval_sol, trait) where {BC}
y_ = recursive_unflatten!(y, u)
Expand Down
2 changes: 1 addition & 1 deletion lib/BoundaryValueDiffEqMIRK/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ FastClosures = "0.3.2"
ForwardDiff = "0.10.38, 1"
Hwloc = "3"
InteractiveUtils = "<0.0.1, 1"
JET = "0.9.18"
JET = "0.9.18, 0.11.0"
LinearAlgebra = "1.10"
LinearSolve = "2.36.2, 3"
Mooncake = "0.4"
Expand Down
Loading
Loading