Skip to content

Commit 9e78a03

Browse files
authored
Various improvements and fixes (#149)
* Improvements for BraidingTensor * Fix catdomain and catcodomain * scalartype from AbstractTensorMap type instead of storagetype * improve type stability allocator * add `similar(::Type{AbstractTensorMap}, ...)`
1 parent 0acfb13 commit 9e78a03

File tree

7 files changed

+56
-10
lines changed

7 files changed

+56
-10
lines changed

src/TensorKit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ export CompositeSpace, ProductSpace # composite spaces
2626
export FusionTree
2727
export IndexSpace, TensorSpace, TensorMapSpace
2828
export AbstractTensorMap, AbstractTensor, TensorMap, Tensor, TrivialTensorMap # tensors and tensor properties
29+
export BraidingTensor
2930
export TruncationScheme
3031
export SpaceMismatch, SectorMismatch, IndexError # error types
3132

src/tensors/abstracttensor.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,15 @@ function Base.similar(::AbstractTensorMap, ::Type{TorA},
278278
return TT(undef, codomain(P), domain(P))
279279
end
280280

281+
# implementation in type-domain
282+
function Base.similar(::Type{TT}, P::TensorMapSpace) where {TT<:AbstractTensorMap}
283+
return TensorMap{scalartype(TT)}(undef, P)
284+
end
285+
function Base.similar(::Type{TT}, cod::TensorSpace{S},
286+
dom::TensorSpace{S}) where {TT<:AbstractTensorMap,S}
287+
return TensorMap{scalartype(TT)}(undef, cod, dom)
288+
end
289+
281290
# Equality and approximality
282291
#----------------------------
283292
function Base.:(==)(t1::AbstractTensorMap, t2::AbstractTensorMap)

src/tensors/braidingtensor.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ struct BraidingTensor{T,S} <: AbstractTensorMap{T,S,2,2}
2828
# partial construction: only construct rowr and colr when needed
2929
end
3030
end
31+
function BraidingTensor{T}(V1::S, V2::S, adjoint::Bool=false) where {T,S<:IndexSpace}
32+
return BraidingTensor{T,S}(V1, V2, adjoint)
33+
end
3134
function BraidingTensor(V1::S, V2::S, adjoint::Bool=false) where {S<:IndexSpace}
3235
if BraidingStyle(sectortype(S)) isa SymmetricBraiding
3336
return BraidingTensor{Float64,S}(V1, V2, adjoint)
@@ -38,7 +41,12 @@ end
3841
function BraidingTensor(V::HomSpace, adjoint::Bool=false)
3942
domain(V) == reverse(codomain(V)) ||
4043
throw(SpaceMismatch("Cannot define a braiding on $V"))
41-
return BraidingTensor(V[1], V[2], adjoint)
44+
return BraidingTensor(V[2], V[1], adjoint)
45+
end
46+
function BraidingTensor{T}(V::HomSpace, adjoint::Bool=false) where {T}
47+
domain(V) == reverse(codomain(V)) ||
48+
throw(SpaceMismatch("Cannot define a braiding on $V"))
49+
return BraidingTensor{T}(V[2], V[1], adjoint)
4250
end
4351
function Base.adjoint(b::BraidingTensor{T,S}) where {T,S}
4452
return BraidingTensor{T,S}(b.V1, b.V2, !b.adjoint)
@@ -54,6 +62,10 @@ blocksectors(b::BraidingTensor) = blocksectors(b.V1 ⊗ b.V2)
5462
hasblock(b::BraidingTensor, s::Sector) = s blocksectors(b)
5563

5664
function fusiontrees(b::BraidingTensor)
65+
if sectortype(b) === Trivial
66+
return ((nothing, nothing),)
67+
end
68+
5769
codom = codomain(b)
5870
dom = domain(b)
5971
I = sectortype(b)
@@ -71,7 +83,6 @@ function fusiontrees(b::BraidingTensor)
7183
offset1 = last(r)
7284
end
7385
end
74-
dim1 = offset1
7586
offset2 = 0
7687
for s2 in sectors(dom)
7788
for f₂ in fusiontrees(s2, c, map(isdual, dom.spaces))
@@ -80,7 +91,6 @@ function fusiontrees(b::BraidingTensor)
8091
offset2 = last(r)
8192
end
8293
end
83-
dim2 = offset2
8494
push!(rowr, c => rowrc)
8595
push!(colr, c => colrc)
8696
end
@@ -124,6 +134,10 @@ end
124134
return sreshape(StridedView(data), d)
125135
end
126136
end
137+
@inline function Base.getindex(b::BraidingTensor, ::Nothing, ::Nothing)
138+
sectortype(b) === Trivial || throw(SectorMismatch())
139+
return getindex(b)
140+
end
127141

128142
# efficient copy constructor
129143
Base.copy(b::BraidingTensor) = b

src/tensors/linalg.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
401401
end
402402

403403
# concatenate tensors
404-
function catdomain(t1::T, t2::T) where {S,N₁,T<:AbstractTensorMap{<:Any,S,N₁,1}}
404+
function catdomain(t1::TT, t2::TT) where {S,N₁,TT<:AbstractTensorMap{<:Any,S,N₁,1}}
405405
codomain(t1) == codomain(t2) ||
406406
throw(SpaceMismatch("codomains of tensors to concatenate must match:\n" *
407407
"$(codomain(t1))$(codomain(t2))"))
@@ -411,14 +411,15 @@ function catdomain(t1::T, t2::T) where {S,N₁,T<:AbstractTensorMap{<:Any,S,N₁
411411
throw(SpaceMismatch("cannot horizontally concatenate tensors whose domain has non-matching duality"))
412412

413413
V = V1 V2
414-
t = TensorMap(undef, promote_type(scalartype(t1), scalartype(t2)), codomain(t1), V)
414+
T = promote_type(scalartype(t1), scalartype(t2))
415+
t = TensorMap{T}(undef, codomain(t1), V)
415416
for c in sectors(V)
416417
block(t, c)[:, 1:dim(V1, c)] .= block(t1, c)
417418
block(t, c)[:, dim(V1, c) .+ (1:dim(V2, c))] .= block(t2, c)
418419
end
419420
return t
420421
end
421-
function catcodomain(t1::T, t2::T) where {S,N₂,T<:AbstractTensorMap{<:Any,S,1,N₂}}
422+
function catcodomain(t1::TT, t2::TT) where {S,N₂,TT<:AbstractTensorMap{<:Any,S,1,N₂}}
422423
domain(t1) == domain(t2) ||
423424
throw(SpaceMismatch("domains of tensors to concatenate must match:\n" *
424425
"$(domain(t1))$(domain(t2))"))
@@ -428,7 +429,8 @@ function catcodomain(t1::T, t2::T) where {S,N₂,T<:AbstractTensorMap{<:Any,S,1,
428429
throw(SpaceMismatch("cannot vertically concatenate tensors whose codomain has non-matching duality"))
429430

430431
V = V1 V2
431-
t = TensorMap(undef, promote_type(scalartype(t1), scalartype(t2)), V, domain(t1))
432+
T = promote_type(scalartype(t1), scalartype(t2))
433+
t = TensorMap{T}(undef, V, domain(t1))
432434
for c in sectors(V)
433435
block(t, c)[1:dim(V1, c), :] .= block(t1, c)
434436
block(t, c)[dim(V1, c) .+ (1:dim(V2, c)), :] .= block(t2, c)

src/tensors/tensoroperations.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ function TO.tensoralloc(::Type{TT}, structure::TensorMapSpace{S,N₁,N₂}, iste
2525
colr, coldims = _buildblockstructure(domain(structure), blocksectoriterator)
2626
A = storagetype(TT)
2727
blockallocator(c) = TO.tensoralloc(A, (rowdims[c], coldims[c]), istemp, allocator)
28-
data = SectorDict(c => blockallocator(c) for c in blocksectoriterator)
28+
data = SectorDict{sectortype(TT),A}(c => blockallocator(c) for c in blocksectoriterator)
2929
return TT(data, codomain(structure), domain(structure), rowr, colr)
3030
end
3131

@@ -127,7 +127,8 @@ function TO.tensorcontract_type(TC,
127127
B::AbstractTensorMap, ::Index2Tuple, ::Bool,
128128
::Index2Tuple{N₁,N₂}) where {N₁,N₂}
129129
M = similarstoragetype(A, TC)
130-
M == similarstoragetype(B, TC) || throw(ArgumentError("incompatible storage types"))
130+
M == similarstoragetype(B, TC) ||
131+
throw(ArgumentError("incompatible storage types:\n$(M)$(similarstoragetype(B, TC))"))
131132
spacetype(A) == spacetype(B) || throw(SpaceMismatch("incompatible space types"))
132133
return tensormaptype(spacetype(A), N₁, N₂, M)
133134
end

src/tensors/vectorinterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# scalartype
22
#------------
3-
VectorInterface.scalartype(T::Type{<:AbstractTensorMap}) = scalartype(storagetype(T))
3+
VectorInterface.scalartype(::Type{TT}) where {T,TT<:AbstractTensorMap{T}} = scalartype(T)
44

55
# zerovector & zerovector!!
66
#---------------------------

test/planar.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using TensorKit, TensorOperations, Test
2+
using TensorKit: BraidingTensor
23
using TensorKit: planaradd!, planartrace!, planarcontract!
34
using TensorKit: PlanarTrivial, ℙ
45

@@ -29,6 +30,24 @@ function force_planar(tsrc::TensorMap{<:Any,<:GradedSpace})
2930
return tdst
3031
end
3132

33+
@testset "Braiding tensor" begin
34+
V1 =^2 ^3 ^3 ^2
35+
t1 = @constinferred BraidingTensor(V1)
36+
@test space(t1) == V1
37+
@test codomain(t1) == codomain(V1)
38+
@test domain(t1) == domain(V1)
39+
@test scalartype(t1) == Float64
40+
@test storagetype(t1) == Matrix{Float64}
41+
t2 = @constinferred BraidingTensor{ComplexF64}(V1)
42+
@test scalartype(t2) == ComplexF64
43+
@test storagetype(t2) == Matrix{ComplexF64}
44+
45+
V2 =^2 ^3 ^2 ^3
46+
@test_throws SpaceMismatch BraidingTensor(V2)
47+
48+
@test adjoint(t1) isa BraidingTensor
49+
end
50+
3251
@testset "planar methods" verbose = true begin
3352
@testset "planaradd" begin
3453
A = randn(ℂ^2 ^3 ^6 ^5 ^4)

0 commit comments

Comments
 (0)