Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,3 +337,118 @@ function SciMLBase.solve(prob::StaticLinearProblem,
return SciMLBase.build_linear_solution(
alg, u, nothing, prob; retcode = ReturnCode.Success)
end

# Here to make sure that StaticLinearProblems with Dual elements don't create a Dual linear cache
function SciMLBase.init(prob::StaticLinearProblem, alg::SciMLLinearSolveAlgorithm,
args...;
alias = LinearAliasSpecifier(),
abstol = default_tol(real(eltype(prob.b))),
reltol = default_tol(real(eltype(prob.b))),
maxiters::Int = length(prob.b),
verbose::Bool = false,
Pl = nothing,
Pr = nothing,
assumptions = OperatorAssumptions(issquare(prob.A)),
sensealg = LinearSolveAdjoint(),
kwargs...)
(; A, b, u0, p) = prob

if haskey(kwargs, :alias_A) || haskey(kwargs, :alias_b)
aliases = LinearAliasSpecifier()

if haskey(kwargs, :alias_A)
message = "`alias_A` keyword argument is deprecated, to set `alias_A`,
please use an LinearAliasSpecifier, e.g. `solve(prob, alias = LinearAliasSpecifier(alias_A = true))"
Base.depwarn(message, :init)
Base.depwarn(message, :solve)
aliases = LinearAliasSpecifier(alias_A = values(kwargs).alias_A)
end

if haskey(kwargs, :alias_b)
message = "`alias_b` keyword argument is deprecated, to set `alias_b`,
please use an LinearAliasSpecifier, e.g. `solve(prob, alias = LinearAliasSpecifier(alias_b = true))"
Base.depwarn(message, :init)
Base.depwarn(message, :solve)
aliases = LinearAliasSpecifier(
alias_A = aliases.alias_A, alias_b = values(kwargs).alias_b)
end
else
if alias isa Bool
aliases = LinearAliasSpecifier(alias = alias)
else
aliases = alias
end
end

if isnothing(aliases.alias_A)
alias_A = default_alias_A(alg, prob.A, prob.b)
else
alias_A = aliases.alias_A
end

if isnothing(aliases.alias_b)
alias_b = default_alias_b(alg, prob.A, prob.b)
else
alias_b = aliases.alias_b
end

A = if alias_A || A isa SMatrix
A
elseif A isa Array
copy(A)
elseif issparsematrixcsc(A)
make_SparseMatrixCSC(A)
else
deepcopy(A)
end

b = if issparsematrix(b) && !(A isa Diagonal)
Array(b) # the solution to a linear solve will always be dense!
elseif alias_b || b isa SVector
b
elseif b isa Array
copy(b)
elseif issparsematrixcsc(b)
# Extension must be loaded if issparsematrixcsc returns true
make_SparseMatrixCSC(b)
else
deepcopy(b)
end

u0_ = u0 !== nothing ? u0 : __init_u0_from_Ab(A, b)

# Guard against type mismatch for user-specified reltol/abstol
reltol = real(eltype(prob.b))(reltol)
abstol = real(eltype(prob.b))(abstol)

precs = if hasproperty(alg, :precs)
isnothing(alg.precs) ? DEFAULT_PRECS : alg.precs
else
DEFAULT_PRECS
end
_Pl, _Pr = precs(A, p)
if isnothing(Pl)
Pl = _Pl
else
# TODO: deprecate once all docs are updated to the new form
#@warn "passing Preconditioners at `init`/`solve` time is deprecated. Instead add a `precs` function to your algorithm."
end
if isnothing(Pr)
Pr = _Pr
else
# TODO: deprecate once all docs are updated to the new form
#@warn "passing Preconditioners at `init`/`solve` time is deprecated. Instead add a `precs` function to your algorithm."
end
cacheval = init_cacheval(alg, A, b, u0_, Pl, Pr, maxiters, abstol, reltol, verbose,
assumptions)
isfresh = true
precsisfresh = false
Tc = typeof(cacheval)

cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc,
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq),
typeof(sensealg)}(
A, b, u0_, p, alg, cacheval, isfresh, precsisfresh, Pl, Pr, abstol, reltol,
maxiters, verbose, assumptions, sensealg)
return cache
end
37 changes: 37 additions & 0 deletions test/forwarddiff_overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using LinearSolve
using ForwardDiff
using Test
using SparseArrays
using StaticArrays

function h(p)
(A = [p[1] p[2]+1 p[2]^3;
Expand Down Expand Up @@ -193,3 +194,39 @@ overload_x_p = solve(prob, UMFPACKFactorization())
backslash_x_p = A \ b

@test ≈(overload_x_p, backslash_x_p, rtol = 1e-9)


# Test StaticArrays
# They don't go through the overloads
# But we should test that the overloads don't mess anything up
function static_h(p)
A = @SMatrix [p[1] p[2]+1 p[2]^3;
3*p[1] p[1]+5 p[2] * p[1]-4;
p[2]^2 9*p[1] p[2]]

b = SA[p[1] + 1, p[2] * 2, p[1]^2]

(A, b)
end

function static_linsolve(p)
A, b = static_h(p)
prob = LinearProblem(A, b)
solve(prob)
end

function static_backslash(p)
A, b = static_h(p)
A \ b
end

@test (ForwardDiff.jacobian(static_linsolve, [5.0, 5.0]) ≈
ForwardDiff.jacobian(static_backslash, [5.0, 5.0]))

#Test to make sure that the cache is not a DualLinearCache
A, b = static_h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])

static_dual_prob = LinearProblem(A, b)
static_dual_cache = init(static_dual_prob)

@test static_dual_cache isa LinearSolve.LinearCache
Loading