Skip to content

Commit 9cc5696

Browse files
committed
fix dim and revert unnecessary Int converts
1 parent f73b24a commit 9cc5696

File tree

5 files changed

+22
-22
lines changed

5 files changed

+22
-22
lines changed

src/spaces/gradedspace.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ Base.hash(V::GradedSpace, h::UInt) = hash(V.dual, hash(V.dims, h))
9292
field(::Type{<:GradedSpace}) =
9393
InnerProductStyle(::Type{<:GradedSpace}) = EuclideanInnerProduct()
9494
function dim(V::GradedSpace)
95-
T = promote_type(Int, real(sectorscalartype(sectortype(V))))
96-
return sum(dim(V, c) * dim(c) for c in sectors(V); init = zero(T))
95+
s = sectors(V)
96+
return isempty(s) ? dim(first(allunits(sectortype(V)))) * 0 : sum(c -> dim(c) * dim(V, c), s)
9797
end
9898
function dim(V::GradedSpace{I, <:AbstractDict}, c::I) where {I <: Sector}
9999
return get(V.dims, isdual(V) ? dual(c) : c, 0)

src/tensors/abstracttensor.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ function Base.convert(::Type{Array}, t::AbstractTensorMap)
617617
dom = domain(t)
618618
T = sectorscalartype(I) <: Complex ? complex(scalartype(t)) :
619619
sectorscalartype(I) <: Integer ? scalartype(t) : float(scalartype(t))
620-
A = zeros(T, Int.(dims(t))...)
620+
A = zeros(T, dims(t)...)
621621
for (f₁, f₂) in fusiontrees(t)
622622
F = convert(Array, (f₁, f₂))
623623
Aslice = StridedView(A)[axes(cod, f₁.uncoupled)..., axes(dom, f₂.uncoupled)...]

src/tensors/tensor.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ function TensorMap(
333333
# dimension check
334334
codom = codomain(V)
335335
dom = domain(V)
336-
arraysize = Int.(dims(V))
336+
arraysize = dims(V)
337337
matsize = (dim(codom), dim(dom))
338338

339339
if !(size(data) == arraysize || size(data) == matsize)

test/tensors/factorizations.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ for V in spacelist
237237
@test isisometry(Vᴴ; side = :right)
238238

239239
#FIXME: dimension of S is a float, might be a real issue if it's a decimal
240-
trunc = truncrank(Int(dim(domain(S)) ÷ 2))
240+
trunc = truncrank(dim(domain(S)) ÷ 2)
241241
U1, S1, Vᴴ1 = @constinferred svd_trunc(t; trunc)
242242
@test t * Vᴴ1' U1 * S1
243243
@test isisometry(U1)
@@ -269,7 +269,7 @@ for V in spacelist
269269
@test isisometry(Vᴴ4; side = :right)
270270
@test norm(t - U4 * S4 * Vᴴ4) <= 0.5
271271

272-
trunc = truncrank(Int(dim(domain(S)) ÷ 2)) & trunctol(; atol = λ - 10eps(λ))
272+
trunc = truncrank(dim(domain(S)) ÷ 2) & trunctol(; atol = λ - 10eps(λ))
273273
U5, S5, Vᴴ5 = @constinferred svd_trunc(t; trunc)
274274
@test t * Vᴴ5' U5 * S5
275275
@test isisometry(U5)
@@ -299,7 +299,7 @@ for V in spacelist
299299
@test @constinferred isposdef(vdv)
300300
t isa DiagonalTensorMap || @test !isposdef(t) # unlikely for non-hermitian map
301301

302-
d, v = @constinferred eig_trunc(t; trunc = truncrank(Int(dim(domain(t)) ÷ 2)))
302+
d, v = @constinferred eig_trunc(t; trunc = truncrank(dim(domain(t)) ÷ 2))
303303
@test t * v v * d
304304
@test dim(domain(d)) dim(domain(t)) ÷ 2
305305

@@ -330,7 +330,7 @@ for V in spacelist
330330
@test isposdef(t - λ * one(t) + 0.1 * one(t))
331331
@test !isposdef(t - λ * one(t) - 0.1 * one(t))
332332

333-
d, v = @constinferred eigh_trunc(t; trunc = truncrank(Int(dim(domain(t)) ÷ 2)))
333+
d, v = @constinferred eigh_trunc(t; trunc = truncrank(dim(domain(t)) ÷ 2))
334334
@test t * v v * d
335335
@test dim(domain(d)) dim(domain(t)) ÷ 2
336336
end

test/tensors/tensors.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,14 @@ for V in spacelist
8585
t = @constinferred randn(T, W)
8686
end
8787
a = @constinferred convert(Array, t)
88-
b = reshape(a, Int(dim(codomain(W))), Int(dim(domain(W)))) # no init in dim makes reshape error for su2
88+
b = reshape(a, dim(codomain(W)), dim(domain(W)))
8989
@test t @constinferred TensorMap(a, W)
9090
@test t @constinferred TensorMap(b, W)
9191
@test t === @constinferred TensorMap(t.data, W)
9292
end
9393
end
9494
for T in (Int, Float32, ComplexF64)
95-
t = randn(T, V1 V2 zerospace(V1)) # no init in dim makes zerospace call error for z2
95+
t = randn(T, V1 V2 zerospace(V1))
9696
a = convert(Array, t)
9797
@test norm(a) == 0
9898
end
@@ -427,8 +427,8 @@ for V in spacelist
427427
t1 = rand(T, W1 W1)
428428
t2 = rand(T, W2, W2)
429429
t = rand(T, W1 W2)
430-
d1 = Int(dim(W1))
431-
d2 = Int(dim(W2))
430+
d1 = dim(W1)
431+
d2 = dim(W2)
432432
At1 = reshape(convert(Array, t1), d1, d1)
433433
At2 = reshape(convert(Array, t2), d2, d2)
434434
At = reshape(convert(Array, t), d1, d2)
@@ -469,7 +469,7 @@ for V in spacelist
469469
W = V1 V2
470470
for T in (Float64, ComplexF64)
471471
t = randn(T, W, W)
472-
s = Int(dim(W))
472+
s = dim(W)
473473
expt = @constinferred exp(t)
474474
@test reshape(convert(Array, expt), (s, s))
475475
exp(reshape(convert(Array, t), (s, s)))
@@ -529,7 +529,7 @@ for V in spacelist
529529
@test norm(tA * t + t * tB + tC) <
530530
(norm(tA) + norm(tB) + norm(tC)) * eps(real(T))^(2 / 3)
531531
if BraidingStyle(I) isa Bosonic && hasfusiontensor(I)
532-
matrix(x) = reshape(convert(Array, x), Int(dim(codomain(x))), Int(dim(domain(x))))
532+
matrix(x) = reshape(convert(Array, x), dim(codomain(x)), dim(domain(x)))
533533
@test matrix(t) sylvester(matrix(tA), matrix(tB), matrix(tC))
534534
end
535535
end
@@ -553,10 +553,10 @@ for V in spacelist
553553
t1 = rand(T, V2 V3 V1, V1)
554554
t2 = rand(T, V2 V1 V3, V2)
555555
t = @constinferred (t1 t2)
556-
d1 = Int(dim(codomain(t1)))
557-
d2 = Int(dim(codomain(t2)))
558-
d3 = Int(dim(domain(t1)))
559-
d4 = Int(dim(domain(t2)))
556+
d1 = dim(codomain(t1))
557+
d2 = dim(codomain(t2))
558+
d3 = dim(domain(t1))
559+
d4 = dim(domain(t2))
560560
At = convert(Array, t)
561561
@test reshape(At, (d1, d2, d3, d4))
562562
reshape(convert(Array, t1), (d1, 1, d3, 1)) .*
@@ -618,10 +618,10 @@ end
618618
t1 = rand(T, V1 V2, V3' V4)
619619
t2 = rand(T, W2, W1 W1')
620620
t = @constinferred (t1 t2)
621-
d1 = Int(dim(codomain(t1)))
622-
d2 = Int(dim(codomain(t2)))
623-
d3 = Int(dim(domain(t1)))
624-
d4 = Int(dim(domain(t2)))
621+
d1 = dim(codomain(t1))
622+
d2 = dim(codomain(t2))
623+
d3 = dim(domain(t1))
624+
d4 = dim(domain(t2))
625625
At = convert(Array, t)
626626
@test reshape(At, (d1, d2, d3, d4))
627627
reshape(convert(Array, t1), (d1, 1, d3, 1)) .*

0 commit comments

Comments
 (0)