Skip to content

Commit 95973fd

Browse files
committed
Comments
1 parent ec43b27 commit 95973fd

File tree

4 files changed

+25
-5
lines changed

4 files changed

+25
-5
lines changed

ext/TensorKitCUDAExt/TensorKitCUDAExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ using TensorKit.Factorizations: AbstractAlgorithm
1212
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, project_symmetric_and_check
1313
import TensorKit: randisometry, rand, randn
1414

15-
using TensorKit.MatrixAlgebraKit
15+
using TensorKit: MatrixAlgebraKit
1616

1717
using Random
1818

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,13 @@ for randfun in (:curand, :curandn)
8484
$randfun!(rng, t)
8585
return t
8686
end
87+
88+
function $randfun!(rng::Random.AbstractRNG, t::CuTensorMap)
89+
for (_, b) in blocks(t)
90+
$randfun!(rng, b)
91+
end
92+
return t
93+
end
8794
end
8895
end
8996

@@ -112,7 +119,7 @@ function LinearAlgebra.isposdef(t::CuTensorMap)
112119
InnerProductStyle(spacetype(t)) === EuclideanInnerProduct() || return false
113120
for (c, b) in blocks(t)
114121
# do our own hermitian check
115-
isherm = TensorKit.MatrixAlgebraKit.ishermitian(b; atol = eps(real(eltype(b))), rtol = eps(real(eltype(b))))
122+
isherm = MatrixAlgebraKit.ishermitian(b; atol = eps(real(eltype(b))), rtol = eps(real(eltype(b))))
116123
isherm || return false
117124
isposdef(Hermitian(b)) || return false
118125
end
@@ -135,6 +142,7 @@ end
135142
function TensorKit.exp!(t::CuTensorMap)
136143
domain(t) == codomain(t) ||
137144
error("Exponential of a tensor only exist when domain == codomain.")
145+
!MatrixAlgebraKit.ishermitian(t) && throw(ArgumentError("`exp!` is only supported on hermitian CUDA tensors"))
138146
for (c, b) in blocks(t)
139147
copy!(b, parent(Base.exp(Hermitian(b))))
140148
end
@@ -146,7 +154,8 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
146154
sf = string(f)
147155
@eval function Base.$f(t::CuTensorMap)
148156
domain(t) == codomain(t) ||
149-
throw(SpaceMismatch("`$($sf)` of a tensor only exist when domain == codomain"))
157+
throw(SpaceMismatch("`$($sf)` of a tensor only exists when domain == codomain"))
158+
!MatrixAlgebraKit.ishermitian(t) && throw(ArgumentError("`$($sf)` is only supported on hermitian CUDA tensors"))
150159
T = complex(float(scalartype(t)))
151160
tf = similar(t, T)
152161
for (c, b) in blocks(t)

src/tensors/linalg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ function LinearAlgebra.rank(
289289
)
290290
r = 0 * dim(first(allunits(sectortype(t))))
291291
dim(t) == 0 && return r
292-
S = LinearAlgebra.svdvals(t)
292+
S = svd_vals(t)
293293
tol = max(atol, rtol * maximum(parent(S)))
294294
for (c, b) in pairs(S)
295295
if !isempty(b)
@@ -307,7 +307,7 @@ function LinearAlgebra.cond(t::AbstractTensorMap, p::Real = 2)
307307
throw(SpaceMismatch("`cond` requires domain and codomain to be the same"))
308308
return zero(real(float(scalartype(t))))
309309
end
310-
S = LinearAlgebra.svdvals(t)
310+
S = svd_vals(t)
311311
maxS = maximum(parent(S))
312312
minS = minimum(parent(S))
313313
return iszero(maxS) ? oftype(maxS, Inf) : (maxS / minS)

test/cuda/tensors.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ const CUDAExt = Base.get_extension(TensorKit, :TensorKitCUDAExt)
77
const CuTensorMap = getglobal(CUDAExt, :CuTensorMap)
88
const curand = getglobal(CUDAExt, :curand)
99
const curandn = getglobal(CUDAExt, :curandn)
10+
const curand! = getglobal(CUDAExt, :curand!)
11+
const curandn! = getglobal(CUDAExt, :curandn!)
1012

1113
@isdefined(TestSetup) || include("../setup.jl")
1214
using .TestSetup
@@ -61,6 +63,15 @@ for V in spacelist
6163
@test domain(t) == one(W)
6264
@test typeof(t) == TensorMap{Float64, spacetype(t), 5, 0, CuVector{Float64, CUDA.DeviceMemory}}
6365
end
66+
for f! in (curand!, curandn!)
67+
t = @constinferred CUDA.zeros(W)
68+
f!(t)
69+
@test scalartype(t) == Float64
70+
@test codomain(t) == W
71+
@test space(t) == (W one(W))
72+
@test domain(t) == one(W)
73+
@test typeof(t) == TensorMap{Float64, spacetype(t), 5, 0, CuVector{Float64, CUDA.DeviceMemory}}
74+
end
6475
for T in (Int, Float32, Float64, ComplexF32, ComplexF64)
6576
t = @constinferred CUDA.zeros(T, W)
6677
CUDA.@allowscalar begin

0 commit comments

Comments
 (0)