Skip to content

Commit b8fdd8d

Browse files
Merge pull request #650 from jClugstor/forwarddiff_static_tests
Add test for ForwardDiff with StaticArrays
2 parents 873c3ae + ff84c9a commit b8fdd8d

File tree

2 files changed

+152
-0
lines changed

2 files changed

+152
-0
lines changed

src/common.jl

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,3 +337,118 @@ function SciMLBase.solve(prob::StaticLinearProblem,
337337
return SciMLBase.build_linear_solution(
338338
alg, u, nothing, prob; retcode = ReturnCode.Success)
339339
end
340+
341+
# Here to make sure that StaticLinearProblems with Dual elements don't create a Dual linear cache
342+
function SciMLBase.init(prob::StaticLinearProblem, alg::SciMLLinearSolveAlgorithm,
343+
args...;
344+
alias = LinearAliasSpecifier(),
345+
abstol = default_tol(real(eltype(prob.b))),
346+
reltol = default_tol(real(eltype(prob.b))),
347+
maxiters::Int = length(prob.b),
348+
verbose::Bool = false,
349+
Pl = nothing,
350+
Pr = nothing,
351+
assumptions = OperatorAssumptions(issquare(prob.A)),
352+
sensealg = LinearSolveAdjoint(),
353+
kwargs...)
354+
(; A, b, u0, p) = prob
355+
356+
if haskey(kwargs, :alias_A) || haskey(kwargs, :alias_b)
357+
aliases = LinearAliasSpecifier()
358+
359+
if haskey(kwargs, :alias_A)
360+
message = "`alias_A` keyword argument is deprecated, to set `alias_A`,
361+
please use an LinearAliasSpecifier, e.g. `solve(prob, alias = LinearAliasSpecifier(alias_A = true))"
362+
Base.depwarn(message, :init)
363+
Base.depwarn(message, :solve)
364+
aliases = LinearAliasSpecifier(alias_A = values(kwargs).alias_A)
365+
end
366+
367+
if haskey(kwargs, :alias_b)
368+
message = "`alias_b` keyword argument is deprecated, to set `alias_b`,
369+
please use an LinearAliasSpecifier, e.g. `solve(prob, alias = LinearAliasSpecifier(alias_b = true))"
370+
Base.depwarn(message, :init)
371+
Base.depwarn(message, :solve)
372+
aliases = LinearAliasSpecifier(
373+
alias_A = aliases.alias_A, alias_b = values(kwargs).alias_b)
374+
end
375+
else
376+
if alias isa Bool
377+
aliases = LinearAliasSpecifier(alias = alias)
378+
else
379+
aliases = alias
380+
end
381+
end
382+
383+
if isnothing(aliases.alias_A)
384+
alias_A = default_alias_A(alg, prob.A, prob.b)
385+
else
386+
alias_A = aliases.alias_A
387+
end
388+
389+
if isnothing(aliases.alias_b)
390+
alias_b = default_alias_b(alg, prob.A, prob.b)
391+
else
392+
alias_b = aliases.alias_b
393+
end
394+
395+
A = if alias_A || A isa SMatrix
396+
A
397+
elseif A isa Array
398+
copy(A)
399+
elseif issparsematrixcsc(A)
400+
make_SparseMatrixCSC(A)
401+
else
402+
deepcopy(A)
403+
end
404+
405+
b = if issparsematrix(b) && !(A isa Diagonal)
406+
Array(b) # the solution to a linear solve will always be dense!
407+
elseif alias_b || b isa SVector
408+
b
409+
elseif b isa Array
410+
copy(b)
411+
elseif issparsematrixcsc(b)
412+
# Extension must be loaded if issparsematrixcsc returns true
413+
make_SparseMatrixCSC(b)
414+
else
415+
deepcopy(b)
416+
end
417+
418+
u0_ = u0 !== nothing ? u0 : __init_u0_from_Ab(A, b)
419+
420+
# Guard against type mismatch for user-specified reltol/abstol
421+
reltol = real(eltype(prob.b))(reltol)
422+
abstol = real(eltype(prob.b))(abstol)
423+
424+
precs = if hasproperty(alg, :precs)
425+
isnothing(alg.precs) ? DEFAULT_PRECS : alg.precs
426+
else
427+
DEFAULT_PRECS
428+
end
429+
_Pl, _Pr = precs(A, p)
430+
if isnothing(Pl)
431+
Pl = _Pl
432+
else
433+
# TODO: deprecate once all docs are updated to the new form
434+
#@warn "passing Preconditioners at `init`/`solve` time is deprecated. Instead add a `precs` function to your algorithm."
435+
end
436+
if isnothing(Pr)
437+
Pr = _Pr
438+
else
439+
# TODO: deprecate once all docs are updated to the new form
440+
#@warn "passing Preconditioners at `init`/`solve` time is deprecated. Instead add a `precs` function to your algorithm."
441+
end
442+
cacheval = init_cacheval(alg, A, b, u0_, Pl, Pr, maxiters, abstol, reltol, verbose,
443+
assumptions)
444+
isfresh = true
445+
precsisfresh = false
446+
Tc = typeof(cacheval)
447+
448+
cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc,
449+
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq),
450+
typeof(sensealg)}(
451+
A, b, u0_, p, alg, cacheval, isfresh, precsisfresh, Pl, Pr, abstol, reltol,
452+
maxiters, verbose, assumptions, sensealg)
453+
return cache
454+
end

test/forwarddiff_overloads.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using LinearSolve
22
using ForwardDiff
33
using Test
44
using SparseArrays
5+
using StaticArrays
56

67
function h(p)
78
(A = [p[1] p[2]+1 p[2]^3;
@@ -193,3 +194,39 @@ overload_x_p = solve(prob, UMFPACKFactorization())
193194
backslash_x_p = A \ b
194195

195196
@test (overload_x_p, backslash_x_p, rtol = 1e-9)
197+
198+
199+
# Test StaticArrays
200+
# They don't go through the overloads
201+
# But we should test that the overloads don't mess anything up
202+
function static_h(p)
203+
A = @SMatrix [p[1] p[2]+1 p[2]^3;
204+
3*p[1] p[1]+5 p[2] * p[1]-4;
205+
p[2]^2 9*p[1] p[2]]
206+
207+
b = SA[p[1] + 1, p[2] * 2, p[1]^2]
208+
209+
(A, b)
210+
end
211+
212+
function static_linsolve(p)
213+
A, b = static_h(p)
214+
prob = LinearProblem(A, b)
215+
solve(prob)
216+
end
217+
218+
function static_backslash(p)
219+
A, b = static_h(p)
220+
A \ b
221+
end
222+
223+
@test (ForwardDiff.jacobian(static_linsolve, [5.0, 5.0])
224+
ForwardDiff.jacobian(static_backslash, [5.0, 5.0]))
225+
226+
#Test to make sure that the cache is not a DualLinearCache
227+
A, b = static_h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
228+
229+
static_dual_prob = LinearProblem(A, b)
230+
static_dual_cache = init(static_dual_prob)
231+
232+
@test static_dual_cache isa LinearSolve.LinearCache

0 commit comments

Comments
 (0)