Skip to content

Commit a31f99f

Browse files
committed
Special handling for staticarrays
1 parent 3c56ab4 commit a31f99f

File tree

4 files changed

+67
-31
lines changed

4 files changed

+67
-31
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ LinearSolvePardisoExt = "Pardiso"
5959
LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools"
6060

6161
[compat]
62+
AllocCheck = "0.1"
6263
Aqua = "0.8"
6364
ArrayInterface = "7.4.11"
6465
BandedMatrices = "1"
@@ -109,6 +110,7 @@ UnPack = "1"
109110
julia = "1.9"
110111

111112
[extras]
113+
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
112114
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
113115
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
114116
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
@@ -133,4 +135,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
133135
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
134136

135137
[targets]
136-
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays"]
138+
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck"]

src/common.jl

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -169,31 +169,9 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
169169
isfresh = true
170170
Tc = typeof(cacheval)
171171

172-
cache = LinearCache{
173-
typeof(A),
174-
typeof(b),
175-
typeof(u0_),
176-
typeof(p),
177-
typeof(alg),
178-
Tc,
179-
typeof(Pl),
180-
typeof(Pr),
181-
typeof(reltol),
182-
typeof(assumptions.issq),
183-
}(A,
184-
b,
185-
u0_,
186-
p,
187-
alg,
188-
cacheval,
189-
isfresh,
190-
Pl,
191-
Pr,
192-
abstol,
193-
reltol,
194-
maxiters,
195-
verbose,
196-
assumptions)
172+
cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc,
173+
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq)}(A, b, u0_,
174+
p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions)
197175
return cache
198176
end
199177

@@ -210,3 +188,33 @@ end
210188
function SciMLBase.solve!(cache::LinearCache, args...; kwargs...)
211189
solve!(cache, cache.alg, args...; kwargs...)
212190
end
191+
192+
# Special Case for StaticArrays
193+
const StaticLinearProblem = LinearProblem{uType, iip, <:SMatrix,
194+
<:Union{<:SMatrix, <:SVector}} where {uType, iip}
195+
196+
function SciMLBase.solve(prob::StaticLinearProblem, args...; kwargs...)
197+
return SciMLBase.solve(prob, nothing, args...; kwargs...)
198+
end
199+
200+
function SciMLBase.solve(prob::StaticLinearProblem,
201+
alg::Union{Nothing, SciMLLinearSolveAlgorithm}, args...; kwargs...)
202+
if alg === nothing || alg isa DirectLdiv!
203+
u = prob.A \ prob.b
204+
elseif alg isa LUFactorization
205+
u = lu(prob.A) \ prob.b
206+
elseif alg isa QRFactorization
207+
u = qr(prob.A) \ prob.b
208+
elseif alg isa CholeskyFactorization
209+
u = cholesky(prob.A) \ prob.b
210+
elseif alg isa NormalCholeskyFactorization
211+
u = cholesky(Symmetric(prob.A' * prob.A)) \ (prob.A' * prob.b)
212+
elseif alg isa SVDFactorization
213+
u = svd(prob.A) \ prob.b
214+
else
215+
# Slower Path but handles all cases
216+
cache = init(prob, alg, args...; kwargs...)
217+
return solve!(cache)
218+
end
219+
return SciMLBase.build_linear_solution(alg, u, nothing, prob)
220+
end

src/factorization.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -975,11 +975,17 @@ function init_cacheval(alg::NormalCholeskyFactorization,
975975
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A))
976976
end
977977

978+
function init_cacheval(alg::NormalCholeskyFactorization, A::SMatrix, b, u, Pl, Pr,
979+
maxiters::Int, abstol, reltol, verbose::Bool,
980+
assumptions::OperatorAssumptions)
981+
return cholesky(Symmetric((A)' * A))
982+
end
983+
978984
function init_cacheval(alg::NormalCholeskyFactorization, A, b, u, Pl, Pr,
979985
maxiters::Int, abstol, reltol, verbose::Bool,
980986
assumptions::OperatorAssumptions)
981987
A_ = convert(AbstractMatrix, A)
982-
ArrayInterface.cholesky_instance(Symmetric((A)' * A, :L), alg.pivot)
988+
return ArrayInterface.cholesky_instance(Symmetric((A)' * A), alg.pivot)
983989
end
984990

985991
function init_cacheval(alg::NormalCholeskyFactorization,
@@ -993,17 +999,20 @@ function SciMLBase.solve!(cache::LinearCache, alg::NormalCholeskyFactorization;
993999
A = cache.A
9941000
A = convert(AbstractMatrix, A)
9951001
if cache.isfresh
996-
if A isa SparseMatrixCSC || A isa GPUArraysCore.AbstractGPUArray
997-
fact = cholesky(Symmetric((A)' * A, :L); check = false)
1002+
if A isa SparseMatrixCSC || A isa GPUArraysCore.AbstractGPUArray || A isa SMatrix
1003+
fact = cholesky(Symmetric((A)' * A); check = false)
9981004
else
999-
fact = cholesky(Symmetric((A)' * A, :L), alg.pivot; check = false)
1005+
fact = cholesky(Symmetric((A)' * A), alg.pivot; check = false)
10001006
end
10011007
cache.cacheval = fact
10021008
cache.isfresh = false
10031009
end
10041010
if A isa SparseMatrixCSC
10051011
cache.u .= @get_cacheval(cache, :NormalCholeskyFactorization) \ (A' * cache.b)
10061012
y = cache.u
1013+
elseif A isa StaticArray
1014+
cache.u = @get_cacheval(cache, :NormalCholeskyFactorization) \ (A' * cache.b)
1015+
y = cache.u
10071016
else
10081017
y = ldiv!(cache.u, @get_cacheval(cache, :NormalCholeskyFactorization), A' * cache.b)
10091018
end

test/static_arrays.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,30 @@
11
using LinearSolve, StaticArrays, LinearAlgebra, Test
2+
using AllocCheck
23

34
A = SMatrix{5, 5}(Hermitian(rand(5, 5) + I))
45
b = SVector{5}(rand(5))
56

7+
@check_allocs __solve_no_alloc(A, b, alg) = solve(LinearProblem(A, b), alg)
8+
9+
function __non_native_static_array_alg(alg)
10+
return alg isa SVDFactorization || alg isa KrylovJL
11+
end
12+
613
for alg in (nothing, LUFactorization(), SVDFactorization(), CholeskyFactorization(),
7-
KrylovJL_GMRES())
14+
NormalCholeskyFactorization(), KrylovJL_GMRES())
815
sol = solve(LinearProblem(A, b), alg)
916
@inferred solve(LinearProblem(A, b), alg)
1017
@test norm(A * sol .- b) < 1e-10
18+
19+
if __non_native_static_array_alg(alg)
20+
@test_broken __solve_no_alloc(A, b, alg)
21+
else
22+
@test_nowarn __solve_no_alloc(A, b, alg)
23+
end
24+
25+
cache = init(LinearProblem(A, b), alg)
26+
sol = solve!(cache)
27+
@test norm(A * sol .- b) < 1e-10
1128
end
1229

1330
A = SMatrix{7, 5}(rand(7, 5))

0 commit comments

Comments
 (0)