|
| 1 | +module Utils |
| 2 | + |
| 3 | +using ADTypes: AbstractADType, AutoForwardDiff, AutoFiniteDiff, AutoPolyesterForwardDiff |
| 4 | +using ArrayInterface: ArrayInterface |
| 5 | +using DifferentiationInterface: DifferentiationInterface |
| 6 | +using FastClosures: @closure |
| 7 | +using LinearAlgebra: LinearAlgebra, I, diagind |
| 8 | +using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem |
| 9 | +using SciMLBase: SciMLBase, NonlinearLeastSquaresProblem, NonlinearProblem, |
| 10 | + NonlinearFunction |
| 11 | +using StaticArraysCore: StaticArray, SArray, SMatrix, SVector |
| 12 | + |
| 13 | +const DI = DifferentiationInterface |
| 14 | + |
| 15 | +const safe_similar = NonlinearSolveBase.Utils.safe_similar |
| 16 | + |
| 17 | +pickchunksize(n::Int) = min(n, 12) |
| 18 | + |
| 19 | +can_dual(::Type{<:Real}) = true |
| 20 | +can_dual(::Type) = false |
| 21 | + |
| 22 | +maybe_unaliased(x::Union{Number, SArray}, ::Bool) = x |
| 23 | +function maybe_unaliased(x::T, alias::Bool) where {T <: AbstractArray} |
| 24 | + (alias || !ArrayInterface.can_setindex(T)) && return x |
| 25 | + return copy(x) |
| 26 | +end |
| 27 | + |
| 28 | +function get_concrete_autodiff(_, ad::AbstractADType) |
| 29 | + DI.check_available(ad) && return ad |
| 30 | + error("AD Backend $(ad) is not available. This could be because you haven't loaded the \ |
| 31 | + actual backend (See [Differentiation Inferface Docs](https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterface/stable/) \ |
| 32 | + for more details) or the backend might not be supported by DifferentiationInferface.jl.") |
| 33 | +end |
| 34 | +function get_concrete_autodiff( |
| 35 | + prob, ad::Union{AutoForwardDiff{nothing}, AutoPolyesterForwardDiff{nothing}}) |
| 36 | + return get_concrete_autodiff(prob, |
| 37 | + ArrayInterface.parameterless_type(ad)(; |
| 38 | + chunksize = pickchunksize(length(prob.u0)), ad.tag)) |
| 39 | +end |
| 40 | +function get_concrete_autodiff(prob, ::Nothing) |
| 41 | + if can_dual(eltype(prob.u0)) && DI.check_available(AutoForwardDiff()) |
| 42 | + return AutoForwardDiff(; chunksize = pickchunksize(length(prob.u0))) |
| 43 | + end |
| 44 | + DI.check_available(AutoFiniteDiff()) && return AutoFiniteDiff() |
| 45 | + error("Default AD backends are not available. Please load either FiniteDiff or \ |
| 46 | + ForwardDiff for default AD selection to work. Else provide a specific AD \ |
| 47 | + backend (instead of `nothing`) to the solver.") |
| 48 | +end |
| 49 | + |
| 50 | +# NOTE: This doesn't initialize the `f(x)` but just returns a buffer of the same size |
| 51 | +function get_fx(prob::NonlinearLeastSquaresProblem, x) |
| 52 | + if SciMLBase.isinplace(prob) && prob.f.resid_prototype === nothing |
| 53 | + error("Inplace NonlinearLeastSquaresProblem requires a `resid_prototype` to be \ |
| 54 | + specified.") |
| 55 | + end |
| 56 | + return get_fx(prob.f, x, prob.p) |
| 57 | +end |
| 58 | +function get_fx(prob::Union{ImmutableNonlinearProblem, NonlinearProblem}, x) |
| 59 | + return get_fx(prob.f, x, prob.p) |
| 60 | +end |
| 61 | +function get_fx(f::NonlinearFunction, x, p) |
| 62 | + if SciMLBase.isinplace(f) |
| 63 | + f.resid_prototype === nothing && return eltype(x).(f.resid_prototype) |
| 64 | + return safe_similar(x) |
| 65 | + end |
| 66 | + return f(x, p) |
| 67 | +end |
| 68 | + |
| 69 | +function eval_f(prob, fx, x) |
| 70 | + SciMLBase.isinplace(prob) || return prob.f(x, prob.p) |
| 71 | + prob.f(fx, x, prob.p) |
| 72 | + return fx |
| 73 | +end |
| 74 | + |
| 75 | +function fixed_parameter_function(prob::AbstractNonlinearProblem) |
| 76 | + SciMLBase.isinplace(prob) && return @closure (du, u) -> prob.f(du, u, prob.p) |
| 77 | + return Base.Fix2(prob.f, prob.p) |
| 78 | +end |
| 79 | + |
| 80 | +# __init_identity_jacobian(u::Number, fu, α = true) = oftype(u, α) |
| 81 | +# function __init_identity_jacobian(u, fu, α = true) |
| 82 | +# J = __similar(u, promote_type(eltype(u), eltype(fu)), length(fu), length(u)) |
| 83 | +# fill!(J, zero(eltype(J))) |
| 84 | +# J[diagind(J)] .= eltype(J)(α) |
| 85 | +# return J |
| 86 | +# end |
| 87 | +# function __init_identity_jacobian(u::StaticArray, fu, α = true) |
| 88 | +# S1, S2 = length(fu), length(u) |
| 89 | +# J = SMatrix{S1, S2, eltype(u)}(I * α) |
| 90 | +# return J |
| 91 | +# end |
| 92 | + |
| 93 | +identity_jacobian!!(J::Number) = one(J) |
| 94 | +function identity_jacobian!!(J::AbstractVector) |
| 95 | + ArrayInterface.can_setindex(J) || return one.(J) |
| 96 | + fill!(J, true) |
| 97 | + return J |
| 98 | +end |
| 99 | +function identity_jacobian!!(J::AbstractMatrix) |
| 100 | + ArrayInterface.can_setindex(J) || return convert(typeof(J), I) |
| 101 | + J[diagind(J)] .= true |
| 102 | + return J |
| 103 | +end |
| 104 | +identity_jacobian!!(::SMatrix{S1, S2, T}) where {S1, S2, T} = SMatrix{S1, S2, T}(I) |
| 105 | +identity_jacobian!!(::SVector{S1, T}) where {S1, T} = ones(SVector{S1, T}) |
| 106 | + |
| 107 | +end |
0 commit comments