Skip to content

Commit 9d8bb18

Browse files
committed
_isidentity_struct
1 parent 1c4d194 commit 9d8bb18

File tree

4 files changed

+30
-4
lines changed

4 files changed

+30
-4
lines changed

src/LinearSolve.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ needs_concrete_A(alg::AbstractFactorization) = true
4242
needs_concrete_A(alg::AbstractKrylovSubspaceMethod) = false
4343
needs_concrete_A(alg::AbstractSolveFunction) = false
4444

45+
# Util
46+
_isidentity_struct(A) = false
47+
_isidentity_struct::Number) = isone(λ)
48+
_isidentity_struct(A::UniformScaling) = isone(A.λ)
49+
_isidentity_struct(::IterativeSolvers.Identity) = true
50+
_isidentity_struct(::SciMLBase.IdentityOperator) = true
51+
_isidentity_struct(::SciMLBase.DiffEqIdentity) = true
52+
4553
# Code
4654

4755
const INCLUDE_SPARSE = Preferences.@load_preference("include_sparse", Base.USE_GPL_LIBS)

src/iterative_wrappers.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,9 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
150150
M = cache.Pl
151151
N = cache.Pr
152152

153-
M = (M === Identity()) ? I : InvPreconditioner(M)
154-
N = (N === Identity()) ? I : InvPreconditioner(N)
153+
# use no-op preconditioner for Krylov.jl (LinearAlgebra.I) when M/N is identity
154+
M = _isidentity_struct(M) ? I : InvPreconditioner(M)
155+
N = _isidentity_struct(M) ? I : InvPreconditioner(N)
155156

156157
atol = float(cache.abstol)
157158
rtol = float(cache.reltol)
@@ -234,15 +235,15 @@ function init_cacheval(alg::IterativeSolversJL, A, b, u, Pl, Pr, maxiters::Int,
234235
alg.kwargs...)
235236

236237
iterable = if alg.generate_iterator === IterativeSolvers.cg_iterator!
237-
Pr !== Identity() &&
238+
! _isidentity_struct(Pr) &&
238239
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
239240
alg.generate_iterator(u, A, b, Pl;
240241
kwargs...)
241242
elseif alg.generate_iterator === IterativeSolvers.gmres_iterable!
242243
alg.generate_iterator(u, A, b; Pl = Pl, Pr = Pr, restart = restart,
243244
kwargs...)
244245
elseif alg.generate_iterator === IterativeSolvers.bicgstabl_iterator!
245-
Pr !== Identity() &&
246+
! _isidentity_struct(Pr) &&
246247
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
247248
alg.generate_iterator(u, A, b, alg.args...; Pl = Pl,
248249
abstol = abstol, reltol = reltol,

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ if GROUP == "All" || GROUP == "Core"
2626
@time @safetestset "Non-Square Tests" begin include("nonsquare.jl") end
2727
@time @safetestset "SparseVector b Tests" begin include("sparse_vector.jl") end
2828
@time @safetestset "Default Alg Tests" begin include("default_algs.jl") end
29+
@time @safetestset "Traits" begin include("traits.jl") end
2930
end
3031

3132
if GROUP == "LinearSolveCUDA"

test/traits.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#
2+
using LinearSolve, LinearAlgebra, Test
3+
using LinearSolve: _isidentity_struct
4+
5+
N = 4
6+
7+
@testset "Traits" begin
8+
9+
@test _isidentity_struct(I)
10+
@test _isidentity_struct(1.0 * I)
11+
@test _isidentity_struct(SciMLBase.IdentityOperator{N}())
12+
@test _isidentity_struct(SciMLBase.DiffEqIdentity(rand(4)))
13+
@test ! _isidentity_struct(2.0 * I)
14+
@test ! _isidentity_struct(rand(N, N))
15+
@test ! _isidentity_struct(Matrix(I, N, N))
16+
end

0 commit comments

Comments
 (0)