Skip to content

Commit 8b38973

Browse files
authored
Braidingtensor improvements (#179)
* Make braidingtensor behave * fix `treebraider` * expand BraidingTensor tests * Remove duplicate copy * rename `V1 -> W` * Apply suggestions from code review * Remove stackoverflow
1 parent f467a21 commit 8b38973

File tree

3 files changed

+86
-50
lines changed

3 files changed

+86
-50
lines changed

src/tensors/braidingtensor.jl

Lines changed: 7 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -78,22 +78,19 @@ end
7878
throw(SectorMismatch())
7979
end
8080
@inbounds begin
81-
d = (dims(V2 V1, f₁.uncoupled)..., dims(V1 V2, f₂.uncoupled)...)
81+
d = (dims(codomain(b), f₁.uncoupled)..., dims(domain(b), f₂.uncoupled)...)
8282
n1 = d[1] * d[2]
8383
n2 = d[3] * d[4]
84-
data = storagetype(b)(undef, (n1, n2))
84+
data = sreshape(StridedView(Matrix{eltype(b)}(undef, n1, n2)), d)
8585
fill!(data, zero(eltype(b)))
86-
a1, a2 = f₂.uncoupled
87-
if f₁.uncoupled == (a2, a1)
86+
if f₁.uncoupled == reverse(f₂.uncoupled)
8887
braiddict = artin_braid(f₂, 1; inv=b.adjoint)
8988
r = get(braiddict, f₁, zero(valtype(braiddict)))
90-
si = 1 + d[1] * d[2] * d[3]
91-
sj = d[1] + d[1] * d[2]
92-
@inbounds for i in 1:d[1], j in 1:d[2]
93-
data[(i - 1) * si + (j - 1) * sj + 1] = r
89+
@inbounds for i in axes(data, 1), j in axes(data, 2)
90+
data[i, j, j, i] = r
9491
end
9592
end
96-
return sreshape(StridedView(data), d)
93+
return data
9794
end
9895
end
9996
@inline function Base.getindex(b::BraidingTensor, ::Nothing, ::Nothing)
@@ -104,31 +101,9 @@ end
104101
# efficient copy constructor
105102
Base.copy(b::BraidingTensor) = b
106103

107-
function Base.copy!(t::TensorMap, b::BraidingTensor)
108-
space(t) == space(b) || throw(SectorMismatch())
109-
fill!(t, zero(scalartype(t)))
110-
for (f₁, f₂) in fusiontrees(t)
111-
data = t[f₁, f₂]
112-
if sectortype(t) == Trivial
113-
r = one(scalartype(t))
114-
else
115-
a1, a2 = f₂.uncoupled
116-
c = f₂.coupled
117-
f₁.uncoupled == (a2, a1) || continue
118-
braiddict = artin_braid(f₂, 1; inv=b.adjoint)
119-
r = convert(scalartype(t), get(braiddict, f₁, zero(valtype(braiddict))))
120-
end
121-
@inbounds for i in axes(data, 1), j in axes(data, 2)
122-
data[i, j, j, i] = r
123-
end
124-
end
125-
return t
126-
end
127104
TensorMap(b::BraidingTensor) = copy!(similar(b), b)
128105
Base.convert(::Type{TensorMap}, b::BraidingTensor) = TensorMap(b)
129106

130-
# TODO: fix this!
131-
# block(b::BraidingTensor, s::Sector) = block(TensorMap(b), s)
132107
function block(b::BraidingTensor, s::Sector)
133108
sectortype(b) == typeof(s) || throw(SectorMismatch())
134109

@@ -141,7 +116,7 @@ function block(b::BraidingTensor, s::Sector)
141116

142117
data = fill!(data, zero(eltype(b)))
143118

144-
V1, V2 = domain(b)
119+
V1, V2 = codomain(b)
145120
if sectortype(b) === Trivial
146121
d1, d2 = dim(V1), dim(V2)
147122
subblock = sreshape(StridedView(data), (d1, d2, d2, d1))
@@ -174,8 +149,6 @@ function block(b::BraidingTensor, s::Sector)
174149
return data
175150
end
176151

177-
blocks(b::BraidingTensor) = blocks(TensorMap(b))
178-
179152
# Index manipulations
180153
# -------------------
181154
has_shared_permute(t::BraidingTensor, ::Index2Tuple) = false

src/tensors/treetransformers.jl

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,34 @@ function TreeTransformer(transform::Function, Vsrc::HomSpace{S},
6969
end
7070
end
7171

72+
# braid is special because it has levels
73+
const treebraidercache = LRU{Any,Any}(; maxsize=10^5)
74+
const usetreebraidercache = Ref{Bool}(true)
75+
@noinline function _get_treebraider(A, key)
76+
d::A = get!(treebraidercache, key) do
77+
return _treebraider(key)
78+
end
79+
return d
80+
end
81+
function _treebraider((Vdst, Vsrc, p, levels))
82+
fusiontreebraider(f1, f2) = braid(f1, f2, levels..., p...)
83+
return TreeTransformer(fusiontreebraider, Vsrc, Vdst)
84+
end
85+
function treebraider(::AbstractTensorMap, ::AbstractTensorMap, p, levels)
86+
return fusiontreetransform(f1, f2) = braid(f1, f2, levels..., p...)
87+
end
88+
function treebraider(tdst::TensorMap, tsrc::TensorMap, p, levels)
89+
if usetreebraidercache[]
90+
key = (space(tdst), space(tsrc), p, levels)
91+
A = treetransformertype(space(tdst), space(tsrc))
92+
return _get_treebraider(A, key)
93+
else
94+
return _treebraider((space(tdst), space(tsrc), p, levels))
95+
end
96+
end
97+
7298
for (transform, transformer) in
73-
((:permute, :permuter), (:braid, :braider), (:transpose, :transposer))
99+
((:permute, :permuter), (:transpose, :transposer))
74100
treetransformcache = Symbol("tree", transformer, "cache")
75101
usetreetransformcache = Symbol("usetree", transformer, "cache")
76102
treetransformer = Symbol("tree", transformer)

test/planar.jl

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,59 @@ function force_planar(tsrc::TensorMap{<:Any,<:GradedSpace})
3030
return tdst
3131
end
3232

33+
Vtr = (ℂ^3,
34+
(ℂ^2)',
35+
^5,
36+
^6,
37+
(ℂ^7)')
38+
VU₁ = (ℂ[U1Irrep](0 => 1, 1 => 2, -1 => 2),
39+
ℂ[U1Irrep](0 => 3, 1 => 1, -1 => 1),
40+
ℂ[U1Irrep](0 => 2, 1 => 2, -1 => 1)',
41+
ℂ[U1Irrep](0 => 1, 1 => 2, -1 => 3),
42+
ℂ[U1Irrep](0 => 1, 1 => 3, -1 => 3)')
43+
VfU₁ = (ℂ[FermionNumber](0 => 1, 1 => 2, -1 => 2),
44+
ℂ[FermionNumber](0 => 3, 1 => 1, -1 => 1),
45+
ℂ[FermionNumber](0 => 2, 1 => 2, -1 => 1)',
46+
ℂ[FermionNumber](0 => 1, 1 => 2, -1 => 3),
47+
ℂ[FermionNumber](0 => 1, 1 => 3, -1 => 3)')
48+
VfSU₂ = (ℂ[FermionSpin](0 => 3, 1 // 2 => 1),
49+
ℂ[FermionSpin](0 => 2, 1 => 1),
50+
ℂ[FermionSpin](1 // 2 => 1, 1 => 1)',
51+
ℂ[FermionSpin](0 => 2, 1 // 2 => 2),
52+
ℂ[FermionSpin](0 => 1, 1 // 2 => 1, 3 // 2 => 1)')
53+
Vfib = (Vect[FibonacciAnyon](:I => 1, => 2),
54+
Vect[FibonacciAnyon](:I => 2, => 1),
55+
Vect[FibonacciAnyon](:I => 1, => 1),
56+
Vect[FibonacciAnyon](:I => 1, => 1),
57+
Vect[FibonacciAnyon](:I => 1, => 1))
3358
@testset "Braiding tensor" begin
34-
V1 =^2 ^3 ^3 ^2
35-
t1 = @constinferred BraidingTensor(V1)
36-
@test space(t1) == V1
37-
@test codomain(t1) == codomain(V1)
38-
@test domain(t1) == domain(V1)
39-
@test scalartype(t1) == Float64
40-
@test storagetype(t1) == Vector{Float64}
41-
t2 = @constinferred BraidingTensor{ComplexF64}(V1)
42-
@test scalartype(t2) == ComplexF64
43-
@test storagetype(t2) == Vector{ComplexF64}
44-
45-
V2 =^2 ^3 ^2 ^3
46-
@test_throws SpaceMismatch BraidingTensor(V2)
47-
48-
@test adjoint(t1) isa BraidingTensor
59+
for V in (Vtr, VU₁, VfU₁, VfSU₂, Vfib)
60+
W = V[1] V[2] V[2] V[1]
61+
t1 = @constinferred BraidingTensor(W)
62+
@test space(t1) == W
63+
@test codomain(t1) == codomain(W)
64+
@test domain(t1) == domain(W)
65+
@test scalartype(t1) == (isreal(sectortype(W)) ? Float64 : ComplexF64)
66+
@test storagetype(t1) == Vector{scalartype(t1)}
67+
t2 = @constinferred BraidingTensor{ComplexF64}(W)
68+
@test scalartype(t2) == ComplexF64
69+
@test storagetype(t2) == Vector{ComplexF64}
70+
71+
W2 = reverse(codomain(W)) domain(W)
72+
@test_throws SpaceMismatch BraidingTensor(W2)
73+
74+
@test adjoint(t1) isa BraidingTensor
75+
76+
t3 = @inferred TensorMap(t2)
77+
t4 = braid(id(storagetype(t2), domain(t2)), ((2, 1), (3, 4)), (1, 2, 3, 4))
78+
@test t1 t4
79+
for (c, b) in blocks(t1)
80+
@test block(t1, c) b block(t3, c)
81+
end
82+
for (f1, f2) in fusiontrees(t1)
83+
@test t1[f1, f2] t3[f1, f2]
84+
end
85+
end
4986
end
5087

5188
@testset "planar methods" verbose = true begin

0 commit comments

Comments
 (0)