Skip to content

Commit 131fd04

Browse files
committed
More tests for CUDA
1 parent ae77d20 commit 131fd04

File tree

2 files changed

+72
-13
lines changed

2 files changed

+72
-13
lines changed

Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
2727

2828
[sources]
2929
GPUArrays = {rev = "master", url = "https://github.com/JuliaGPU/GPUArrays.jl"}
30-
MatrixAlgebraKit = {rev = "ksh/tk", url = "https://github.com/QuantumKitHub/MatrixAlgebraKit.jl"}
3130
AMDGPU = {rev = "master", url = "https://github.com/JuliaGPU/AMDGPU.jl"}
32-
cuTENSOR = {subdir = "lib/cutensor", url = "https://github.com/JuliaGPU/CUDA.jl", rev="master"}
31+
MatrixAlgebraKit = {rev = "ksh/tk2", url = "https://github.com/QuantumKitHub/MatrixAlgebraKit.jl"}
3332

3433
[extensions]
3534
TensorKitAMDGPUExt = "AMDGPU"

test/cuda/factorizations.jl

Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
using LinearAlgebra, CUDA, Test, TestExtras, TensorKit, cuTENSOR
1+
using Test, TestExtras
2+
using TensorKit
3+
using LinearAlgebra: LinearAlgebra
4+
using CUDA, cuTENSOR
25

36
const CUDAExt = Base.get_extension(TensorKit, :TensorKitCUDAExt)
47
@assert !isnothing(CUDAExt)
@@ -55,15 +58,15 @@ for V in spacelist
5558
@test Q * R t
5659
@test isisometric(Q)
5760

58-
Q, R = @constinferred left_orth(t; kind = :qr)
61+
Q, R = @constinferred left_orth(t)
5962
@test Q * R t
6063
@test isisometric(Q)
6164

6265
N = @constinferred qr_null(t)
6366
@test isisometric(N)
6467
@test norm(N' * t) 0 atol = 100 * eps(norm(t))
6568

66-
N = @constinferred left_null(t; kind = :qr)
69+
N = @constinferred left_null(t)
6770
@test isisometric(N)
6871
@test norm(N' * t) 0 atol = 100 * eps(norm(t))
6972
end
@@ -82,7 +85,7 @@ for V in spacelist
8285
@test isisometric(Q)
8386
@test dim(Q) == dim(R) == dim(t)
8487

85-
Q, R = @constinferred left_orth(t; kind = :qr)
88+
Q, R = @constinferred left_orth(t)
8689
@test Q * R t
8790
@test isisometric(Q)
8891
@test dim(Q) == dim(R) == dim(t)
@@ -107,7 +110,7 @@ for V in spacelist
107110
@test L * Q t
108111
@test isisometric(Q; side = :right)
109112

110-
L, Q = @constinferred right_orth(t; kind = :lq)
113+
L, Q = @constinferred right_orth(t)
111114
@test L * Q t
112115
@test isisometric(Q; side = :right)
113116

@@ -130,7 +133,7 @@ for V in spacelist
130133
@test isisometric(Q; side = :right)
131134
@test dim(Q) == dim(L) == dim(t)
132135

133-
L, Q = @constinferred right_orth(t; kind = :lq)
136+
L, Q = @constinferred right_orth(t)
134137
@test L * Q t
135138
@test isisometric(Q; side = :right)
136139
@test dim(Q) == dim(L) == dim(t)
@@ -157,7 +160,7 @@ for V in spacelist
157160
# broken for T <: Complex
158161
@test isposdef(p)
159162

160-
w, p = @constinferred left_orth(t; kind = :polar)
163+
w, p = @constinferred left_orth(t; alg = :polar)
161164
@test w * p t
162165
@test isisometric(w)
163166
end
@@ -174,7 +177,7 @@ for V in spacelist
174177
@test isisometric(wᴴ; side = :right)
175178
@test isposdef(p)
176179

177-
p, wᴴ = @constinferred right_orth(t; kind = :polar)
180+
p, wᴴ = @constinferred right_orth(t; alg = :polar)
178181
@test p * wᴴ t
179182
@test isisometric(wᴴ; side = :right)
180183
end
@@ -204,15 +207,15 @@ for V in spacelist
204207
@test b s′[c]
205208
end
206209

207-
v, c = @constinferred left_orth(t; kind = :svd)
210+
v, c = @constinferred left_orth(t; alg = :svd)
208211
@test v * c t
209212
@test isisometric(v)
210213

211-
N = @constinferred left_null(t; kind = :svd)
214+
N = @constinferred left_null(t; alg = :svd)
212215
@test isisometric(N)
213216
@test norm(N' * t) 0 atol = 100 * eps(norm(t))
214217

215-
Nᴴ = @constinferred right_null(t; kind = :svd)
218+
Nᴴ = @constinferred right_null(t; alg = :svd)
216219
@test isisometric(Nᴴ; side = :right)
217220
@test norm(t * Nᴴ') 0 atol = 100 * eps(norm(t))
218221
end
@@ -386,5 +389,62 @@ for V in spacelist
386389
@test cond(t) λmax / λmin
387390
end
388391
end
392+
393+
@testset "Hermitian projections" begin
394+
for T in eltypes,
395+
t in (
396+
curand(T, V1, V1), curand(T, W, W), curand(T, W, W)',
397+
CuDiagonalTensorMap(curand(T, reduceddim(V1)), V1),
398+
)
399+
normalize!(t)
400+
noisefactor = eps(real(T))^(3 / 4)
401+
402+
th = (t + t') / 2
403+
ta = (t - t') / 2
404+
tc = copy(t)
405+
406+
th′ = @constinferred project_hermitian(t)
407+
@test ishermitian(th′)
408+
@test th′ th
409+
@test t == tc
410+
th_approx = th + noisefactor * ta
411+
@test !ishermitian(th_approx) || (T <: Real && t isa CuDiagonalTensorMap)
412+
@test ishermitian(th_approx; atol = 10 * noisefactor)
413+
414+
ta′ = project_antihermitian(t)
415+
@test isantihermitian(ta′)
416+
@test ta′ ta
417+
@test t == tc
418+
ta_approx = ta + noisefactor * th
419+
@test !isantihermitian(ta_approx)
420+
@test isantihermitian(ta_approx; atol = 10 * noisefactor) || (T <: Real && t isa CuDiagonalTensorMap)
421+
end
422+
end
423+
424+
@testset "Isometric projections" begin
425+
for T in eltypes,
426+
t in (
427+
curandn(T, W, W), curandn(T, W, W)',
428+
curandn(T, W, V1), curandn(T, V1, W)',
429+
)
430+
t2 = project_isometric(t)
431+
@test isisometric(t2)
432+
t3 = project_isometric(t2)
433+
@test t3 t2 # stability of the projection
434+
@test t2 * (t2' * t) t
435+
436+
tc = similar(t)
437+
t3 = @constinferred project_isometric!(copy!(tc, t), t2)
438+
@test t3 === t2
439+
@test isisometric(t2)
440+
441+
# test that t2 is closer to A then any other isometry
442+
for k in 1:10
443+
δt = randn!(similar(t))
444+
t3 = project_isometric(t + δt / 100)
445+
@test norm(t - t3) > norm(t - t2)
446+
end
447+
end
448+
end
389449
end
390450
end

0 commit comments

Comments
 (0)