1- using LinearAlgebra, CUDA, Test, TestExtras, TensorKit, cuTENSOR
1+ using Test, TestExtras
2+ using TensorKit
3+ using LinearAlgebra: LinearAlgebra
4+ using CUDA, cuTENSOR
25
36const CUDAExt = Base. get_extension (TensorKit, :TensorKitCUDAExt )
47@assert ! isnothing (CUDAExt)
@@ -55,15 +58,15 @@ for V in spacelist
5558 @test Q * R ≈ t
5659 @test isisometric (Q)
5760
58- Q, R = @constinferred left_orth (t; kind = :qr )
61+ Q, R = @constinferred left_orth (t)
5962 @test Q * R ≈ t
6063 @test isisometric (Q)
6164
6265 N = @constinferred qr_null (t)
6366 @test isisometric (N)
6467 @test norm (N' * t) ≈ 0 atol = 100 * eps (norm (t))
6568
66- N = @constinferred left_null (t; kind = :qr )
69+ N = @constinferred left_null (t)
6770 @test isisometric (N)
6871 @test norm (N' * t) ≈ 0 atol = 100 * eps (norm (t))
6972 end
@@ -82,7 +85,7 @@ for V in spacelist
8285 @test isisometric (Q)
8386 @test dim (Q) == dim (R) == dim (t)
8487
85- Q, R = @constinferred left_orth (t; kind = :qr )
88+ Q, R = @constinferred left_orth (t)
8689 @test Q * R ≈ t
8790 @test isisometric (Q)
8891 @test dim (Q) == dim (R) == dim (t)
@@ -107,7 +110,7 @@ for V in spacelist
107110 @test L * Q ≈ t
108111 @test isisometric (Q; side = :right )
109112
110- L, Q = @constinferred right_orth (t; kind = :lq )
113+ L, Q = @constinferred right_orth (t)
111114 @test L * Q ≈ t
112115 @test isisometric (Q; side = :right )
113116
@@ -130,7 +133,7 @@ for V in spacelist
130133 @test isisometric (Q; side = :right )
131134 @test dim (Q) == dim (L) == dim (t)
132135
133- L, Q = @constinferred right_orth (t; kind = :lq )
136+ L, Q = @constinferred right_orth (t)
134137 @test L * Q ≈ t
135138 @test isisometric (Q; side = :right )
136139 @test dim (Q) == dim (L) == dim (t)
@@ -157,7 +160,7 @@ for V in spacelist
157160 # broken for T <: Complex
158161 @test isposdef (p)
159162
160- w, p = @constinferred left_orth (t; kind = :polar )
163+ w, p = @constinferred left_orth (t; alg = :polar )
161164 @test w * p ≈ t
162165 @test isisometric (w)
163166 end
@@ -174,7 +177,7 @@ for V in spacelist
174177 @test isisometric (wᴴ; side = :right )
175178 @test isposdef (p)
176179
177- p, wᴴ = @constinferred right_orth (t; kind = :polar )
180+ p, wᴴ = @constinferred right_orth (t; alg = :polar )
178181 @test p * wᴴ ≈ t
179182 @test isisometric (wᴴ; side = :right )
180183 end
@@ -204,15 +207,15 @@ for V in spacelist
204207 @test b ≈ s′[c]
205208 end
206209
207- v, c = @constinferred left_orth (t; kind = :svd )
210+ v, c = @constinferred left_orth (t; alg = :svd )
208211 @test v * c ≈ t
209212 @test isisometric (v)
210213
211- N = @constinferred left_null (t; kind = :svd )
214+ N = @constinferred left_null (t; alg = :svd )
212215 @test isisometric (N)
213216 @test norm (N' * t) ≈ 0 atol = 100 * eps (norm (t))
214217
215- Nᴴ = @constinferred right_null (t; kind = :svd )
218+ Nᴴ = @constinferred right_null (t; alg = :svd )
216219 @test isisometric (Nᴴ; side = :right )
217220 @test norm (t * Nᴴ' ) ≈ 0 atol = 100 * eps (norm (t))
218221 end
@@ -386,5 +389,62 @@ for V in spacelist
386389 @test cond (t) ≈ λmax / λmin
387390 end
388391 end
392+
393+ @testset " Hermitian projections" begin
394+ for T in eltypes,
395+ t in (
396+ curand (T, V1, V1), curand (T, W, W), curand (T, W, W)' ,
397+ CuDiagonalTensorMap (curand (T, reduceddim (V1)), V1),
398+ )
399+ normalize! (t)
400+ noisefactor = eps (real (T))^ (3 / 4 )
401+
402+ th = (t + t' ) / 2
403+ ta = (t - t' ) / 2
404+ tc = copy (t)
405+
406+ th′ = @constinferred project_hermitian (t)
407+ @test ishermitian (th′)
408+ @test th′ ≈ th
409+ @test t == tc
410+ th_approx = th + noisefactor * ta
411+ @test ! ishermitian (th_approx) || (T <: Real && t isa CuDiagonalTensorMap)
412+ @test ishermitian (th_approx; atol = 10 * noisefactor)
413+
414+ ta′ = project_antihermitian (t)
415+ @test isantihermitian (ta′)
416+ @test ta′ ≈ ta
417+ @test t == tc
418+ ta_approx = ta + noisefactor * th
419+ @test ! isantihermitian (ta_approx)
420+ @test isantihermitian (ta_approx; atol = 10 * noisefactor) || (T <: Real && t isa CuDiagonalTensorMap)
421+ end
422+ end
423+
424+ @testset " Isometric projections" begin
425+ for T in eltypes,
426+ t in (
427+ curandn (T, W, W), curandn (T, W, W)' ,
428+ curandn (T, W, V1), curandn (T, V1, W)' ,
429+ )
430+ t2 = project_isometric (t)
431+ @test isisometric (t2)
432+ t3 = project_isometric (t2)
433+ @test t3 ≈ t2 # stability of the projection
434+ @test t2 * (t2' * t) ≈ t
435+
436+ tc = similar (t)
437+ t3 = @constinferred project_isometric! (copy! (tc, t), t2)
438+ @test t3 === t2
439+ @test isisometric (t2)
440+
441+ # test that t2 is closer to A then any other isometry
442+ for k in 1 : 10
443+ δt = randn! (similar (t))
444+ t3 = project_isometric (t + δt / 100 )
445+ @test norm (t - t3) > norm (t - t2)
446+ end
447+ end
448+ end
389449 end
390450end
0 commit comments