diff --git a/Project.toml b/Project.toml index 9ecff6938..e37c6e017 100644 --- a/Project.toml +++ b/Project.toml @@ -47,7 +47,7 @@ FastClosures = "0.3.2" ForwardDiff = "0.10.38, 1" Hwloc = "3.3" InteractiveUtils = "<0.0.1, 1" -JET = "0.9.18" +JET = "0.9.18, 0.11.0" LinearAlgebra = "1.10" LinearSolve = "3.12" NonlinearSolveFirstOrder = "1.3" diff --git a/lib/BoundaryValueDiffEqAscher/Project.toml b/lib/BoundaryValueDiffEqAscher/Project.toml index 8bc348ffc..36cafb16b 100644 --- a/lib/BoundaryValueDiffEqAscher/Project.toml +++ b/lib/BoundaryValueDiffEqAscher/Project.toml @@ -32,7 +32,7 @@ FastClosures = "0.3.2" ForwardDiff = "0.10.38, 1" Hwloc = "3" InteractiveUtils = "<0.0.1, 1" -JET = "0.9.18" +JET = "0.9.18, 0.11.0" LinearAlgebra = "1.10" PreallocationTools = "0.4.24" Random = "1.10" diff --git a/lib/BoundaryValueDiffEqAscher/src/BoundaryValueDiffEqAscher.jl b/lib/BoundaryValueDiffEqAscher/src/BoundaryValueDiffEqAscher.jl index 387527271..be869f944 100644 --- a/lib/BoundaryValueDiffEqAscher/src/BoundaryValueDiffEqAscher.jl +++ b/lib/BoundaryValueDiffEqAscher/src/BoundaryValueDiffEqAscher.jl @@ -10,7 +10,8 @@ using BoundaryValueDiffEqCore: AbstractBoundaryValueDiffEqAlgorithm, __internal_nlsolve_problem, __vec, __vec_f, __vec_f!, __vec_bc, __vec_bc!, __extract_mesh, get_dense_ad, __get_bcresid_prototype, __split_kwargs, __concrete_kwargs, - __default_nonsparse_ad, __construct_internal_problem + __default_nonsparse_ad, __construct_internal_problem, + __internal_solve using ConcreteStructs: @concrete using DiffEqBase: DiffEqBase diff --git a/lib/BoundaryValueDiffEqAscher/src/ascher.jl b/lib/BoundaryValueDiffEqAscher/src/ascher.jl index 41804f463..8fabb129e 100644 --- a/lib/BoundaryValueDiffEqAscher/src/ascher.jl +++ b/lib/BoundaryValueDiffEqAscher/src/ascher.jl @@ -181,7 +181,7 @@ function __perform_ascher_iteration(cache::AscherCache{iip, T}, abstol, adaptive solve_alg = __concrete_solve_algorithm(nlprob, cache.alg.nlsolve, cache.alg.optimize) kwargs = __concrete_kwargs( cache.alg.nlsolve, cache.alg.optimize, cache.nlsolve_kwargs, cache.optimize_kwargs) - nlsol = __solve(nlprob, solve_alg; kwargs...) + nlsol = __internal_solve(nlprob, solve_alg; kwargs...) error_norm = 2 * abstol info = nlsol.retcode @@ -351,8 +351,8 @@ function __construct_nlproblem(cache::AscherCache{iip, T}) where {iip, T} end return __construct_internal_problem( - cache.prob, alg, loss, jac, jac_prototype, resid_prototype, - lz, cache.p, cache.ncomp, length(cache.mesh)) + cache.prob, cache.prob.problem_type, alg, loss, jac, jac_prototype, + resid_prototype, lz, cache.p, cache.ncomp, length(cache.mesh)) end function __ascher_mpoint_jacobian!(J, x, diffmode, diffcache, loss, resid, p) diff --git a/lib/BoundaryValueDiffEqCore/Project.toml b/lib/BoundaryValueDiffEqCore/Project.toml index d2e409963..4ca21da56 100644 --- a/lib/BoundaryValueDiffEqCore/Project.toml +++ b/lib/BoundaryValueDiffEqCore/Project.toml @@ -33,7 +33,7 @@ ConcreteStructs = "0.2.3" DiffEqBase = "6.167" ForwardDiff = "0.10.38, 1" InteractiveUtils = "<0.0.1, 1" -JET = "0.9.18" +JET = "0.9.18, 0.11.0" LineSearch = "0.1.4" LinearAlgebra = "1.10" Logging = "1.10" diff --git a/lib/BoundaryValueDiffEqCore/src/types.jl b/lib/BoundaryValueDiffEqCore/src/types.jl index b559a2a6c..365013618 100644 --- a/lib/BoundaryValueDiffEqCore/src/types.jl +++ b/lib/BoundaryValueDiffEqCore/src/types.jl @@ -126,9 +126,11 @@ end coloring_algorithm = GreedyColoringAlgorithm()) end +@inline __default_coloring_algorithm(_) = GreedyColoringAlgorithm() @inline __default_coloring_algorithm(diffmode::AutoSparse) = isnothing(diffmode) ? GreedyColoringAlgorithm() : diffmode.coloring_algorithm +@inline __default_sparsity_detector(_) = TracerLocalSparsityDetector() @inline __default_sparsity_detector(diffmode::AutoSparse) = isnothing(diffmode) ? TracerLocalSparsityDetector() : diffmode.sparsity_detector diff --git a/lib/BoundaryValueDiffEqCore/src/utils.jl b/lib/BoundaryValueDiffEqCore/src/utils.jl index a3096f8c0..2cc396b14 100644 --- a/lib/BoundaryValueDiffEqCore/src/utils.jl +++ b/lib/BoundaryValueDiffEqCore/src/utils.jl @@ -649,7 +649,7 @@ end Constructs the internal problem based on the type of the boundary value problem and the algorithm used. It returns either a `NonlinearProblem` or an `OptimizationProblem`. """ -function __construct_internal_problem(prob::AbstractBVProblem, alg, loss, jac, +function __construct_internal_problem(prob, pt::StandardBVProblem, alg, loss, jac, jac_prototype, resid_prototype, y, p, M::Int, N::Int) T = eltype(y) iip = SciMLBase.isinplace(prob) @@ -658,16 +658,19 @@ function __construct_internal_problem(prob::AbstractBVProblem, alg, loss, jac, jac_prototype = jac_prototype) return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p) else - optf = OptimizationFunction{true}(__default_cost(prob.f), AutoFiniteDiff(), # Need to investigate the ForwardDiff dual problem + optf = OptimizationFunction{true}(__default_cost(prob.f), + AutoSparse(get_dense_ad(alg.jac_alg.nonbc_diffmode), + sparsity_detector = __default_sparsity_detector(alg.jac_alg.diffmode)), cons = loss, - cons_j = jac, cons_jac_prototype = jac_prototype) + cons_j = jac, + cons_jac_prototype = jac_prototype) lcons, ucons = __extract_lcons_ucons(prob, T, M, N) return __internal_optimization_problem( prob, optf, y, p; lcons = lcons, ucons = ucons) end end -function __construct_internal_problem(prob::TwoPointBVProblem, alg, loss, jac, +function __construct_internal_problem(prob, pt::TwoPointBVProblem, alg, loss, jac, jac_prototype, resid_prototype, y, p, M::Int, N::Int) T = eltype(y) iip = SciMLBase.isinplace(prob) @@ -676,27 +679,58 @@ function __construct_internal_problem(prob::TwoPointBVProblem, alg, loss, jac, jac_prototype = jac_prototype) return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p) else - optf = OptimizationFunction{true}( - __default_cost(prob.f), get_dense_ad(alg.jac_alg.diffmode), - cons = loss, cons_j = jac, cons_jac_prototype = jac_prototype) + optf = OptimizationFunction{true}(__default_cost(prob.f), + AutoSparse(get_dense_ad(alg.jac_alg.diffmode), + sparsity_detector = __default_sparsity_detector(alg.jac_alg.diffmode)), + cons = loss, + cons_j = jac, + cons_jac_prototype = jac_prototype) lcons, ucons = __extract_lcons_ucons(prob, T, M, N) return __internal_optimization_problem( prob, optf, y, p; lcons = lcons, ucons = ucons) end end -# Multiple shooting always use inplace version internal problem constructor + +# Single shooting use diffmode for StandardBVProblem and TwoPointBVProblem function __construct_internal_problem(prob, alg, loss, jac, jac_prototype, resid_prototype, y, p, M::Int, N::Int, ::Nothing) T = eltype(y) + iip = SciMLBase.isinplace(prob) + if !isnothing(alg.nlsolve) || (isnothing(alg.nlsolve) && isnothing(alg.optimize)) + nlf = NonlinearFunction{iip}(loss; jac = jac, resid_prototype = resid_prototype, + jac_prototype = jac_prototype) + return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p) + else + optf = OptimizationFunction{iip}(__default_cost(prob.f), + AutoSparse(get_dense_ad(alg.jac_alg.diffmode), + sparsity_detector = __default_sparsity_detector(alg.jac_alg.diffmode)), + cons = loss, + cons_j = jac, + cons_jac_prototype = jac_prototype) + lcons, ucons = __extract_lcons_ucons(prob, T, M, N) + + return __internal_optimization_problem( + prob, optf, y, p; lcons = lcons, ucons = ucons) + end +end + +# Multiple shooting always use inplace version internal problem constructor +function __construct_internal_problem( + prob, pt::StandardBVProblem, alg, loss, jac, jac_prototype, + resid_prototype, y, p, M::Int, N::Int, ::Nothing) + T = eltype(y) if !isnothing(alg.nlsolve) || (isnothing(alg.nlsolve) && isnothing(alg.optimize)) nlf = NonlinearFunction{true}(loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype) return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p) else - optf = OptimizationFunction{true}( - __default_cost(prob.f), get_dense_ad(alg.jac_alg.diffmode), - cons = loss, cons_j = jac, cons_jac_prototype = jac_prototype) + optf = OptimizationFunction{true}(__default_cost(prob.f), + AutoSparse(get_dense_ad(alg.jac_alg.nonbc_diffmode), + sparsity_detector = __default_sparsity_detector(alg.jac_alg.nonbc_diffmode)), + cons = loss, + cons_j = jac, + cons_jac_prototype = jac_prototype) lcons, ucons = __extract_lcons_ucons(prob, T, M, N) return __internal_optimization_problem( @@ -704,18 +738,67 @@ function __construct_internal_problem(prob, alg, loss, jac, jac_prototype, end end function __construct_internal_problem( - prob::TwoPointBVProblem, alg, loss, jac, jac_prototype, + prob, pt::TwoPointBVProblem, alg, loss, jac, jac_prototype, resid_prototype, y, p, M::Int, N::Int, ::Nothing) T = eltype(y) + if !isnothing(alg.nlsolve) || (isnothing(alg.nlsolve) && isnothing(alg.optimize)) + nlf = NonlinearFunction{true}(loss; jac = jac, resid_prototype = resid_prototype, + jac_prototype = jac_prototype) + return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p) + else + optf = OptimizationFunction{true}(__default_cost(prob.f), + AutoSparse(get_dense_ad(alg.jac_alg.diffmode), + sparsity_detector = __default_sparsity_detector(alg.jac_alg.nonbc_diffmode)), + cons = loss, + cons_j = jac, + cons_jac_prototype = jac_prototype) + lcons, ucons = __extract_lcons_ucons(prob, T, M, N) + + return __internal_optimization_problem( + prob, optf, y, p; lcons = lcons, ucons = ucons) + end +end + +# Second order BVProblem +function __construct_internal_problem( + prob, pt::StandardSecondOrderBVProblem, alg, loss, jac, + jac_prototype, resid_prototype, y, p, M::Int, N::Int) + T = eltype(y) iip = SciMLBase.isinplace(prob) if !isnothing(alg.nlsolve) || (isnothing(alg.nlsolve) && isnothing(alg.optimize)) nlf = NonlinearFunction{iip}(loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype) return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p) else - optf = OptimizationFunction{true}( - __default_cost(prob.f), get_dense_ad(alg.jac_alg.nonbc_diffmode), - cons = loss, cons_j = jac, cons_jac_prototype = jac_prototype) + optf = OptimizationFunction{iip}(__default_cost(prob.f.f), + AutoSparse(get_dense_ad(alg.jac_alg.nonbc_diffmode), + sparsity_detector = __default_sparsity_detector(alg.jac_alg.nonbc_diffmode)), + cons = loss, + cons_j = jac, + cons_jac_prototype = jac_prototype) + lcons, ucons = __extract_lcons_ucons(prob, T, M, N) + return __internal_optimization_problem( + prob, optf, y, p; lcons = lcons, ucons = ucons) + end +end + +# Two point BVProblem +function __construct_internal_problem( + prob, pt::TwoPointSecondOrderBVProblem, alg, loss, jac, + jac_prototype, resid_prototype, y, p, M::Int, N::Int) + T = eltype(y) + iip = SciMLBase.isinplace(prob) + if !isnothing(alg.nlsolve) || (isnothing(alg.nlsolve) && isnothing(alg.optimize)) + nlf = NonlinearFunction{iip}(loss; jac = jac, resid_prototype = resid_prototype, + jac_prototype = jac_prototype) + return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p) + else + optf = OptimizationFunction{iip}(__default_cost(prob.f.f), + AutoSparse(get_dense_ad(alg.jac_alg.diffmode), + sparsity_detector = __default_sparsity_detector(alg.jac_alg.diffmode)), + cons = loss, + cons_j = jac, + cons_jac_prototype = jac_prototype) lcons, ucons = __extract_lcons_ucons(prob, T, M, N) return __internal_optimization_problem( diff --git a/lib/BoundaryValueDiffEqFIRK/Project.toml b/lib/BoundaryValueDiffEqFIRK/Project.toml index 559d5834c..44bc624a3 100644 --- a/lib/BoundaryValueDiffEqFIRK/Project.toml +++ b/lib/BoundaryValueDiffEqFIRK/Project.toml @@ -42,7 +42,7 @@ FastClosures = "0.3.2" ForwardDiff = "0.10.38, 1" Hwloc = "3" InteractiveUtils = "<0.0.1, 1" -JET = "0.9.18" +JET = "0.9.18, 0.11.0" LinearAlgebra = "1.10" LinearSolve = "2.36.2, 3" Mooncake = "0.4.146" diff --git a/lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl b/lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl index 7d6b6f916..f836f7917 100644 --- a/lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl +++ b/lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl @@ -23,7 +23,8 @@ using BoundaryValueDiffEqCore: AbstractBoundaryValueDiffEqAlgorithm, __construct_internal_problem, __initial_guess_length, __initial_guess_on_mesh, __flatten_initial_guess, __build_solution, __Fix3, __split_kwargs, _sparse_like, - get_dense_ad, __internal_optimization_problem + get_dense_ad, __internal_optimization_problem, + __internal_solve using ConcreteStructs: @concrete using DiffEqBase: DiffEqBase diff --git a/lib/BoundaryValueDiffEqFIRK/src/firk.jl b/lib/BoundaryValueDiffEqFIRK/src/firk.jl index a6d7abb57..f74327e74 100644 --- a/lib/BoundaryValueDiffEqFIRK/src/firk.jl +++ b/lib/BoundaryValueDiffEqFIRK/src/firk.jl @@ -409,7 +409,7 @@ function __perform_firk_iteration(cache::Union{FIRKCacheExpand, FIRKCacheNested} solve_alg = __concrete_solve_algorithm(nlprob, cache.alg.nlsolve, cache.alg.optimize) kwargs = __concrete_kwargs( cache.alg.nlsolve, cache.alg.optimize, cache.nlsolve_kwargs, cache.optimize_kwargs) - sol_nlprob = __solve(nlprob, solve_alg; kwargs...) + sol_nlprob = __internal_solve(nlprob, solve_alg; kwargs...) recursive_unflatten!(cache.y₀, sol_nlprob.u) defect_norm = 2 * abstol @@ -494,6 +494,13 @@ function __construct_problem(cache::Union{FIRKCacheNested{iip}, FIRKCacheExpand{ u, p, cache.y, pt, cache.bc, cache.mesh, cache, eval_sol, trait) end + if !isnothing(cache.alg.optimize) + loss = @closure (du, + u, + p) -> __firk_loss!( + du, u, p, cache.y, pt, cache.bc, cache.residual, cache.mesh, cache, trait) + end + return __construct_problem(cache, y, loss_bc, loss_collocation, loss, pt) end @@ -578,7 +585,8 @@ function __construct_problem( end resid_prototype = vcat(resid_bc, resid_collocation) - return __construct_internal_problem(cache.prob, cache.alg, loss, jac, jac_prototype, + return __construct_internal_problem( + cache.prob, cache.problem_type, cache.alg, loss, jac, jac_prototype, resid_prototype, y, cache.p, cache.M, (N - 1) * (stage + 1) + 1) end @@ -637,7 +645,8 @@ function __construct_problem( end resid_prototype = copy(resid) - return __construct_internal_problem(cache.prob, cache.alg, loss, jac, jac_prototype, + return __construct_internal_problem( + cache.prob, cache.problem_type, cache.alg, loss, jac, jac_prototype, resid_prototype, y, cache.p, cache.M, (N - 1) * (stage + 1) + 1) end @@ -716,8 +725,9 @@ function __construct_problem( end resid_prototype = vcat(resid_bc, resid_collocation) - return __construct_internal_problem(cache.prob, cache.alg, loss, jac, jac_prototype, - resid_prototype, y, cache.p, cache.M, N) + return __construct_internal_problem( + cache.prob, cache.problem_type, cache.alg, loss, jac, + jac_prototype, resid_prototype, y, cache.p, cache.M, N) end function __construct_problem( @@ -765,8 +775,9 @@ function __construct_problem( end resid_prototype = copy(resid) - return __construct_internal_problem(cache.prob, cache.alg, loss, jac, jac_prototype, - resid_prototype, y, cache.p, cache.M, N) + return __construct_internal_problem( + cache.prob, cache.problem_type, cache.alg, loss, jac, + jac_prototype, resid_prototype, y, cache.p, cache.M, N) end @views function __firk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC, residual, @@ -791,6 +802,16 @@ end return nothing end +# loss function for optimization based solvers +@views function __firk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC, + residual, mesh, cache, trait) where {BC} + bcresid = length(cache.bcresid_prototype) + __firk_loss_bc!(resid[1:bcresid], u, p, pt, bc!, y, mesh, cache, trait) + __firk_loss_collocation!( + resid[(bcresid + 1):end], u, p, y, mesh, residual, cache, trait) + return nothing +end + @views function __firk_loss!( resid, u, p, y::AbstractVector, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2}, residual, mesh, cache, _, trait::DiffCacheNeeded) where {BC1, BC2} @@ -816,6 +837,13 @@ end return nothing end +# loss function for optimization based solvers +@views function __firk_loss!(resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2}, + residual, mesh, cache, trait) where {BC1, BC2} + __firk_loss!(resid, u, p, y, pt, bc!, residual, mesh, cache, nothing, trait) + return nothing +end + @views function __firk_loss( u, p, y, pt::StandardBVProblem, bc::BC, mesh, cache, eval_sol, trait) where {BC} y_ = recursive_unflatten!(y, u) diff --git a/lib/BoundaryValueDiffEqMIRK/Project.toml b/lib/BoundaryValueDiffEqMIRK/Project.toml index f7ade90b4..c796bfac9 100644 --- a/lib/BoundaryValueDiffEqMIRK/Project.toml +++ b/lib/BoundaryValueDiffEqMIRK/Project.toml @@ -42,7 +42,7 @@ FastClosures = "0.3.2" ForwardDiff = "0.10.38, 1" Hwloc = "3" InteractiveUtils = "<0.0.1, 1" -JET = "0.9.18" +JET = "0.9.18, 0.11.0" LinearAlgebra = "1.10" LinearSolve = "2.36.2, 3" Mooncake = "0.4" diff --git a/lib/BoundaryValueDiffEqMIRK/src/mirk.jl b/lib/BoundaryValueDiffEqMIRK/src/mirk.jl index 98d89646c..aafd5e6cc 100644 --- a/lib/BoundaryValueDiffEqMIRK/src/mirk.jl +++ b/lib/BoundaryValueDiffEqMIRK/src/mirk.jl @@ -277,6 +277,13 @@ function __construct_problem(cache::MIRKCache{iip}, y::AbstractVector, y₀::Abs u, p, cache.y, pt, cache.bc, cache.mesh, cache, eval_sol, trait) end + if !isnothing(cache.alg.optimize) + loss = @closure (du, + u, + p) -> __mirk_loss!( + du, u, p, cache.y, pt, cache.bc, cache.residual, cache.mesh, cache, trait) + end + return __construct_problem(cache, y, loss_bc, loss_collocation, loss, pt) end @@ -303,6 +310,16 @@ end return nothing end +# loss function for optimization based solvers +@views function __mirk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC, + residual, mesh, cache, trait) where {BC} + bcresid = length(cache.bcresid_prototype) + __mirk_loss_bc!(resid[1:bcresid], u, p, pt, bc!, y, mesh, cache, trait) + __mirk_loss_collocation!( + resid[(bcresid + 1):end], u, p, y, mesh, residual, cache, trait) + return nothing +end + @views function __mirk_loss!(resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2}, residual, mesh, cache, _, trait::DiffCacheNeeded) where {BC1, BC2} y_ = recursive_unflatten!(y, u) @@ -326,6 +343,13 @@ end return nothing end +# loss function for optimization based solvers +@views function __mirk_loss!(resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2}, + residual, mesh, cache, trait) where {BC1, BC2} + __mirk_loss!(resid, u, p, y, pt, bc!, residual, mesh, cache, nothing, trait) + return nothing +end + @views function __mirk_loss( u, p, y, pt::StandardBVProblem, bc::BC, mesh, cache, EvalSol, trait) where {BC} y_ = recursive_unflatten!(y, u) @@ -459,8 +483,9 @@ function __construct_problem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_colloca end resid_prototype = vcat(resid_bc, resid_collocation) - return __construct_internal_problem(cache.prob, cache.alg, loss, jac, jac_prototype, - resid_prototype, y, cache.p, cache.M, N) + return __construct_internal_problem( + cache.prob, cache.problem_type, cache.alg, loss, jac, + jac_prototype, resid_prototype, y, cache.p, cache.M, N) end function __mirk_mpoint_jacobian!( @@ -548,8 +573,9 @@ function __construct_problem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_colloca end resid_prototype = copy(resid) - return __construct_internal_problem(cache.prob, cache.alg, loss, jac, jac_prototype, - resid_prototype, y, cache.p, cache.M, N) + return __construct_internal_problem( + cache.prob, cache.problem_type, cache.alg, loss, jac, + jac_prototype, resid_prototype, y, cache.p, cache.M, N) end function __mirk_2point_jacobian!(J, x, diffmode, diffcache, loss_fn::L, resid, p) where {L} diff --git a/lib/BoundaryValueDiffEqMIRKN/Project.toml b/lib/BoundaryValueDiffEqMIRKN/Project.toml index adb360d86..d46510a19 100644 --- a/lib/BoundaryValueDiffEqMIRKN/Project.toml +++ b/lib/BoundaryValueDiffEqMIRKN/Project.toml @@ -42,7 +42,7 @@ FastClosures = "0.3.2" ForwardDiff = "0.10.38, 1" Hwloc = "3" InteractiveUtils = "<0.0.1, 1" -JET = "0.9.18" +JET = "0.9.18, 0.11.0" LinearAlgebra = "1.10" PreallocationTools = "0.4.24" PrecompileTools = "1.2" diff --git a/lib/BoundaryValueDiffEqMIRKN/src/BoundaryValueDiffEqMIRKN.jl b/lib/BoundaryValueDiffEqMIRKN/src/BoundaryValueDiffEqMIRKN.jl index fd90fa1d5..7207d4bb7 100644 --- a/lib/BoundaryValueDiffEqMIRKN/src/BoundaryValueDiffEqMIRKN.jl +++ b/lib/BoundaryValueDiffEqMIRKN/src/BoundaryValueDiffEqMIRKN.jl @@ -20,7 +20,7 @@ using BoundaryValueDiffEqCore: AbstractBoundaryValueDiffEqAlgorithm, concrete_jacobian_algorithm, __default_coloring_algorithm, __default_sparsity_detector, interval, __split_kwargs, NoErrorControl, __construct_internal_problem, - __concrete_kwargs + __concrete_kwargs, __internal_solve using ConcreteStructs: @concrete using DiffEqBase: DiffEqBase diff --git a/lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl b/lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl index 05bdb2cd2..507ff5e9a 100644 --- a/lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl +++ b/lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl @@ -119,7 +119,7 @@ function __perform_mirkn_iteration(cache::MIRKNCache) solve_alg = __concrete_solve_algorithm(nlprob, cache.alg.nlsolve, cache.alg.optimize) kwargs = __concrete_kwargs( cache.alg.nlsolve, cache.alg.optimize, cache.nlsolve_kwargs, cache.optimize_kwargs) - sol_nlprob = __solve(nlprob, solve_alg; kwargs...) + sol_nlprob = __internal_solve(nlprob, solve_alg; kwargs...) recursive_unflatten!(cache.y₀, sol_nlprob.u) return sol_nlprob, sol_nlprob.retcode @@ -230,8 +230,9 @@ function __construct_nlproblem(cache::MIRKNCache{iip}, y, loss_bc::BC, loss_coll cache_collocation, loss_bc, loss_collocation, L, cache.p) end resid_prototype = vcat(resid_bc, resid_collocation) - return __construct_internal_problem(cache.prob, cache.alg, loss, jac, jac_prototype, - resid_prototype, y, cache.p, cache.M, 2 * N) + return __construct_internal_problem( + cache.prob, cache.problem_type, cache.alg, loss, jac, + jac_prototype, resid_prototype, y, cache.p, cache.M, 2 * N) end function __construct_nlproblem(cache::MIRKNCache{iip}, y, loss_bc::BC, loss_collocation::C, @@ -272,8 +273,9 @@ function __construct_nlproblem(cache::MIRKNCache{iip}, y, loss_bc::BC, loss_coll end resid_prototype = copy(resid) - return __construct_internal_problem(cache.prob, cache.alg, loss, jac, jac_prototype, - resid_prototype, y, cache.p, cache.M, 2 * N) + return __construct_internal_problem( + cache.prob, cache.problem_type, cache.alg, loss, jac, + jac_prototype, resid_prototype, y, cache.p, cache.M, 2 * N) end function __mirkn_2point_jacobian!(J, x, diffmode, diffcache, loss_fn::L, resid, p) where {L} diff --git a/lib/BoundaryValueDiffEqShooting/Project.toml b/lib/BoundaryValueDiffEqShooting/Project.toml index c7aa7966d..c0e2435d1 100644 --- a/lib/BoundaryValueDiffEqShooting/Project.toml +++ b/lib/BoundaryValueDiffEqShooting/Project.toml @@ -37,7 +37,7 @@ FastClosures = "0.3.2" ForwardDiff = "0.10.38, 1" Hwloc = "3" InteractiveUtils = "<0.0.1, 1" -JET = "0.9.18" +JET = "0.9.18, 0.11.0" LinearAlgebra = "1.10" OrdinaryDiffEqLowOrderRK = "1" OrdinaryDiffEqRosenbrock = "1" diff --git a/lib/BoundaryValueDiffEqShooting/src/BoundaryValueDiffEqShooting.jl b/lib/BoundaryValueDiffEqShooting/src/BoundaryValueDiffEqShooting.jl index d8d7faed9..9d6ced304 100644 --- a/lib/BoundaryValueDiffEqShooting/src/BoundaryValueDiffEqShooting.jl +++ b/lib/BoundaryValueDiffEqShooting/src/BoundaryValueDiffEqShooting.jl @@ -18,7 +18,8 @@ using BoundaryValueDiffEqCore: AbstractBoundaryValueDiffEqAlgorithm, BVPJacobian NoDiffCacheNeeded, DiffCacheNeeded, __extract_mesh, __extract_u0, __has_initial_guess, __initial_guess_length, __initial_guess_on_mesh, __flatten_initial_guess, - __get_non_sparse_ad, __build_solution, __Fix3, get_dense_ad + __get_non_sparse_ad, __build_solution, __Fix3, get_dense_ad, + __internal_solve using ConcreteStructs: @concrete using DiffEqBase: DiffEqBase, solve diff --git a/lib/BoundaryValueDiffEqShooting/src/multiple_shooting.jl b/lib/BoundaryValueDiffEqShooting/src/multiple_shooting.jl index a384f4eda..a32392f55 100644 --- a/lib/BoundaryValueDiffEqShooting/src/multiple_shooting.jl +++ b/lib/BoundaryValueDiffEqShooting/src/multiple_shooting.jl @@ -144,11 +144,11 @@ function __solve_nlproblem!( # NOTE: u_at_nodes is updated inplace nlprob = __construct_internal_problem( - prob, alg, loss_fn, jac_fn, jac_prototype, resid_prototype, - u_at_nodes, prob.p, M, length(nodes), nothing) + prob, prob.problem_type, alg, loss_fn, jac_fn, jac_prototype, + resid_prototype, u_at_nodes, prob.p, M, length(nodes), nothing) nlsolve_alg = __concrete_solve_algorithm(nlprob, alg.nlsolve, alg.optimize) - __solve(nlprob, nlsolve_alg; kwargs...) + __internal_solve(nlprob, nlsolve_alg; kwargs...) return nothing end @@ -220,8 +220,8 @@ function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_ # NOTE: u_at_nodes is updated inplace nlprob = __construct_internal_problem( - prob, alg, loss_fn, jac_fn, jac_prototype, resid_prototype, - u_at_nodes, prob.p, M, length(nodes), nothing) + prob, prob.problem_type, alg, loss_fn, jac_fn, jac_prototype, + resid_prototype, u_at_nodes, prob.p, M, length(nodes), nothing) nlsolve_alg = __concrete_solve_algorithm(nlprob, alg.nlsolve, alg.optimize) __solve(nlprob, nlsolve_alg; kwargs...) diff --git a/lib/BoundaryValueDiffEqShooting/src/single_shooting.jl b/lib/BoundaryValueDiffEqShooting/src/single_shooting.jl index c8e9a567d..b1723ad83 100644 --- a/lib/BoundaryValueDiffEqShooting/src/single_shooting.jl +++ b/lib/BoundaryValueDiffEqShooting/src/single_shooting.jl @@ -76,11 +76,10 @@ function SciMLBase.__solve(prob::BVProblem, alg_::Shooting; abstol = 1e-6, end nlprob = __construct_internal_problem(prob, alg, loss_fn, jac_fn, jac_prototype, - resid_prototype, u0, prob.p, length(u0), 1) + resid_prototype, u0, prob.p, length(u0), 1, nothing) solve_alg = __concrete_solve_algorithm(nlprob, alg.nlsolve, alg.optimize) kwargs = __concrete_kwargs(alg.nlsolve, alg.optimize, nlsolve_kwargs, optimize_kwargs) - #TODO: add verbose kwarg - nlsol = __solve(nlprob, solve_alg; kwargs...) + nlsol = __internal_solve(nlprob, solve_alg; kwargs...) # There is no way to reinit with the same cache with different cache. But not saving # the internal values gives a significant speedup. So we just create a new cache