Skip to content

Commit 6bfa1ca

Browse files
committed
Comments and coverage
1 parent 24f2f77 commit 6bfa1ca

File tree

5 files changed

+51
-10
lines changed

5 files changed

+51
-10
lines changed

ext/TensorKitCUDAExt/TensorKitCUDAExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using TensorKit.Factorizations
1010
using TensorKit.Strided
1111
using TensorKit.Factorizations: AbstractAlgorithm
1212
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, _project_symmetric_and_check
13-
import TensorKit: randisometry
13+
import TensorKit: randisometry, rand, randn
1414

1515
using TensorKit.MatrixAlgebraKit
1616

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@ const CuTensor{T, S, N} = CuTensorMap{T, S, N, 0}
33

44
const AdjointCuTensorMap{T, S, N₁, N₂} = AdjointTensorMap{T, S, N₁, N₂, CuTensorMap{T, S, N₁, N₂}}
55

6-
function CuTensorMap{T, S, N₁, N₂}(t::TensorMap{T, S, N₁, N₂, A}) where {T, S, N₁, N₂, A}
6+
function CuTensorMap(t::TensorMap{T, S, N₁, N₂, A}) where {T, S, N₁, N₂, A}
77
return CuTensorMap{T, S, N₁, N₂}(CuArray(t.data), t.space)
88
end
99

1010
# project_symmetric! doesn't yet work for GPU types, so do this on the host, then copy
11-
function TensorKit._project_symmetric_and_check(::Type{T}, ::Type{A}, data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data)))))) where {T, A <: CuVector{T}}
11+
function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data)))))) where {T, A <: CuVector{T}}
1212
h_t = TensorKit.TensorMapWithStorage{T, Vector{T}}(undef, V)
1313
h_t = TensorKit.project_symmetric!(h_t, Array(data))
1414
# verify result
@@ -146,8 +146,9 @@ for randfun in (:rand, :randn, :randisometry)
146146
end
147147
end
148148

149-
function Base.convert(::Type{CuTensorMap}, t::AbstractTensorMap)
150-
return copy!(CuTensorMap{scalartype(t)}(undef, space(t)), t)
149+
function Base.convert(::Type{CuTensorMap}, t::AbstractTensorMap{T, S}) where {T, S}
150+
d_data = CuArray(t.data)
151+
return TensorMap{T}(d_data, t.space)
151152
end
152153

153154
# Scalar implementation

src/tensors/linalg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ function LinearAlgebra.rank(
299299
r = 0 * dim(first(allunits(sectortype(t))))
300300
dim(t) == 0 && return r
301301
S = LinearAlgebra.svdvals(t)
302-
tol = max(atol, rtol * mapreduce(maximum, max, values(S)))
302+
tol = max(atol, rtol * maximum(S.data))
303303
for (c, b) in pairs(S)
304304
if !isempty(b)
305305
r += dim(c) * count(>(tol), b)
@@ -317,8 +317,8 @@ function LinearAlgebra.cond(t::AbstractTensorMap, p::Real = 2)
317317
return zero(real(float(scalartype(t))))
318318
end
319319
S = LinearAlgebra.svdvals(t)
320-
maxS = mapreduce(maximum, max, values(S))
321-
minS = mapreduce(minimum, min, values(S))
320+
maxS = maximum(S.data)
321+
minS = minimum(S.data)
322322
return iszero(maxS) ? oftype(maxS, Inf) : (maxS / minS)
323323
else
324324
throw(ArgumentError("cond currently only defined for p=2"))

src/tensors/tensor.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ TensorMap(data::AbstractArray, codom::TensorSpace, dom::TensorSpace; kwargs...)
198198
Tensor(data::AbstractArray, codom::TensorSpace; kwargs...) =
199199
TensorMap(data, codom one(codom); kwargs...)
200200

201-
function _project_symmetric_and_check(::Type{T}, ::Type{A}, data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data)))))) where {T, A <: DenseVector{T}}
201+
function project_symmetric_and_check(::Type{T}, ::Type{A}, data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data)))))) where {T, A <: DenseVector{T}}
202202
t = TensorMapWithStorage{T, A}(undef, V)
203203
t = project_symmetric!(t, data)
204204
# verify result
@@ -218,7 +218,7 @@ function TensorMapWithStorage{T, A}(
218218
sectortype(V) === Trivial &&
219219
return tensormaptype(spacetype(V), numout(V), numin(V), A)(reshape(data, length(data)), V)
220220

221-
return _project_symmetric_and_check(T, A, data, V; tol)
221+
return project_symmetric_and_check(T, A, data, V; tol)
222222
end
223223
TensorMapWithStorage{T, A}(data::AbstractArray, codom::TensorSpace, dom::TensorSpace; kwargs...) where {T, A} =
224224
TensorMapWithStorage{T, A}(data, codom dom; kwargs...)

test/cuda/tensors.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ using TensorKit, Combinatorics
44
ad = adapt(Array)
55
const CUDAExt = Base.get_extension(TensorKit, :TensorKitCUDAExt)
66
@assert !isnothing(CUDAExt)
7+
const CuTensorMap = getglobal(CUDAExt, :CuTensorMap)
8+
const curand = getglobal(CUDAExt, :curand)
9+
const curandn = getglobal(CUDAExt, :curandn)
710

811
@isdefined(TestSetup) || include("../setup.jl")
912
using .TestSetup
@@ -41,6 +44,23 @@ for V in spacelist
4144
V1, V2, V3, V4, V5 = V
4245
@timedtestset "Basic tensor properties" begin
4346
W = V1 V2 V3 V4 V5
47+
# test default pass-throughs
48+
for f in (CUDA.zeros, CUDA.ones, curand, curandn)
49+
t = @constinferred f(W)
50+
@test scalartype(t) == Float64
51+
@test codomain(t) == W
52+
@test space(t) == (W one(W))
53+
@test domain(t) == one(W)
54+
@test typeof(t) == TensorMap{Float64, spacetype(t), 5, 0, CuVector{Float64, CUDA.DeviceMemory}}
55+
end
56+
for f in (rand, randn)
57+
t = @constinferred f(CuVector{Float64}, W)
58+
@test scalartype(t) == Float64
59+
@test codomain(t) == W
60+
@test space(t) == (W one(W))
61+
@test domain(t) == one(W)
62+
@test typeof(t) == TensorMap{Float64, spacetype(t), 5, 0, CuVector{Float64, CUDA.DeviceMemory}}
63+
end
4464
for T in (Int, Float32, Float64, ComplexF32, ComplexF64)
4565
t = @constinferred CUDA.zeros(T, W)
4666
CUDA.@allowscalar begin
@@ -64,6 +84,26 @@ for V in spacelist
6484
@test typeof(c) === sectortype(t)
6585
end
6686
end
87+
@timedtestset "Conversion to/from host" begin
88+
W = V1 V2 V3 V4 V5
89+
for T in (Int, Float32, ComplexF64)
90+
h_t = @constinferred rand(T, W)
91+
t1 = convert(CuTensorMap, h_t)
92+
@test collect(t1.data) == h_t.data
93+
@test t1.space == h_t.space
94+
@test scalartype(t1) == T
95+
@test codomain(t1) == W
96+
@test space(t1) == (W one(W))
97+
@test domain(t1) == one(W)
98+
t2 = CuTensorMap(h_t)
99+
@test collect(t2.data) == h_t.data
100+
@test t2.space == h_t.space
101+
@test scalartype(t2) == T
102+
@test codomain(t2) == W
103+
@test space(t2) == (W one(W))
104+
@test domain(t2) == one(W)
105+
end
106+
end
67107
@timedtestset "Tensor Dict conversion" begin
68108
W = V1 V2 V3 V4 V5
69109
for T in (Int, Float32, ComplexF64)

0 commit comments

Comments
 (0)