@@ -129,10 +129,9 @@ function expv!(w::AbstractVector{Complex{Tw}}, t::Complex{Tt}, Ks::KrylovSubspac
129129 lmul! (beta, mul! (w, @view (V[:, 1 : m]), compatible_multiplicative_operand (V, expHe))) # exp(A) ≈ norm(b) * V * exp(H)e
130130end
131131
132- function ExponentialUtilities. expv! (w:: GPUArraysCore.AbstractGPUVector{Tw} ,
133- t:: Real , Ks:: KrylovSubspace{T, U} ;
134- cache = nothing ,
135- expmethod = ExpMethodHigham2005Base ()) where {Tw, T, U}
132+ # Internal GPU implementation shared by Real and Complex t methods
133+ function _expv_gpu_impl! (w:: GPUArraysCore.AbstractGPUVector , t, Ks:: KrylovSubspace{T, U} ,
134+ cache, expmethod) where {T, U}
136135 m, beta, V, H = Ks. m, Ks. beta, getV (Ks), getH (Ks)
137136 @assert length (w)== size (V, 1 ) " Dimension mismatch"
138137 if isnothing (cache)
@@ -150,18 +149,31 @@ function ExponentialUtilities.expv!(w::GPUArraysCore.AbstractGPUVector{Tw},
150149 if ishermitian (cache)
151150 # Optimize the case for symtridiagonal H
152151 F = eigen! (SymTridiagonal (cache))
153- expHe = F. vectors * (exp .(lmul! (t, F. values) ) .* @view (F. vectors[1 , :]))
152+ expHe = F. vectors * (exp .(t * F. values) .* @view (F. vectors[1 , :]))
154153 else
155- # lmul!(t, cache)
156- # expH = exponential!(cache, expmethod)
157- # expHe = @view(expH[:, 1])
158154 expH = exponential! (t * cache, expmethod)
159155 expHe = @view (expH[:, 1 ])
160156 end
161157
162158 lmul! (beta, mul! (w, @view (V[:, 1 : m]), Adapt. adapt (parameterless_type (w), expHe))) # exp(A) ≈ norm(b) * V * exp(H)e
163159end
164160
161+ # GPU expv! for Real t
162+ function ExponentialUtilities. expv! (w:: GPUArraysCore.AbstractGPUVector{Tw} ,
163+ t:: Real , Ks:: KrylovSubspace{T, U} ;
164+ cache = nothing ,
165+ expmethod = ExpMethodHigham2005Base ()) where {Tw, T, U}
166+ _expv_gpu_impl! (w, t, Ks, cache, expmethod)
167+ end
168+
169+ # GPU expv! for Complex t
170+ function ExponentialUtilities. expv! (w:: GPUArraysCore.AbstractGPUVector{Complex{Tw}} ,
171+ t:: Complex{Tt} , Ks:: KrylovSubspace{T, U} ;
172+ cache = nothing ,
173+ expmethod = ExpMethodHigham2005Base ()) where {Tw, Tt, T, U}
174+ _expv_gpu_impl! (w, t, Ks, cache, expmethod)
175+ end
176+
165177compatible_multiplicative_operand (:: AbstractArray , source:: AbstractArray ) = source
166178
167179# ###########################
0 commit comments