Skip to content

Commit caaedab

Browse files
Merge pull request #444 from avik-pal/ap/static_arrays
Proper handling of static arrays
2 parents 9e22a72 + a31f99f commit caaedab

File tree

6 files changed

+99
-60
lines changed

6 files changed

+99
-60
lines changed

Project.toml

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,13 @@ LinearSolvePardisoExt = "Pardiso"
5959
LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools"
6060

6161
[compat]
62+
AllocCheck = "0.1"
6263
Aqua = "0.8"
6364
ArrayInterface = "7.4.11"
6465
BandedMatrices = "1"
6566
BlockDiagonals = "0.1"
66-
ConcreteStructs = "0.2"
6767
CUDA = "5"
68+
ConcreteStructs = "0.2"
6869
DocStringExtensions = "0.9"
6970
EnumX = "1"
7071
Enzyme = "0.11"
@@ -77,15 +78,15 @@ GPUArraysCore = "0.1"
7778
HYPRE = "1.4.0"
7879
InteractiveUtils = "1.6"
7980
IterativeSolvers = "0.9.3"
80-
Libdl = "1.6"
81-
LinearAlgebra = "1.9"
8281
JET = "0.8"
8382
KLU = "0.3.0, 0.4"
8483
KernelAbstractions = "0.9"
8584
Krylov = "0.9"
8685
KrylovKit = "0.6"
87-
Metal = "0.5"
86+
Libdl = "1.6"
87+
LinearAlgebra = "1.9"
8888
MPI = "0.20"
89+
Metal = "0.5"
8990
MultiFloats = "1"
9091
Pardiso = "0.5"
9192
Pkg = "1"
@@ -102,13 +103,14 @@ SciMLOperators = "0.3"
102103
Setfield = "1"
103104
SparseArrays = "1.9"
104105
Sparspak = "0.3.6"
105-
StaticArraysCore = "1"
106106
StaticArrays = "1"
107+
StaticArraysCore = "1"
107108
Test = "1"
108109
UnPack = "1"
109110
julia = "1.9"
110111

111112
[extras]
113+
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
112114
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
113115
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
114116
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
@@ -133,4 +135,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
133135
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
134136

135137
[targets]
136-
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays"]
138+
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck"]

src/common.jl

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,13 @@ default_alias_b(::Any, ::Any, ::Any) = false
119119
default_alias_A(::AbstractKrylovSubspaceMethod, ::Any, ::Any) = true
120120
default_alias_b(::AbstractKrylovSubspaceMethod, ::Any, ::Any) = true
121121

122+
function __init_u0_from_Ab(A, b)
123+
u0 = similar(b, size(A, 2))
124+
fill!(u0, false)
125+
return u0
126+
end
127+
__init_u0_from_Ab(::SMatrix{S1, S2}, b) where {S1, S2} = zeros(SVector{S2, eltype(b)})
128+
122129
function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
123130
args...;
124131
alias_A = default_alias_A(alg, prob.A, prob.b),
@@ -133,7 +140,7 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
133140
kwargs...)
134141
@unpack A, b, u0, p = prob
135142

136-
A = if alias_A
143+
A = if alias_A || A isa SMatrix
137144
A
138145
elseif A isa Array || A isa SparseMatrixCSC
139146
copy(A)
@@ -143,55 +150,28 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
143150

144151
b = if b isa SparseArrays.AbstractSparseArray && !(A isa Diagonal)
145152
Array(b) # the solution to a linear solve will always be dense!
146-
elseif alias_b
153+
elseif alias_b || b isa SVector
147154
b
148155
elseif b isa Array || b isa SparseMatrixCSC
149156
copy(b)
150157
else
151158
deepcopy(b)
152159
end
153160

154-
u0 = if u0 !== nothing
155-
u0
156-
else
157-
u0 = similar(b, size(A, 2))
158-
fill!(u0, false)
159-
end
161+
u0_ = u0 !== nothing ? u0 : __init_u0_from_Ab(A, b)
160162

161163
# Guard against type mismatch for user-specified reltol/abstol
162164
reltol = real(eltype(prob.b))(reltol)
163165
abstol = real(eltype(prob.b))(abstol)
164166

165-
cacheval = init_cacheval(alg, A, b, u0, Pl, Pr, maxiters, abstol, reltol, verbose,
167+
cacheval = init_cacheval(alg, A, b, u0_, Pl, Pr, maxiters, abstol, reltol, verbose,
166168
assumptions)
167169
isfresh = true
168170
Tc = typeof(cacheval)
169171

170-
cache = LinearCache{
171-
typeof(A),
172-
typeof(b),
173-
typeof(u0),
174-
typeof(p),
175-
typeof(alg),
176-
Tc,
177-
typeof(Pl),
178-
typeof(Pr),
179-
typeof(reltol),
180-
typeof(assumptions.issq),
181-
}(A,
182-
b,
183-
u0,
184-
p,
185-
alg,
186-
cacheval,
187-
isfresh,
188-
Pl,
189-
Pr,
190-
abstol,
191-
reltol,
192-
maxiters,
193-
verbose,
194-
assumptions)
172+
cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc,
173+
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq)}(A, b, u0_,
174+
p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions)
195175
return cache
196176
end
197177

@@ -208,3 +188,33 @@ end
208188
function SciMLBase.solve!(cache::LinearCache, args...; kwargs...)
209189
solve!(cache, cache.alg, args...; kwargs...)
210190
end
191+
192+
# Special Case for StaticArrays
193+
const StaticLinearProblem = LinearProblem{uType, iip, <:SMatrix,
194+
<:Union{<:SMatrix, <:SVector}} where {uType, iip}
195+
196+
function SciMLBase.solve(prob::StaticLinearProblem, args...; kwargs...)
197+
return SciMLBase.solve(prob, nothing, args...; kwargs...)
198+
end
199+
200+
function SciMLBase.solve(prob::StaticLinearProblem,
201+
alg::Union{Nothing, SciMLLinearSolveAlgorithm}, args...; kwargs...)
202+
if alg === nothing || alg isa DirectLdiv!
203+
u = prob.A \ prob.b
204+
elseif alg isa LUFactorization
205+
u = lu(prob.A) \ prob.b
206+
elseif alg isa QRFactorization
207+
u = qr(prob.A) \ prob.b
208+
elseif alg isa CholeskyFactorization
209+
u = cholesky(prob.A) \ prob.b
210+
elseif alg isa NormalCholeskyFactorization
211+
u = cholesky(Symmetric(prob.A' * prob.A)) \ (prob.A' * prob.b)
212+
elseif alg isa SVDFactorization
213+
u = svd(prob.A) \ prob.b
214+
else
215+
# Slower Path but handles all cases
216+
cache = init(prob, alg, args...; kwargs...)
217+
return solve!(cache)
218+
end
219+
return SciMLBase.build_linear_solution(alg, u, nothing, prob)
220+
end

src/default.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@ function defaultalg(A, b, assump::OperatorAssumptions{Nothing})
3636
defaultalg(A, b, OperatorAssumptions(issq, assump.condition))
3737
end
3838

39+
function defaultalg(A::SMatrix{S1, S2}, b, assump::OperatorAssumptions{Bool}) where {S1, S2}
40+
if S1 == S2
41+
return LUFactorization()
42+
else
43+
return SVDFactorization() # QR(...) \ b is not defined currently
44+
end
45+
end
46+
3947
function defaultalg(A::Tridiagonal, b, assump::OperatorAssumptions{Bool})
4048
if assump.issq
4149
DefaultLinearSolver(DefaultAlgorithmChoice.LUFactorization)
@@ -175,10 +183,6 @@ function defaultalg(A, b, assump::OperatorAssumptions{Bool})
175183
DefaultAlgorithmChoice.LUFactorization
176184
end
177185

178-
# For static arrays GMRES allocates a lot. Use factorization
179-
elseif A isa StaticArray
180-
DefaultAlgorithmChoice.LUFactorization
181-
182186
# This catches the cases where a factorization overload could exist
183187
# For example, BlockBandedMatrix
184188
elseif A !== nothing && ArrayInterface.isstructured(A)
@@ -190,9 +194,6 @@ function defaultalg(A, b, assump::OperatorAssumptions{Bool})
190194
end
191195
elseif assump.condition === OperatorCondition.WellConditioned
192196
DefaultAlgorithmChoice.NormalCholeskyFactorization
193-
elseif A isa StaticArray
194-
# Static Array doesn't have QR() \ b defined
195-
DefaultAlgorithmChoice.SVDFactorization
196197
elseif assump.condition === OperatorCondition.IllConditioned
197198
if is_underdetermined(A)
198199
# Underdetermined
@@ -269,8 +270,7 @@ function SciMLBase.init(prob::LinearProblem, alg::Nothing,
269270
args...;
270271
assumptions = OperatorAssumptions(issquare(prob.A)),
271272
kwargs...)
272-
alg = defaultalg(prob.A, prob.b, assumptions)
273-
SciMLBase.init(prob, alg, args...; assumptions, kwargs...)
273+
SciMLBase.init(prob, defaultalg(prob.A, prob.b, assumptions), args...; assumptions, kwargs...)
274274
end
275275

276276
function SciMLBase.solve!(cache::LinearCache, alg::Nothing,

src/factorization.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,6 @@ end
215215
function init_cacheval(alg::CholeskyFactorization, A::SMatrix{S1, S2}, b, u, Pl, Pr,
216216
maxiters::Int, abstol, reltol, verbose::Bool,
217217
assumptions::OperatorAssumptions) where {S1, S2}
218-
# StaticArrays doesn't have the pivot argument. Prevent generic fallback.
219-
# CholeskyFactorization is part of DefaultLinearSolver, so it is possible that `A` is
220-
# not Hermitian.
221-
(!issquare(A) || !ishermitian(A)) && return nothing
222218
cholesky(A)
223219
end
224220

@@ -979,11 +975,17 @@ function init_cacheval(alg::NormalCholeskyFactorization,
979975
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A))
980976
end
981977

978+
function init_cacheval(alg::NormalCholeskyFactorization, A::SMatrix, b, u, Pl, Pr,
979+
maxiters::Int, abstol, reltol, verbose::Bool,
980+
assumptions::OperatorAssumptions)
981+
return cholesky(Symmetric((A)' * A))
982+
end
983+
982984
function init_cacheval(alg::NormalCholeskyFactorization, A, b, u, Pl, Pr,
983985
maxiters::Int, abstol, reltol, verbose::Bool,
984986
assumptions::OperatorAssumptions)
985987
A_ = convert(AbstractMatrix, A)
986-
ArrayInterface.cholesky_instance(Symmetric((A)' * A, :L), alg.pivot)
988+
return ArrayInterface.cholesky_instance(Symmetric((A)' * A), alg.pivot)
987989
end
988990

989991
function init_cacheval(alg::NormalCholeskyFactorization,
@@ -997,17 +999,20 @@ function SciMLBase.solve!(cache::LinearCache, alg::NormalCholeskyFactorization;
997999
A = cache.A
9981000
A = convert(AbstractMatrix, A)
9991001
if cache.isfresh
1000-
if A isa SparseMatrixCSC || A isa GPUArraysCore.AbstractGPUArray
1001-
fact = cholesky(Symmetric((A)' * A, :L); check = false)
1002+
if A isa SparseMatrixCSC || A isa GPUArraysCore.AbstractGPUArray || A isa SMatrix
1003+
fact = cholesky(Symmetric((A)' * A); check = false)
10021004
else
1003-
fact = cholesky(Symmetric((A)' * A, :L), alg.pivot; check = false)
1005+
fact = cholesky(Symmetric((A)' * A), alg.pivot; check = false)
10041006
end
10051007
cache.cacheval = fact
10061008
cache.isfresh = false
10071009
end
10081010
if A isa SparseMatrixCSC
10091011
cache.u .= @get_cacheval(cache, :NormalCholeskyFactorization) \ (A' * cache.b)
10101012
y = cache.u
1013+
elseif A isa StaticArray
1014+
cache.u = @get_cacheval(cache, :NormalCholeskyFactorization) \ (A' * cache.b)
1015+
y = cache.u
10111016
else
10121017
y = ldiv!(cache.u, @get_cacheval(cache, :NormalCholeskyFactorization), A' * cache.b)
10131018
end

src/iterative_wrappers.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,10 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...)
284284

285285
# Copy the solution to the allocated output vector
286286
cacheval = @get_cacheval(cache, :KrylovJL_GMRES)
287-
if cache.u !== cacheval.x
287+
if cache.u !== cacheval.x && ArrayInterface.can_setindex(cache.u)
288288
cache.u .= cacheval.x
289+
else
290+
cache.u = convert(typeof(cache.u), cacheval.x)
289291
end
290292

291293
return SciMLBase.build_linear_solution(alg, cache.u, resid, cache;

test/static_arrays.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,44 @@
1-
using LinearSolve, StaticArrays, LinearAlgebra
1+
using LinearSolve, StaticArrays, LinearAlgebra, Test
2+
using AllocCheck
23

34
A = SMatrix{5, 5}(Hermitian(rand(5, 5) + I))
45
b = SVector{5}(rand(5))
56

7+
@check_allocs __solve_no_alloc(A, b, alg) = solve(LinearProblem(A, b), alg)
8+
9+
function __non_native_static_array_alg(alg)
10+
return alg isa SVDFactorization || alg isa KrylovJL
11+
end
12+
613
for alg in (nothing, LUFactorization(), SVDFactorization(), CholeskyFactorization(),
7-
KrylovJL_GMRES())
14+
NormalCholeskyFactorization(), KrylovJL_GMRES())
815
sol = solve(LinearProblem(A, b), alg)
16+
@inferred solve(LinearProblem(A, b), alg)
17+
@test norm(A * sol .- b) < 1e-10
18+
19+
if __non_native_static_array_alg(alg)
20+
@test_broken __solve_no_alloc(A, b, alg)
21+
else
22+
@test_nowarn __solve_no_alloc(A, b, alg)
23+
end
24+
25+
cache = init(LinearProblem(A, b), alg)
26+
sol = solve!(cache)
927
@test norm(A * sol .- b) < 1e-10
1028
end
1129

1230
A = SMatrix{7, 5}(rand(7, 5))
1331
b = SVector{7}(rand(7))
1432

1533
for alg in (nothing, SVDFactorization(), KrylovJL_LSMR())
34+
@inferred solve(LinearProblem(A, b), alg)
1635
@test_nowarn solve(LinearProblem(A, b), alg)
1736
end
1837

1938
A = SMatrix{5, 7}(rand(5, 7))
2039
b = SVector{5}(rand(5))
2140

2241
for alg in (nothing, SVDFactorization(), KrylovJL_LSMR())
42+
@inferred solve(LinearProblem(A, b), alg)
2343
@test_nowarn solve(LinearProblem(A, b), alg)
2444
end

0 commit comments

Comments
 (0)