diff --git a/src/common.jl b/src/common.jl index f447df8e7..de0e4d67d 100644 --- a/src/common.jl +++ b/src/common.jl @@ -337,3 +337,118 @@ function SciMLBase.solve(prob::StaticLinearProblem, return SciMLBase.build_linear_solution( alg, u, nothing, prob; retcode = ReturnCode.Success) end + +# Here to make sure that StaticLinearProblems with Dual elements don't create a Dual linear cache +function SciMLBase.init(prob::StaticLinearProblem, alg::SciMLLinearSolveAlgorithm, + args...; + alias = LinearAliasSpecifier(), + abstol = default_tol(real(eltype(prob.b))), + reltol = default_tol(real(eltype(prob.b))), + maxiters::Int = length(prob.b), + verbose::Bool = false, + Pl = nothing, + Pr = nothing, + assumptions = OperatorAssumptions(issquare(prob.A)), + sensealg = LinearSolveAdjoint(), + kwargs...) + (; A, b, u0, p) = prob + + if haskey(kwargs, :alias_A) || haskey(kwargs, :alias_b) + aliases = LinearAliasSpecifier() + + if haskey(kwargs, :alias_A) + message = "`alias_A` keyword argument is deprecated, to set `alias_A`, + please use an LinearAliasSpecifier, e.g. `solve(prob, alias = LinearAliasSpecifier(alias_A = true))" + Base.depwarn(message, :init) + Base.depwarn(message, :solve) + aliases = LinearAliasSpecifier(alias_A = values(kwargs).alias_A) + end + + if haskey(kwargs, :alias_b) + message = "`alias_b` keyword argument is deprecated, to set `alias_b`, + please use an LinearAliasSpecifier, e.g. `solve(prob, alias = LinearAliasSpecifier(alias_b = true))" + Base.depwarn(message, :init) + Base.depwarn(message, :solve) + aliases = LinearAliasSpecifier( + alias_A = aliases.alias_A, alias_b = values(kwargs).alias_b) + end + else + if alias isa Bool + aliases = LinearAliasSpecifier(alias = alias) + else + aliases = alias + end + end + + if isnothing(aliases.alias_A) + alias_A = default_alias_A(alg, prob.A, prob.b) + else + alias_A = aliases.alias_A + end + + if isnothing(aliases.alias_b) + alias_b = default_alias_b(alg, prob.A, prob.b) + else + alias_b = aliases.alias_b + end + + A = if alias_A || A isa SMatrix + A + elseif A isa Array + copy(A) + elseif issparsematrixcsc(A) + make_SparseMatrixCSC(A) + else + deepcopy(A) + end + + b = if issparsematrix(b) && !(A isa Diagonal) + Array(b) # the solution to a linear solve will always be dense! + elseif alias_b || b isa SVector + b + elseif b isa Array + copy(b) + elseif issparsematrixcsc(b) + # Extension must be loaded if issparsematrixcsc returns true + make_SparseMatrixCSC(b) + else + deepcopy(b) + end + + u0_ = u0 !== nothing ? u0 : __init_u0_from_Ab(A, b) + + # Guard against type mismatch for user-specified reltol/abstol + reltol = real(eltype(prob.b))(reltol) + abstol = real(eltype(prob.b))(abstol) + + precs = if hasproperty(alg, :precs) + isnothing(alg.precs) ? DEFAULT_PRECS : alg.precs + else + DEFAULT_PRECS + end + _Pl, _Pr = precs(A, p) + if isnothing(Pl) + Pl = _Pl + else + # TODO: deprecate once all docs are updated to the new form + #@warn "passing Preconditioners at `init`/`solve` time is deprecated. Instead add a `precs` function to your algorithm." + end + if isnothing(Pr) + Pr = _Pr + else + # TODO: deprecate once all docs are updated to the new form + #@warn "passing Preconditioners at `init`/`solve` time is deprecated. Instead add a `precs` function to your algorithm." + end + cacheval = init_cacheval(alg, A, b, u0_, Pl, Pr, maxiters, abstol, reltol, verbose, + assumptions) + isfresh = true + precsisfresh = false + Tc = typeof(cacheval) + + cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc, + typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq), + typeof(sensealg)}( + A, b, u0_, p, alg, cacheval, isfresh, precsisfresh, Pl, Pr, abstol, reltol, + maxiters, verbose, assumptions, sensealg) + return cache +end \ No newline at end of file diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index 3d4a035f3..97fef8d6f 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -2,6 +2,7 @@ using LinearSolve using ForwardDiff using Test using SparseArrays +using StaticArrays function h(p) (A = [p[1] p[2]+1 p[2]^3; @@ -193,3 +194,39 @@ overload_x_p = solve(prob, UMFPACKFactorization()) backslash_x_p = A \ b @test ≈(overload_x_p, backslash_x_p, rtol = 1e-9) + + +# Test StaticArrays +# They don't go through the overloads +# But we should test that the overloads don't mess anything up +function static_h(p) + A = @SMatrix [p[1] p[2]+1 p[2]^3; + 3*p[1] p[1]+5 p[2] * p[1]-4; + p[2]^2 9*p[1] p[2]] + + b = SA[p[1] + 1, p[2] * 2, p[1]^2] + + (A, b) +end + +function static_linsolve(p) + A, b = static_h(p) + prob = LinearProblem(A, b) + solve(prob) +end + +function static_backslash(p) + A, b = static_h(p) + A \ b +end + +@test (ForwardDiff.jacobian(static_linsolve, [5.0, 5.0]) ≈ + ForwardDiff.jacobian(static_backslash, [5.0, 5.0])) + +#Test to make sure that the cache is not a DualLinearCache +A, b = static_h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) + +static_dual_prob = LinearProblem(A, b) +static_dual_cache = init(static_dual_prob) + +@test static_dual_cache isa LinearSolve.LinearCache \ No newline at end of file