Skip to content

Commit 6d678a9

Browse files
Merge pull request #318 from SciML/inference
add JET tests and rely less on constant prop
2 parents 208cbc6 + ab451fe commit 6d678a9

File tree

6 files changed

+68
-73
lines changed

6 files changed

+68
-73
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LinearSolve"
22
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
33
authors = ["SciML"]
4-
version = "2.1.0"
4+
version = "2.1.1"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -60,6 +60,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
6060
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
6161
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
6262
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
63+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
6364
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
6465
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
6566
MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5"
@@ -69,7 +70,7 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
6970
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7071

7172
[targets]
72-
test = ["Test", "IterativeSolvers", "InteractiveUtils", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI"]
73+
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI"]
7374

7475
[weakdeps]
7576
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

src/LinearSolve.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,29 @@ _isidentity_struct(::SciMLBase.DiffEqIdentity) = true
5555

5656
const INCLUDE_SPARSE = Preferences.@load_preference("include_sparse", Base.USE_GPL_LIBS)
5757

58+
EnumX.@enumx DefaultAlgorithmChoice begin
59+
LUFactorization
60+
QRFactorization
61+
DiagonalFactorization
62+
DirectLdiv!
63+
SparspakFactorization
64+
KLUFactorization
65+
UMFPACKFactorization
66+
KrylovJL_GMRES
67+
GenericLUFactorization
68+
RFLUFactorization
69+
LDLtFactorization
70+
BunchKaufmanFactorization
71+
CHOLMODFactorization
72+
SVDFactorization
73+
CholeskyFactorization
74+
NormalCholeskyFactorization
75+
end
76+
77+
struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm
78+
alg::DefaultAlgorithmChoice.T
79+
end
80+
5881
include("common.jl")
5982
include("factorization.jl")
6083
include("simplelu.jl")
@@ -74,7 +97,7 @@ include("deprecated.jl")
7497
cache.cacheval = fact
7598
cache.isfresh = false
7699
end
77-
y = _ldiv!(cache.u, get_cacheval(cache, $(Meta.quot(defaultalg_symbol(alg)))),
100+
y = _ldiv!(cache.u, @get_cacheval(cache, $(Meta.quot(defaultalg_symbol(alg)))),
78101
cache.b)
79102

80103
#=

src/common.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
121121
verbose::Bool = false,
122122
Pl = IdentityOperator(size(prob.A)[1]),
123123
Pr = IdentityOperator(size(prob.A)[2]),
124-
assumptions = OperatorAssumptions(Val(issquare(prob.A))),
124+
assumptions = OperatorAssumptions(issquare(prob.A)),
125125
kwargs...)
126126
@unpack A, b, u0, p = prob
127127

src/default.jl

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,3 @@
1-
EnumX.@enumx DefaultAlgorithmChoice begin
2-
LUFactorization
3-
QRFactorization
4-
DiagonalFactorization
5-
DirectLdiv!
6-
SparspakFactorization
7-
KLUFactorization
8-
UMFPACKFactorization
9-
KrylovJL_GMRES
10-
GenericLUFactorization
11-
RFLUFactorization
12-
LDLtFactorization
13-
BunchKaufmanFactorization
14-
CHOLMODFactorization
15-
SVDFactorization
16-
CholeskyFactorization
17-
NormalCholeskyFactorization
18-
end
19-
20-
struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm
21-
alg::DefaultAlgorithmChoice.T
22-
end
23-
241
mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
252
T13, T14, T15, T16}
263
LUFactorization::T1
@@ -313,35 +290,6 @@ cache.cacheval = NamedTuple(LUFactorization = cache of LUFactorization, ...)
313290
Expr(:call, :DefaultLinearSolverInit, caches...)
314291
end
315292

316-
"""
317-
if algsym === :LUFactorization
318-
cache.cacheval.LUFactorization = ...
319-
else
320-
...
321-
end
322-
"""
323-
@generated function get_cacheval(cache::LinearCache, algsym::Symbol)
324-
ex = :()
325-
for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T))
326-
ex = if ex == :()
327-
Expr(:elseif, :(algsym === $(Meta.quot(alg))),
328-
:(getfield(cache.cacheval, $(Meta.quot(alg)))))
329-
else
330-
Expr(:elseif, :(algsym === $(Meta.quot(alg))),
331-
:(getfield(cache.cacheval, $(Meta.quot(alg)))), ex)
332-
end
333-
end
334-
ex = Expr(:if, ex.args...)
335-
336-
quote
337-
if cache.alg isa DefaultLinearSolver
338-
$ex
339-
else
340-
cache.cacheval
341-
end
342-
end
343-
end
344-
345293
function defaultalg_symbol(::Type{T}) where {T}
346294
Symbol(split(string(SciMLBase.parameterless_type(T)), ".")[end])
347295
end

src/factorization.jl

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
macro get_cacheval(cache, algsym)
2+
quote
3+
if $(esc(cache)).alg isa DefaultLinearSolver
4+
getfield($(esc(cache)).cacheval, $algsym)
5+
else
6+
$(esc(cache)).cacheval
7+
end
8+
end
9+
end
10+
111
_ldiv!(x, A, b) = ldiv!(x, A, b)
212

313
function _ldiv!(x::Vector, A::Factorization, b::Vector)
@@ -712,11 +722,11 @@ function SciMLBase.solve!(cache::LinearCache, alg::UMFPACKFactorization; kwargs.
712722
if alg.check_pattern && !(SuiteSparse.decrement(SparseArrays.getcolptr(A)) ==
713723
cache.cacheval.colptr &&
714724
SuiteSparse.decrement(SparseArrays.getrowval(A)) ==
715-
get_cacheval(cache, :UMFPACKFactorization).rowval)
725+
@get_cacheval(cache, :UMFPACKFactorization).rowval)
716726
fact = lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
717727
nonzeros(A)))
718728
else
719-
fact = lu!(get_cacheval(cache, :UMFPACKFactorization),
729+
fact = lu!(@get_cacheval(cache, :UMFPACKFactorization),
720730
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
721731
nonzeros(A)))
722732
end
@@ -727,7 +737,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::UMFPACKFactorization; kwargs.
727737
cache.isfresh = false
728738
end
729739

730-
y = ldiv!(cache.u, get_cacheval(cache, :UMFPACKFactorization), cache.b)
740+
y = ldiv!(cache.u, @get_cacheval(cache, :UMFPACKFactorization), cache.b)
731741
SciMLBase.build_linear_solution(alg, y, nothing, cache)
732742
end
733743

@@ -782,7 +792,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::KLUFactorization; kwargs...)
782792
A = convert(AbstractMatrix, A)
783793

784794
if cache.isfresh
785-
cacheval = get_cacheval(cache, :KLUFactorization)
795+
cacheval = @get_cacheval(cache, :KLUFactorization)
786796
if cacheval !== nothing && alg.reuse_symbolic
787797
if alg.check_pattern && !(SuiteSparse.decrement(SparseArrays.getcolptr(A)) ==
788798
cacheval.colptr &&
@@ -811,7 +821,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::KLUFactorization; kwargs...)
811821
cache.isfresh = false
812822
end
813823

814-
y = ldiv!(cache.u, get_cacheval(cache, :KLUFactorization), cache.b)
824+
y = ldiv!(cache.u, @get_cacheval(cache, :KLUFactorization), cache.b)
815825
SciMLBase.build_linear_solution(alg, y, nothing, cache)
816826
end
817827

@@ -863,7 +873,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::CHOLMODFactorization; kwargs.
863873
A = convert(AbstractMatrix, A)
864874

865875
if cache.isfresh
866-
cacheval = get_cacheval(cache, :CHOLMODFactorization)
876+
cacheval = @get_cacheval(cache, :CHOLMODFactorization)
867877
fact = cholesky(A; check = false)
868878
if !LinearAlgebra.issuccess(fact)
869879
ldlt!(fact, A; check = false)
@@ -872,7 +882,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::CHOLMODFactorization; kwargs.
872882
cache.isfresh = false
873883
end
874884

875-
cache.u .= get_cacheval(cache, :CHOLMODFactorization) \ cache.b
885+
cache.u .= @get_cacheval(cache, :CHOLMODFactorization) \ cache.b
876886
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
877887
end
878888

@@ -928,7 +938,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::RFLUFactorization{P, T};
928938
kwargs...) where {P, T}
929939
A = cache.A
930940
A = convert(AbstractMatrix, A)
931-
fact, ipiv = get_cacheval(cache, :RFLUFactorization)
941+
fact, ipiv = @get_cacheval(cache, :RFLUFactorization)
932942
if cache.isfresh
933943
if length(ipiv) != min(size(A)...)
934944
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...))
@@ -937,7 +947,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::RFLUFactorization{P, T};
937947
cache.cacheval = (fact, ipiv)
938948
cache.isfresh = false
939949
end
940-
y = ldiv!(cache.u, get_cacheval(cache, :RFLUFactorization)[1], cache.b)
950+
y = ldiv!(cache.u, @get_cacheval(cache, :RFLUFactorization)[1], cache.b)
941951
SciMLBase.build_linear_solution(alg, y, nothing, cache)
942952
end
943953

@@ -1025,10 +1035,10 @@ function SciMLBase.solve!(cache::LinearCache, alg::NormalCholeskyFactorization;
10251035
cache.isfresh = false
10261036
end
10271037
if A isa SparseMatrixCSC
1028-
cache.u .= get_cacheval(cache, :NormalCholeskyFactorization) \ (A' * cache.b)
1038+
cache.u .= @get_cacheval(cache, :NormalCholeskyFactorization) \ (A' * cache.b)
10291039
y = cache.u
10301040
else
1031-
y = ldiv!(cache.u, get_cacheval(cache, :NormalCholeskyFactorization), A' * cache.b)
1041+
y = ldiv!(cache.u, @get_cacheval(cache, :NormalCholeskyFactorization), A' * cache.b)
10321042
end
10331043
SciMLBase.build_linear_solution(alg, y, nothing, cache)
10341044
end
@@ -1072,7 +1082,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::NormalBunchKaufmanFactorizati
10721082
cache.cacheval = fact
10731083
cache.isfresh = false
10741084
end
1075-
y = ldiv!(cache.u, get_cacheval(cache, :NormalBunchKaufmanFactorization), A' * cache.b)
1085+
y = ldiv!(cache.u, @get_cacheval(cache, :NormalBunchKaufmanFactorization), A' * cache.b)
10761086
SciMLBase.build_linear_solution(alg, y, nothing, cache)
10771087
end
10781088

@@ -1131,7 +1141,7 @@ end
11311141
function SciMLBase.solve!(cache::LinearCache, alg::FastLUFactorization; kwargs...)
11321142
A = cache.A
11331143
A = convert(AbstractMatrix, A)
1134-
ws_and_fact = get_cacheval(cache, :FastLUFactorization)
1144+
ws_and_fact = @get_cacheval(cache, :FastLUFactorization)
11351145
if cache.isfresh
11361146
# we will fail here if A is a different *size* than in a previous version of the same cache.
11371147
# it may instead be desirable to resize the workspace.
@@ -1201,7 +1211,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::FastQRFactorization{P};
12011211
kwargs...) where {P}
12021212
A = cache.A
12031213
A = convert(AbstractMatrix, A)
1204-
ws_and_fact = get_cacheval(cache, :FastQRFactorization)
1214+
ws_and_fact = @get_cacheval(cache, :FastQRFactorization)
12051215
if cache.isfresh
12061216
# we will fail here if A is a different *size* than in a previous version of the same cache.
12071217
# it may instead be desirable to resize the workspace.
@@ -1281,7 +1291,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::SparspakFactorization; kwargs
12811291
A = cache.A
12821292
if cache.isfresh
12831293
if cache.cacheval !== nothing && alg.reuse_symbolic
1284-
fact = sparspaklu!(get_cacheval(cache, :SparspakFactorization),
1294+
fact = sparspaklu!(@get_cacheval(cache, :SparspakFactorization),
12851295
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
12861296
nonzeros(A)))
12871297
else
@@ -1291,6 +1301,6 @@ function SciMLBase.solve!(cache::LinearCache, alg::SparspakFactorization; kwargs
12911301
cache.cacheval = fact
12921302
cache.isfresh = false
12931303
end
1294-
y = ldiv!(cache.u, get_cacheval(cache, :SparspakFactorization), cache.b)
1304+
y = ldiv!(cache.u, @get_cacheval(cache, :SparspakFactorization), cache.b)
12951305
SciMLBase.build_linear_solution(alg, y, nothing, cache)
12961306
end

test/default_algs.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using LinearSolve, LinearAlgebra, SparseArrays, Test
1+
using LinearSolve, LinearAlgebra, SparseArrays, Test, JET
22
@test LinearSolve.defaultalg(nothing, zeros(3)).alg ===
33
LinearSolve.DefaultAlgorithmChoice.GenericLUFactorization
44
@test LinearSolve.defaultalg(nothing, zeros(50)).alg ===
@@ -22,6 +22,19 @@ using LinearSolve, LinearAlgebra, SparseArrays, Test
2222
A = rand(4, 4)
2323
b = rand(4)
2424
prob = LinearProblem(A, b)
25+
JET.@test_opt init(prob, nothing)
26+
JET.@test_opt solve(prob, LUFactorization())
27+
JET.@test_opt solve(prob, GenericLUFactorization())
28+
JET.@test_opt solve(prob, QRFactorization())
29+
JET.@test_opt solve(prob, DiagonalFactorization())
30+
#JET.@test_opt solve(prob, SVDFactorization())
31+
#JET.@test_opt solve(prob, KrylovJL_GMRES())
32+
33+
prob = LinearProblem(sparse(A), b)
34+
#JET.@test_opt solve(prob, UMFPACKFactorization())
35+
#JET.@test_opt solve(prob, KLUFactorization())
36+
#JET.@test_opt solve(prob, SparspakFactorization())
37+
#JET.@test_opt solve(prob)
2538
@inferred solve(prob)
2639
@inferred init(prob, nothing)
2740
end

0 commit comments

Comments
 (0)