Skip to content

Commit c8d5d9e

Browse files
Merge pull request #198 from ChrisRackauckas-Claude/fix-gpu-complex-union
Fix GPU expv! to support complex t via shared implementation
2 parents 1fb5a12 + 5fd46c3 commit c8d5d9e

File tree

2 files changed

+38
-8
lines changed

2 files changed

+38
-8
lines changed

src/krylov_phiv.jl

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,9 @@ function expv!(w::AbstractVector{Complex{Tw}}, t::Complex{Tt}, Ks::KrylovSubspac
129129
lmul!(beta, mul!(w, @view(V[:, 1:m]), compatible_multiplicative_operand(V, expHe))) # exp(A) ≈ norm(b) * V * exp(H)e
130130
end
131131

132-
function ExponentialUtilities.expv!(w::GPUArraysCore.AbstractGPUVector{Tw},
133-
t::Real, Ks::KrylovSubspace{T, U};
134-
cache = nothing,
135-
expmethod = ExpMethodHigham2005Base()) where {Tw, T, U}
132+
# Internal GPU implementation shared by Real and Complex t methods
133+
function _expv_gpu_impl!(w::GPUArraysCore.AbstractGPUVector, t, Ks::KrylovSubspace{T, U},
134+
cache, expmethod) where {T, U}
136135
m, beta, V, H = Ks.m, Ks.beta, getV(Ks), getH(Ks)
137136
@assert length(w)==size(V, 1) "Dimension mismatch"
138137
if isnothing(cache)
@@ -150,18 +149,31 @@ function ExponentialUtilities.expv!(w::GPUArraysCore.AbstractGPUVector{Tw},
150149
if ishermitian(cache)
151150
# Optimize the case for symtridiagonal H
152151
F = eigen!(SymTridiagonal(cache))
153-
expHe = F.vectors * (exp.(lmul!(t, F.values)) .* @view(F.vectors[1, :]))
152+
expHe = F.vectors * (exp.(t * F.values) .* @view(F.vectors[1, :]))
154153
else
155-
#lmul!(t, cache)
156-
#expH = exponential!(cache, expmethod)
157-
#expHe = @view(expH[:, 1])
158154
expH = exponential!(t * cache, expmethod)
159155
expHe = @view(expH[:, 1])
160156
end
161157

162158
lmul!(beta, mul!(w, @view(V[:, 1:m]), Adapt.adapt(parameterless_type(w), expHe))) # exp(A) ≈ norm(b) * V * exp(H)e
163159
end
164160

161+
# GPU expv! for Real t
162+
function ExponentialUtilities.expv!(w::GPUArraysCore.AbstractGPUVector{Tw},
163+
t::Real, Ks::KrylovSubspace{T, U};
164+
cache = nothing,
165+
expmethod = ExpMethodHigham2005Base()) where {Tw, T, U}
166+
_expv_gpu_impl!(w, t, Ks, cache, expmethod)
167+
end
168+
169+
# GPU expv! for Complex t
170+
function ExponentialUtilities.expv!(w::GPUArraysCore.AbstractGPUVector{Complex{Tw}},
171+
t::Complex{Tt}, Ks::KrylovSubspace{T, U};
172+
cache = nothing,
173+
expmethod = ExpMethodHigham2005Base()) where {Tw, Tt, T, U}
174+
_expv_gpu_impl!(w, t, Ks, cache, expmethod)
175+
end
176+
165177
compatible_multiplicative_operand(::AbstractArray, source::AbstractArray) = source
166178

167179
############################

test/gpu/gputests.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using LinearAlgebra
22
using SparseArrays
33
using CUDA
44
using CUDA.CUSPARSE
5+
using ExponentialUtilities
56
using ExponentialUtilities: inplace_add!,
67
exponential!, ExpMethodHigham2005, expv,
78
expv_timestep
@@ -55,3 +56,20 @@ E2 = Array(expv(t, A_gpu, b_gpu))
5556
E1 = expv_timestep(ts, A, b)
5657
E2 = Array(expv_timestep(ts, A_gpu, b_gpu))
5758
@test E1 E2
59+
60+
@testset "GPU expv! with complex t" begin
61+
T = ComplexF64
62+
v0 = randn(T, 4)
63+
cuv0 = cu(v0)
64+
A = randn(T, 4, 4)
65+
cuA = cu(A)
66+
67+
Ks = ExponentialUtilities.arnoldi(A, v0; tol = 1e-7, ishermitian = false, opnorm = 1.0)
68+
cuKs = ExponentialUtilities.arnoldi(cuA, cuv0; tol = 1e-7, ishermitian = false,
69+
opnorm = 1.0)
70+
71+
dt = 0.01im
72+
ExponentialUtilities.expv!(v0, dt, Ks)
73+
ExponentialUtilities.expv!(cuv0, dt, cuKs)
74+
@test v0 collect(cuv0)
75+
end

0 commit comments

Comments
 (0)