Skip to content

Commit d4d6fe7

Browse files
committed
use the const TK in tests where appropriate
1 parent 084ab72 commit d4d6fe7

File tree

5 files changed

+70
-70
lines changed

5 files changed

+70
-70
lines changed

test/ad.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ function ChainRulesTestUtils.test_approx(actual::AbstractTensorMap,
2929
end
3030

3131
# make sure that norms are computed correctly:
32-
function FiniteDifferences.to_vec(t::TensorKit.SectorDict)
32+
function FiniteDifferences.to_vec(t::TK.SectorDict)
3333
T = scalartype(valtype(t))
3434
vec = mapreduce(vcat, t; init=T[]) do (c, b)
3535
return reshape(b, :) .* sqrt(dim(c))
@@ -39,7 +39,7 @@ function FiniteDifferences.to_vec(t::TensorKit.SectorDict)
3939
function from_vec(x_real)
4040
x = T <: Real ? x_real : reinterpret(T, x_real)
4141
ctr = 0
42-
return TensorKit.SectorDict(c => (n = length(b);
42+
return TK.SectorDict(c => (n = length(b);
4343
b′ = reshape(view(x, ctr .+ (1:n)), size(b)) ./
4444
sqrt(dim(c));
4545
ctr += n;
@@ -61,25 +61,25 @@ end
6161

6262
# rrules for functions that destroy inputs
6363
# ----------------------------------------
64-
function ChainRulesCore.rrule(::typeof(TensorKit.tsvd), args...; kwargs...)
64+
function ChainRulesCore.rrule(::typeof(TK.tsvd), args...; kwargs...)
6565
return ChainRulesCore.rrule(tsvd!, args...; kwargs...)
6666
end
6767
function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals), args...; kwargs...)
6868
return ChainRulesCore.rrule(svdvals!, args...; kwargs...)
6969
end
70-
function ChainRulesCore.rrule(::typeof(TensorKit.eig), args...; kwargs...)
70+
function ChainRulesCore.rrule(::typeof(TK.eig), args...; kwargs...)
7171
return ChainRulesCore.rrule(eig!, args...; kwargs...)
7272
end
73-
function ChainRulesCore.rrule(::typeof(TensorKit.eigh), args...; kwargs...)
73+
function ChainRulesCore.rrule(::typeof(TK.eigh), args...; kwargs...)
7474
return ChainRulesCore.rrule(eigh!, args...; kwargs...)
7575
end
7676
function ChainRulesCore.rrule(::typeof(LinearAlgebra.eigvals), args...; kwargs...)
7777
return ChainRulesCore.rrule(eigvals!, args...; kwargs...)
7878
end
79-
function ChainRulesCore.rrule(::typeof(TensorKit.leftorth), args...; kwargs...)
79+
function ChainRulesCore.rrule(::typeof(TK.leftorth), args...; kwargs...)
8080
return ChainRulesCore.rrule(leftorth!, args...; kwargs...)
8181
end
82-
function ChainRulesCore.rrule(::typeof(TensorKit.rightorth), args...; kwargs...)
82+
function ChainRulesCore.rrule(::typeof(TK.rightorth), args...; kwargs...)
8383
return ChainRulesCore.rrule(rightorth!, args...; kwargs...)
8484
end
8585

@@ -134,7 +134,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
134134
ℂ[FibonacciAnyon](:I => 2, => 3),
135135
ℂ[FibonacciAnyon](:I => 2, => 2)))
136136

137-
@timedtestset "Automatic Differentiation with spacetype $(TensorKit.type_repr(eltype(V)))" verbose = true for V in
137+
@timedtestset "Automatic Differentiation with spacetype $(TK.type_repr(eltype(V)))" verbose = true for V in
138138
Vlist
139139
eltypes = isreal(sectortype(eltype(V))) ? (Float64, ComplexF64) : (ComplexF64,)
140140
symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding
@@ -149,9 +149,9 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
149149

150150
test_rrule(copy, T1)
151151
test_rrule(copy, T2)
152-
test_rrule(TensorKit.copy_oftype, T1, ComplexF64)
152+
test_rrule(TK.copy_oftype, T1, ComplexF64)
153153
if symmetricbraiding
154-
test_rrule(TensorKit.permutedcopy_oftype, T1, ComplexF64, ((3, 1), (2, 4)))
154+
test_rrule(TK.permutedcopy_oftype, T1, ComplexF64, ((3, 1), (2, 4)))
155155

156156
test_rrule(convert, Array, T1)
157157
test_rrule(TensorMap, convert(Array, T1), codomain(T1), domain(T1);
@@ -364,13 +364,13 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
364364
H = (H + H') / 2
365365
atol = precision(T)
366366

367-
for alg in (TensorKit.QR(), TensorKit.QRpos())
367+
for alg in (TK.QR(), TK.QRpos())
368368
test_rrule(leftorth, A; fkwargs=(; alg=alg), atol)
369369
test_rrule(leftorth, B; fkwargs=(; alg=alg), atol)
370370
test_rrule(leftorth, C; fkwargs=(; alg=alg), atol)
371371
end
372372

373-
for alg in (TensorKit.LQ(), TensorKit.LQpos())
373+
for alg in (TK.LQ(), TK.LQpos())
374374
test_rrule(rightorth, A; fkwargs=(; alg=alg), atol)
375375
test_rrule(rightorth, B; fkwargs=(; alg=alg), atol)
376376
test_rrule(rightorth, C; fkwargs=(; alg=alg), atol)
@@ -428,7 +428,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
428428
T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V)
429429
test_rrule(tsvd, B; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0))
430430

431-
Vtrunc = spacetype(S)(TensorKit.SectorDict(c => ceil(Int, size(b, 1) / 2)
431+
Vtrunc = spacetype(S)(TK.SectorDict(c => ceil(Int, size(b, 1) / 2)
432432
for (c, b) in blocks(S)))
433433

434434
U, S, V, ϵ = tsvd(B; trunc=truncspace(Vtrunc))
@@ -447,7 +447,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
447447
T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V)
448448
test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0))
449449

450-
c, = TensorKit.MatrixAlgebra._argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])),
450+
c, = TK.MatrixAlgebra._argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])),
451451
blocks(S))
452452
trunc = truncdim(round(Int, 2 * dim(c)))
453453
U, S, V, ϵ = tsvd(C; trunc)

test/diagonal.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3),
3030
b2 = @constinferred block(t, first(blocksectors(t)))
3131
@test b1 == b2
3232
@test eltype(bs) === Pair{typeof(c),typeof(b1)}
33-
@test typeof(b1) === TensorKit.blocktype(t)
33+
@test typeof(b1) === TK.blocktype(t)
3434
# basic linear algebra
3535
@test isa(@constinferred(norm(t)), real(T))
3636
@test norm(t)^2 dot(t, t)
@@ -201,7 +201,7 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3),
201201
zip(values(LinearAlgebra.eigvals(D)),
202202
values(LinearAlgebra.eigvals(t))))
203203
end
204-
@testset "leftorth with $alg" for alg in (TensorKit.QR(), TensorKit.QL())
204+
@testset "leftorth with $alg" for alg in (TK.QR(), TK.QL())
205205
Q, R = @constinferred leftorth(t; alg=alg)
206206
QdQ = Q' * Q
207207
@test QdQ one(QdQ)
@@ -210,7 +210,7 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3),
210210
@test isposdef(R)
211211
end
212212
end
213-
@testset "rightorth with $alg" for alg in (TensorKit.RQ(), TensorKit.LQ())
213+
@testset "rightorth with $alg" for alg in (TK.RQ(), TK.LQ())
214214
L, Q = @constinferred rightorth(t; alg=alg)
215215
QQd = Q * Q'
216216
@test QQd one(QQd)
@@ -219,7 +219,7 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3),
219219
@test isposdef(L)
220220
end
221221
end
222-
@testset "tsvd with $alg" for alg in (TensorKit.SVD(), TensorKit.SDD())
222+
@testset "tsvd with $alg" for alg in (TK.SVD(), TK.SDD())
223223
U, S, Vᴴ = @constinferred tsvd(t; alg=alg)
224224
UdU = U' * U
225225
@test UdU one(UdU)

test/fusiontrees.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ println("------------------------------------")
22
println("Fusion Trees")
33
println("------------------------------------")
44
ti = time()
5-
@timedtestset "Fusion trees for $(TensorKit.type_repr(I))" verbose = true for I in
5+
@timedtestset "Fusion trees for $(TK.type_repr(I))" verbose = true for I in
66
sectorlist
77

8-
Istr = TensorKit.type_repr(I)
8+
Istr = TK.type_repr(I)
99
N = 5
1010
out = ntuple(n -> randsector(I), N)
1111
isdual = ntuple(n -> rand(Bool), N)
@@ -433,12 +433,12 @@ ti = time()
433433
ip = invperm(p)
434434
ip1, ip2 = ip[1:N], ip[(N + 1):(2N)]
435435

436-
d = @constinferred TensorKit.permute(f1, f2, p1, p2)
436+
d = @constinferred TK.permute(f1, f2, p1, p2)
437437
@test dim(incoming)
438438
sum(abs2(coef) * dim(f1.coupled) for ((f1, f2), coef) in d)
439439
d2 = Dict{typeof((f1, f2)),valtype(d)}()
440440
for ((f1′, f2′), coeff) in d
441-
d′ = TensorKit.permute(f1′, f2′, ip1, ip2)
441+
d′ = TK.permute(f1′, f2′, ip1, ip2)
442442
for ((f1′′, f2′′), coeff2) in d′
443443
d2[(f1′′, f2′′)] = get(d2, (f1′′, f2′′), zero(coeff)) +
444444
coeff2 * coeff
@@ -567,7 +567,7 @@ ti = time()
567567
end
568568
end
569569
end
570-
TensorKit.empty_globalcaches!()
570+
TK.empty_globalcaches!()
571571
end
572572
tf = time()
573573
printstyled("Finished fusion tree tests in ",

test/spaces.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,11 @@ println("------------------------------------")
6464
@test @constinferred(sectortype(V)) == Trivial
6565
@test ((@constinferred sectors(V))...,) == (Trivial(),)
6666
@test length(sectors(V)) == 1
67-
@test @constinferred(TensorKit.hassector(V, Trivial()))
67+
@test @constinferred(TK.hassector(V, Trivial()))
6868
@test @constinferred(dim(V)) == d == @constinferred(dim(V, Trivial()))
6969
@test dim(@constinferred(zerospace(V))) == 0
7070
@test (sectors(zerospace(V))...,) == ()
71-
@test @constinferred(TensorKit.axes(V)) == Base.OneTo(d)
71+
@test @constinferred(TK.axes(V)) == Base.OneTo(d)
7272
@test^d == ℝ[](d) == CartesianSpace(d) == typeof(V)(d)
7373
W = @constinferred^1
7474
@test @constinferred(unitspace(V)) == W == unitspace(typeof(V))
@@ -111,11 +111,11 @@ println("------------------------------------")
111111
@test @constinferred(sectortype(V)) == Trivial
112112
@test ((@constinferred sectors(V))...,) == (Trivial(),)
113113
@test length(sectors(V)) == 1
114-
@test @constinferred(TensorKit.hassector(V, Trivial()))
114+
@test @constinferred(TK.hassector(V, Trivial()))
115115
@test @constinferred(dim(V)) == d == @constinferred(dim(V, Trivial()))
116116
@test dim(@constinferred(zerospace(V))) == 0
117117
@test (sectors(zerospace(V))...,) == ()
118-
@test @constinferred(TensorKit.axes(V)) == Base.OneTo(d)
118+
@test @constinferred(TK.axes(V)) == Base.OneTo(d)
119119
@test^d == Vect[Trivial](d) == Vect[](Trivial() => d) == ℂ[](d) == typeof(V)(d)
120120
W = @constinferred^1
121121
@test @constinferred(unitspace(V)) == W == unitspace(typeof(V))
@@ -153,10 +153,10 @@ println("------------------------------------")
153153
@test isdual(V')
154154
@test !isdual(conj(V))
155155
@test isdual(conj(V'))
156-
@test !TensorKit.isconj(V)
157-
@test !TensorKit.isconj(V')
158-
@test TensorKit.isconj(conj(V))
159-
@test TensorKit.isconj(conj(V'))
156+
@test !TK.isconj(V)
157+
@test !TK.isconj(V')
158+
@test TK.isconj(conj(V))
159+
@test TK.isconj(conj(V'))
160160
@test isa(V, VectorSpace)
161161
@test isa(V, ElementarySpace)
162162
@test !isa(InnerProductStyle(V), HasInnerProduct)
@@ -165,20 +165,20 @@ println("------------------------------------")
165165
@test @constinferred(dual(V)) != @constinferred(conj(V)) != V
166166
@test @constinferred(field(V)) ==
167167
@test @constinferred(sectortype(V)) == Trivial
168-
@test @constinferred(TensorKit.hassector(V, Trivial()))
168+
@test @constinferred(TK.hassector(V, Trivial()))
169169
@test @constinferred(dim(V)) == d == @constinferred(dim(V, Trivial()))
170-
@test @constinferred(TensorKit.axes(V)) == Base.OneTo(d)
170+
@test @constinferred(TK.axes(V)) == Base.OneTo(d)
171171
end
172172

173-
@timedtestset "ElementarySpace: $(TensorKit.type_repr(Vect[I]))" for I in sectorlist
173+
@timedtestset "ElementarySpace: $(TK.type_repr(Vect[I]))" for I in sectorlist
174174
if Base.IteratorSize(values(I)) === Base.IsInfinite()
175175
set = unique(vcat(unit(I), [randsector(I) for k in 1:10]))
176176
gen = (c => 2 for c in set)
177177
else
178178
gen = (values(I)[k] => (k + 1) for k in 1:length(values(I)))
179179
end
180180
V = GradedSpace(gen)
181-
@test eval(Meta.parse(TensorKit.type_repr(typeof(V)))) == typeof(V)
181+
@test eval(Meta.parse(TK.type_repr(typeof(V)))) == typeof(V)
182182
@test eval(Meta.parse(sprint(show, V))) == V
183183
@test eval(Meta.parse(sprint(show, V'))) == V'
184184
@test V' == GradedSpace(gen; dual=true)
@@ -225,12 +225,12 @@ println("------------------------------------")
225225
@test @constinferred(field(V)) ==
226226
@test @constinferred(sectortype(V)) == I
227227
slist = @constinferred sectors(V)
228-
@test @constinferred(TensorKit.hassector(V, first(slist)))
228+
@test @constinferred(TK.hassector(V, first(slist)))
229229
@test @constinferred(dim(V)) == sum(dim(s) * dim(V, s) for s in slist)
230230
@test @constinferred(reduceddim(V)) == sum(dim(V, s) for s in slist)
231231
@constinferred dim(V, first(slist))
232232
if hasfusiontensor(I)
233-
@test @constinferred(TensorKit.axes(V)) == Base.OneTo(dim(V))
233+
@test @constinferred(TK.axes(V)) == Base.OneTo(dim(V))
234234
end
235235
@test @constinferred((V, zerospace(V))) == V
236236
@test @constinferred((V, V)) == Vect[I](c => 2dim(V, c) for c in sectors(V))
@@ -407,7 +407,7 @@ println("------------------------------------")
407407

408408
@timedtestset "HomSpace" begin
409409
for (V1, V2, V3, V4, V5) in (Vtr, Vℤ₃, VSU₂)
410-
W = TensorKit.HomSpace(V1 V2, V3 V4 V5)
410+
W = TK.HomSpace(V1 V2, V3 V4 V5)
411411
@test W == (V3 V4 V5 V1 V2)
412412
@test W == (V1 V2 V3 V4 V5)
413413
@test W' == (V1 V2 V3 V4 V5)
@@ -424,7 +424,7 @@ println("------------------------------------")
424424
@test W == deepcopy(W)
425425
@test W == @constinferred permute(W, ((1, 2), (3, 4, 5)))
426426
@test permute(W, ((2, 4, 5), (3, 1))) == (V2 V4' V5' V3 V1')
427-
@test (V1 V2 V1 V2) == @constinferred TensorKit.compose(W, W')
427+
@test (V1 V2 V1 V2) == @constinferred TK.compose(W, W')
428428
@test (V1 V2 V3 V4 V5 unitspace(V5)) ==
429429
@constinferred(insertleftunit(W)) ==
430430
@constinferred(insertrightunit(W))
@@ -444,5 +444,5 @@ println("------------------------------------")
444444
@test_throws BoundsError insertleftunit(one(V1) V1, 0)
445445
end
446446
end
447-
TensorKit.empty_globalcaches!()
447+
TK.empty_globalcaches!()
448448
end

0 commit comments

Comments
 (0)