@@ -105,6 +105,19 @@ function _sym_givens(a::T, b::T) where {T <: AbstractFloat}
105
105
return (c, s, ρ)
106
106
end
107
107
108
+ function _sym_givens! (c, s, R, nr:: Int , inner_iter:: Int , bsize:: Int , Hbis)
109
+ if __is_extension_loaded (Val (:KernelAbstractions ))
110
+ return _fast_sym_givens! (c, s, R, nr, inner_iter, bsize, Hbis)
111
+ end
112
+ __res = _sym_givens .(R[nr + inner_iter], Hbis)
113
+ GPUArraysCore. @allowscalar foreach (1 : bsize) do i
114
+ c[inner_iter][i] = __res[i][1 ]
115
+ s[inner_iter][i] = __res[i][2 ]
116
+ R[nr + inner_iter][i] = __res[i][3 ]
117
+ end
118
+ return c, s, R
119
+ end
120
+
108
121
_no_preconditioner (:: Nothing ) = true
109
122
_no_preconditioner (:: IdentityOperator ) = true
110
123
_no_preconditioner (:: UniformScaling ) = true
@@ -221,15 +234,7 @@ function SciMLBase.solve!(cache::SimpleGMRESCache{false}, lincache::LinearCache)
221
234
while ! (solved || tired || breakdown)
222
235
# Initialize workspace.
223
236
nr = 0 # Number of coefficients stored in Rₖ.
224
- #= TODO : Check that not zeroing out doesn't lead to incorrect results.
225
- foreach(V) do v
226
- v .= zero(T) # Orthogonal basis of Kₖ(MAN, Mr₀).
227
- end
228
- s .= zero(T) # Givens sines used for the factorization QₖRₖ = Hₖ₊₁.ₖ.
229
- c .= zero(T) # Givens cosines used for the factorization QₖRₖ = Hₖ₊₁.ₖ.
230
- R .= zero(T) # Upper triangular matrix Rₖ.
231
- z .= zero(T) # Right-hand of the least squares problem min ‖Hₖ₊₁.ₖyₖ - βe₁‖₂.
232
- =#
237
+ # TODO : Check that not zeroing out doesn't lead to incorrect results.
233
238
234
239
if restart
235
240
xr .= zero (T) # xr === Δx when restart is set to true
@@ -517,13 +522,7 @@ function SciMLBase.solve!(cache::SimpleGMRESCache{true}, lincache::LinearCache)
517
522
# Compute and apply current Givens reflection Ωₖ.
518
523
# [cₖ sₖ] [ r̄ₖ.ₖ ] = [rₖ.ₖ]
519
524
# [s̄ₖ -cₖ] [hₖ₊₁.ₖ] [ 0 ]
520
- # FIXME : Write inplace kernel
521
- __res = _sym_givens .(R[nr + inner_iter], Hbis)
522
- foreach (1 : bsize) do i
523
- c[inner_iter][i] = __res[i][1 ]
524
- s[inner_iter][i] = __res[i][2 ]
525
- R[nr + inner_iter][i] = __res[i][3 ]
526
- end
525
+ _sym_givens! (c, s, R, nr, inner_iter, bsize, Hbis)
527
526
528
527
# Update zₖ = (Qₖ)ᴴβe₁
529
528
ζₖ₊₁ = conj .(s[inner_iter]) .* z[inner_iter]
@@ -567,15 +566,8 @@ function SciMLBase.solve!(cache::SimpleGMRESCache{true}, lincache::LinearCache)
567
566
pos = pos - j + 1 # position of rᵢ.ⱼ₋₁
568
567
end
569
568
# Rₖ can be singular if the system is inconsistent
570
- # FIXME : Write with broadcasting
571
- GPUArraysCore. @allowscalar foreach (1 : bsize) do B
572
- if abs (R[pos][B]) ≤ btol
573
- y[i][B] = zero (T)
574
- inconsistent = true
575
- else
576
- y[i][B] /= R[pos][B]
577
- end
578
- end
569
+ y[i] .= ifelse .(abs .(R[pos]) .≤ btol, zero (T), y[i] ./ R[pos]) # yᵢ ← yᵢ / rᵢᵢ
570
+ inconsistent = any (abs .(R[pos]) .≤ btol)
579
571
end
580
572
581
573
# Form xₖ = NVₖyₖ
0 commit comments