diff --git a/src/common.jl b/src/common.jl index de0e4d67d..f447df8e7 100644 --- a/src/common.jl +++ b/src/common.jl @@ -337,118 +337,3 @@ function SciMLBase.solve(prob::StaticLinearProblem, return SciMLBase.build_linear_solution( alg, u, nothing, prob; retcode = ReturnCode.Success) end - -# Here to make sure that StaticLinearProblems with Dual elements don't create a Dual linear cache -function SciMLBase.init(prob::StaticLinearProblem, alg::SciMLLinearSolveAlgorithm, - args...; - alias = LinearAliasSpecifier(), - abstol = default_tol(real(eltype(prob.b))), - reltol = default_tol(real(eltype(prob.b))), - maxiters::Int = length(prob.b), - verbose::Bool = false, - Pl = nothing, - Pr = nothing, - assumptions = OperatorAssumptions(issquare(prob.A)), - sensealg = LinearSolveAdjoint(), - kwargs...) - (; A, b, u0, p) = prob - - if haskey(kwargs, :alias_A) || haskey(kwargs, :alias_b) - aliases = LinearAliasSpecifier() - - if haskey(kwargs, :alias_A) - message = "`alias_A` keyword argument is deprecated, to set `alias_A`, - please use an LinearAliasSpecifier, e.g. `solve(prob, alias = LinearAliasSpecifier(alias_A = true))" - Base.depwarn(message, :init) - Base.depwarn(message, :solve) - aliases = LinearAliasSpecifier(alias_A = values(kwargs).alias_A) - end - - if haskey(kwargs, :alias_b) - message = "`alias_b` keyword argument is deprecated, to set `alias_b`, - please use an LinearAliasSpecifier, e.g. `solve(prob, alias = LinearAliasSpecifier(alias_b = true))" - Base.depwarn(message, :init) - Base.depwarn(message, :solve) - aliases = LinearAliasSpecifier( - alias_A = aliases.alias_A, alias_b = values(kwargs).alias_b) - end - else - if alias isa Bool - aliases = LinearAliasSpecifier(alias = alias) - else - aliases = alias - end - end - - if isnothing(aliases.alias_A) - alias_A = default_alias_A(alg, prob.A, prob.b) - else - alias_A = aliases.alias_A - end - - if isnothing(aliases.alias_b) - alias_b = default_alias_b(alg, prob.A, prob.b) - else - alias_b = aliases.alias_b - end - - A = if alias_A || A isa SMatrix - A - elseif A isa Array - copy(A) - elseif issparsematrixcsc(A) - make_SparseMatrixCSC(A) - else - deepcopy(A) - end - - b = if issparsematrix(b) && !(A isa Diagonal) - Array(b) # the solution to a linear solve will always be dense! - elseif alias_b || b isa SVector - b - elseif b isa Array - copy(b) - elseif issparsematrixcsc(b) - # Extension must be loaded if issparsematrixcsc returns true - make_SparseMatrixCSC(b) - else - deepcopy(b) - end - - u0_ = u0 !== nothing ? u0 : __init_u0_from_Ab(A, b) - - # Guard against type mismatch for user-specified reltol/abstol - reltol = real(eltype(prob.b))(reltol) - abstol = real(eltype(prob.b))(abstol) - - precs = if hasproperty(alg, :precs) - isnothing(alg.precs) ? DEFAULT_PRECS : alg.precs - else - DEFAULT_PRECS - end - _Pl, _Pr = precs(A, p) - if isnothing(Pl) - Pl = _Pl - else - # TODO: deprecate once all docs are updated to the new form - #@warn "passing Preconditioners at `init`/`solve` time is deprecated. Instead add a `precs` function to your algorithm." - end - if isnothing(Pr) - Pr = _Pr - else - # TODO: deprecate once all docs are updated to the new form - #@warn "passing Preconditioners at `init`/`solve` time is deprecated. Instead add a `precs` function to your algorithm." - end - cacheval = init_cacheval(alg, A, b, u0_, Pl, Pr, maxiters, abstol, reltol, verbose, - assumptions) - isfresh = true - precsisfresh = false - Tc = typeof(cacheval) - - cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc, - typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq), - typeof(sensealg)}( - A, b, u0_, p, alg, cacheval, isfresh, precsisfresh, Pl, Pr, abstol, reltol, - maxiters, verbose, assumptions, sensealg) - return cache -end \ No newline at end of file diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index 97fef8d6f..3d4a035f3 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -2,7 +2,6 @@ using LinearSolve using ForwardDiff using Test using SparseArrays -using StaticArrays function h(p) (A = [p[1] p[2]+1 p[2]^3; @@ -194,39 +193,3 @@ overload_x_p = solve(prob, UMFPACKFactorization()) backslash_x_p = A \ b @test ≈(overload_x_p, backslash_x_p, rtol = 1e-9) - - -# Test StaticArrays -# They don't go through the overloads -# But we should test that the overloads don't mess anything up -function static_h(p) - A = @SMatrix [p[1] p[2]+1 p[2]^3; - 3*p[1] p[1]+5 p[2] * p[1]-4; - p[2]^2 9*p[1] p[2]] - - b = SA[p[1] + 1, p[2] * 2, p[1]^2] - - (A, b) -end - -function static_linsolve(p) - A, b = static_h(p) - prob = LinearProblem(A, b) - solve(prob) -end - -function static_backslash(p) - A, b = static_h(p) - A \ b -end - -@test (ForwardDiff.jacobian(static_linsolve, [5.0, 5.0]) ≈ - ForwardDiff.jacobian(static_backslash, [5.0, 5.0])) - -#Test to make sure that the cache is not a DualLinearCache -A, b = static_h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) - -static_dual_prob = LinearProblem(A, b) -static_dual_cache = init(static_dual_prob) - -@test static_dual_cache isa LinearSolve.LinearCache \ No newline at end of file