Skip to content

Commit 8b4e443

Browse files
committed
add init for StaticLinearProblem
1 parent 949e8e8 commit 8b4e443

File tree

1 file changed

+114
-0
lines changed

1 file changed

+114
-0
lines changed

src/common.jl

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,3 +337,117 @@ function SciMLBase.solve(prob::StaticLinearProblem,
337337
return SciMLBase.build_linear_solution(
338338
alg, u, nothing, prob; retcode = ReturnCode.Success)
339339
end
340+
341+
function SciMLBase.init(prob::StaticLinearProblem, alg::SciMLLinearSolveAlgorithm,
342+
args...;
343+
alias = LinearAliasSpecifier(),
344+
abstol = default_tol(real(eltype(prob.b))),
345+
reltol = default_tol(real(eltype(prob.b))),
346+
maxiters::Int = length(prob.b),
347+
verbose::Bool = false,
348+
Pl = nothing,
349+
Pr = nothing,
350+
assumptions = OperatorAssumptions(issquare(prob.A)),
351+
sensealg = LinearSolveAdjoint(),
352+
kwargs...)
353+
(; A, b, u0, p) = prob
354+
355+
if haskey(kwargs, :alias_A) || haskey(kwargs, :alias_b)
356+
aliases = LinearAliasSpecifier()
357+
358+
if haskey(kwargs, :alias_A)
359+
message = "`alias_A` keyword argument is deprecated, to set `alias_A`,
360+
please use an LinearAliasSpecifier, e.g. `solve(prob, alias = LinearAliasSpecifier(alias_A = true))"
361+
Base.depwarn(message, :init)
362+
Base.depwarn(message, :solve)
363+
aliases = LinearAliasSpecifier(alias_A = values(kwargs).alias_A)
364+
end
365+
366+
if haskey(kwargs, :alias_b)
367+
message = "`alias_b` keyword argument is deprecated, to set `alias_b`,
368+
please use an LinearAliasSpecifier, e.g. `solve(prob, alias = LinearAliasSpecifier(alias_b = true))"
369+
Base.depwarn(message, :init)
370+
Base.depwarn(message, :solve)
371+
aliases = LinearAliasSpecifier(
372+
alias_A = aliases.alias_A, alias_b = values(kwargs).alias_b)
373+
end
374+
else
375+
if alias isa Bool
376+
aliases = LinearAliasSpecifier(alias = alias)
377+
else
378+
aliases = alias
379+
end
380+
end
381+
382+
if isnothing(aliases.alias_A)
383+
alias_A = default_alias_A(alg, prob.A, prob.b)
384+
else
385+
alias_A = aliases.alias_A
386+
end
387+
388+
if isnothing(aliases.alias_b)
389+
alias_b = default_alias_b(alg, prob.A, prob.b)
390+
else
391+
alias_b = aliases.alias_b
392+
end
393+
394+
A = if alias_A || A isa SMatrix
395+
A
396+
elseif A isa Array
397+
copy(A)
398+
elseif issparsematrixcsc(A)
399+
make_SparseMatrixCSC(A)
400+
else
401+
deepcopy(A)
402+
end
403+
404+
b = if issparsematrix(b) && !(A isa Diagonal)
405+
Array(b) # the solution to a linear solve will always be dense!
406+
elseif alias_b || b isa SVector
407+
b
408+
elseif b isa Array
409+
copy(b)
410+
elseif issparsematrixcsc(b)
411+
# Extension must be loaded if issparsematrixcsc returns true
412+
make_SparseMatrixCSC(b)
413+
else
414+
deepcopy(b)
415+
end
416+
417+
u0_ = u0 !== nothing ? u0 : __init_u0_from_Ab(A, b)
418+
419+
# Guard against type mismatch for user-specified reltol/abstol
420+
reltol = real(eltype(prob.b))(reltol)
421+
abstol = real(eltype(prob.b))(abstol)
422+
423+
precs = if hasproperty(alg, :precs)
424+
isnothing(alg.precs) ? DEFAULT_PRECS : alg.precs
425+
else
426+
DEFAULT_PRECS
427+
end
428+
_Pl, _Pr = precs(A, p)
429+
if isnothing(Pl)
430+
Pl = _Pl
431+
else
432+
# TODO: deprecate once all docs are updated to the new form
433+
#@warn "passing Preconditioners at `init`/`solve` time is deprecated. Instead add a `precs` function to your algorithm."
434+
end
435+
if isnothing(Pr)
436+
Pr = _Pr
437+
else
438+
# TODO: deprecate once all docs are updated to the new form
439+
#@warn "passing Preconditioners at `init`/`solve` time is deprecated. Instead add a `precs` function to your algorithm."
440+
end
441+
cacheval = init_cacheval(alg, A, b, u0_, Pl, Pr, maxiters, abstol, reltol, verbose,
442+
assumptions)
443+
isfresh = true
444+
precsisfresh = false
445+
Tc = typeof(cacheval)
446+
447+
cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc,
448+
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq),
449+
typeof(sensealg)}(
450+
A, b, u0_, p, alg, cacheval, isfresh, precsisfresh, Pl, Pr, abstol, reltol,
451+
maxiters, verbose, assumptions, sensealg)
452+
return cache
453+
end

0 commit comments

Comments
 (0)