Skip to content

Commit dad747b

Browse files
authored
S should be a concrete type in the workspaces (#977)
* S should be a concrete type in the workspaces * Fix a typo * Fix the GPU tests
1 parent 843465c commit dad747b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+102
-62
lines changed

src/bicgstab.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ kwargs_bicgstab = (:c, :M, :N, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose,
131131

132132
# Check type consistency
133133
eltype(A) == FC || @warn "eltype(A) ≠ $FC. This could lead to errors or additional allocations in operator-vector products."
134-
ktypeof(b) <: S || error("ktypeof(b) is not a subtype of $S")
135-
ktypeof(c) <: S || error("ktypeof(c) is not a subtype of $S")
134+
ktypeof(b) == S || error("ktypeof(b) must be equal to $S")
135+
ktypeof(c) == S || error("ktypeof(c) must be equal to $S")
136136

137137
# Set up workspace.
138138
allocate_if(!MisI, solver, :t , S, solver.x) # The length of t is n

src/bilq.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ kwargs_bilq = (:c, :transfer_to_bicg, :M, :N, :ldiv, :atol, :rtol, :itmax, :time
124124

125125
# Check type consistency
126126
eltype(A) == FC || @warn "eltype(A) ≠ $FC. This could lead to errors or additional allocations in operator-vector products."
127-
ktypeof(b) <: S || error("ktypeof(b) is not a subtype of $S")
128-
ktypeof(c) <: S || error("ktypeof(c) is not a subtype of $S")
127+
ktypeof(b) == S || error("ktypeof(b) must be equal to $S")
128+
ktypeof(c) == S || error("ktypeof(c) must be equal to $S")
129129

130130
# Compute the adjoint of A, M and N
131131
Aᴴ = A'

src/bilqr.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ kwargs_bilqr = (:transfer_to_bicg, :atol, :rtol, :itmax, :timemax, :verbose, :hi
119119

120120
# Check type consistency
121121
eltype(A) == FC || @warn "eltype(A) ≠ $FC. This could lead to errors or additional allocations in operator-vector products."
122-
ktypeof(b) <: S || error("ktypeof(b) is not a subtype of $S")
123-
ktypeof(c) <: S || error("ktypeof(c) is not a subtype of $S")
122+
ktypeof(b) == S || error("ktypeof(b) must be equal to $S")
123+
ktypeof(c) == S || error("ktypeof(c) must be equal to $S")
124124

125125
# Compute the adjoint of A
126126
Aᴴ = A'

src/block_gmres.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ kwargs_block_gmres = (:M, :N, :ldiv, :restart, :reorthogonalization, :atol, :rto
109109

110110
# Check type consistency
111111
eltype(A) == FC || @warn "eltype(A) ≠ $FC. This could lead to errors or additional allocations in operator-matrix products."
112-
ktypeof(B) <: SM || error("ktypeof(B) is not a subtype of $SM")
112+
ktypeof(B) == SM || error("ktypeof(B) must be equal to $SM")
113113

114114
# Set up workspace.
115115
allocate_if(!MisI , solver, :Q , SM, solver.n, solver.p)

src/block_krylov_solvers.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,15 @@ mutable struct BlockMinresSolver{T,FC,SV,SM} <: BlockKrylovSolver{T,FC,SV,SM}
4242
end
4343

4444
function BlockMinresSolver(m::Integer, n::Integer, p::Integer, SV::Type, SM::Type)
45-
FC = eltype(SV)
46-
T = real(FC)
47-
ΔX = SM(undef, 0, 0)
48-
X = SM(undef, n, p)
49-
P = SM(undef, 0, 0)
50-
Q = SM(undef, n, p)
51-
C = SM(undef, p, p)
52-
D = SM(undef, 2p, p)
53-
Φ = SM(undef, p, p)
45+
FC = eltype(SV)
46+
T = real(FC)
47+
ΔX = SM(undef, 0, 0)
48+
X = SM(undef, n, p)
49+
P = SM(undef, 0, 0)
50+
Q = SM(undef, n, p)
51+
C = SM(undef, p, p)
52+
D = SM(undef, 2p, p)
53+
Φ = SM(undef, p, p)
5454
Vₖ₋₁ = SM(undef, n, p)
5555
Vₖ = SM(undef, n, p)
5656
wₖ₋₂ = SM(undef, n, p)
@@ -59,6 +59,8 @@ function BlockMinresSolver(m::Integer, n::Integer, p::Integer, SV::Type, SM::Typ
5959
Hₖ₋₁ = SM(undef, 2p, p)
6060
τₖ₋₂ = SV(undef, p)
6161
τₖ₋₁ = SV(undef, p)
62+
SV = isconcretetype(SV) ? SV : typeof(τₖ₋₁)
63+
SM = isconcretetype(SM) ? SM : typeof(X)
6264
stats = SimpleStats(0, false, false, false, T[], T[], T[], 0.0, "unknown")
6365
solver = BlockMinresSolver{T,FC,SV,SM}(m, n, p, ΔX, X, P, Q, C, D, Φ, Vₖ₋₁, Vₖ, wₖ₋₂, wₖ₋₁, Hₖ₋₂, Hₖ₋₁, τₖ₋₂, τₖ₋₁, false, stats)
6466
return solver
@@ -119,6 +121,8 @@ function BlockGmresSolver(m::Integer, n::Integer, p::Integer, SV::Type, SM::Type
119121
R = SM[SM(undef, p, p) for i = 1 : div(memory * (memory+1), 2)]
120122
H = SM[SM(undef, 2p, p) for i = 1 : memory]
121123
τ = SV[SV(undef, p) for i = 1 : memory]
124+
SV = isconcretetype(SV) ? SV : typeof(τ)
125+
SM = isconcretetype(SM) ? SM : typeof(X)
122126
stats = SimpleStats(0, false, false, false, T[], T[], T[], 0.0, "unknown")
123127
solver = BlockGmresSolver{T,FC,SV,SM}(m, n, p, ΔX, X, W, P, Q, C, D, V, Z, R, H, τ, false, stats)
124128
return solver

src/block_minres.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ kwargs_block_minres = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :his
101101

102102
# Check type consistency
103103
eltype(A) == FC || @warn "eltype(A) ≠ $FC. This could lead to errors or additional allocations in operator-matrix products."
104-
ktypeof(B) <: SM || error("ktypeof(B) is not a subtype of $SM")
104+
ktypeof(B) == SM || error("ktypeof(B) must be equal to $SM")
105105

106106
# Set up workspace.
107107
Vₖ₋₁, Vₖ = solver.Vₖ₋₁, solver.Vₖ

src/car.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ kwargs_car = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :history, :ca
113113

114114
# Check type consistency
115115
eltype(A) == FC || @warn "eltype(A) ≠ $FC. This could lead to errors or additional allocations in operator-vector products."
116-
ktypeof(b) <: S || error("ktypeof(b) is not a subtype of $S")
116+
ktypeof(b) == S || error("ktypeof(b) must be equal to $S")
117117

118118
# Set up workspace.
119119
allocate_if(!MisI, solver, :Mu, S, solver.x) # The length of Mu is n

src/cg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ kwargs_cg = (:M, :ldiv, :radius, :linesearch, :atol, :rtol, :itmax, :timemax, :v
123123

124124
# Check type consistency
125125
eltype(A) == FC || @warn "eltype(A) ≠ $FC. This could lead to errors or additional allocations in operator-vector products."
126-
ktypeof(b) <: S || error("ktypeof(b) is not a subtype of $S")
126+
ktypeof(b) == S || error("ktypeof(b) must be equal to $S")
127127

128128
# Set up workspace.
129129
allocate_if(!MisI, solver, :z, S, solver.x) # The length of z is n

src/cg_lanczos.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ kwargs_cg_lanczos = (:M, :ldiv, :check_curvature, :atol, :rtol, :itmax, :timemax
118118

119119
# Check type consistency
120120
eltype(A) == FC || @warn "eltype(A) ≠ $FC. This could lead to errors or additional allocations in operator-vector products."
121-
ktypeof(b) <: S || error("ktypeof(b) is not a subtype of $S")
121+
ktypeof(b) == S || error("ktypeof(b) must be equal to $S")
122122

123123
# Set up workspace.
124124
allocate_if(!MisI, solver, :v, S, solver.x) # The length of v is n

src/cg_lanczos_shift.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ kwargs_cg_lanczos_shift = (:M, :ldiv, :check_curvature, :atol, :rtol, :itmax, :t
114114

115115
# Check type consistency
116116
eltype(A) == FC || @warn "eltype(A) ≠ $FC. This could lead to errors or additional allocations in operator-vector products."
117-
ktypeof(b) <: S || error("ktypeof(b) is not a subtype of $S")
117+
ktypeof(b) == S || error("ktypeof(b) must be equal to $S")
118118

119119
# Set up workspace.
120120
allocate_if(!MisI, solver, :v, S, solver.Mv) # The length of v is n

0 commit comments

Comments
 (0)