Skip to content

Commit 1f63100

Browse files
lkdvosJutho
authored andcommitted
Fix block(BraidingTensor)
1 parent 6dc4a28 commit 1f63100

File tree

1 file changed

+35
-26
lines changed

1 file changed

+35
-26
lines changed

src/tensors/braidingtensor.jl

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ function Base.copy!(t::TensorMap, b::BraidingTensor)
117117
braiddict = artin_braid(f₂, 1; inv=b.adjoint)
118118
r = convert(scalartype(t), get(braiddict, f₁, zero(valtype(braiddict))))
119119
end
120-
for i in 1:size(data, 1), j in 1:size(data, 2)
120+
@inbounds for i in axes(data, 1), j in axes(data, 2)
121121
data[i, j, j, i] = r
122122
end
123123
end
@@ -126,39 +126,48 @@ end
126126
TensorMap(b::BraidingTensor) = copy!(similar(b), b)
127127
Base.convert(::Type{TensorMap}, b::BraidingTensor) = TensorMap(b)
128128

129+
# TODO: fix this!
130+
# block(b::BraidingTensor, s::Sector) = block(TensorMap(b), s)
129131
function block(b::BraidingTensor, s::Sector)
130132
sectortype(b) == typeof(s) || throw(SectorMismatch())
131-
(V1, V2) = domain(b)
132-
if sectortype(b) == Trivial
133+
134+
# TODO: probably always square?
135+
m = blockdim(codomain(b), s)
136+
n = blockdim(domain(b), s)
137+
data = Matrix{eltype(b)}(undef, (m, n))
138+
139+
length(data) == 0 && return data # s ∉ blocksectors(b)
140+
141+
data = fill!(data, zero(scalartype(b)))
142+
143+
V1, V2 = domain(b)
144+
if sectortype(b) === Trivial
133145
d1, d2 = dim(V1), dim(V2)
134-
n = d1 * d2
135-
data = fill!(storagetype(b)(undef, (n, n)), zero(scalartype(b)))
136-
si = 1 + d2 * d1 * d1
137-
sj = d2 + d2 * d1
138-
@inbounds for i in 1:d2, j in 1:d1
139-
data[(i - 1) * si + (j - 1) * sj + 1] = one(scalartype(b))
146+
subblock = sreshape(StridedView(data), (d1, d2, d2, d1))
147+
@inbounds for i in axes(subblock, 1), j in axes(subblock, 2)
148+
subblock[i, j, j, i] = one(scalartype(b))
140149
end
141150
return data
142151
end
143-
n = blockdim(domain(b), s)
144-
data = fill!(storagetype(b)(undef, (n, n)), zero(scalartype(b)))
145-
iter = fusiontrees(b) # actually contains information about ranges as well
146-
for (f₂, r2) in iter.colr[s]
147-
for (f₁, r1) in iter.rowr[s]
148-
a1, a2 = f₂.uncoupled
149-
d1 = dim(V1, a1)
150-
d2 = dim(V2, a2)
151-
f₁.uncoupled == (a2, a1) || continue
152-
braiddict = artin_braid(f₂, 1; inv=b.adjoint)
153-
r = convert(scalartype(b), get(braiddict, f₁, zero(valtype(braiddict))))
154-
si = 1 + n * d1
155-
sj = d2 + n
156-
start = first(r1) + (first(r2) - 1) * n
157-
@inbounds for i in 1:d2, j in 1:d1
158-
data[(i - 1) * si + (j - 1) * sj + start] = r
159-
end
152+
153+
structure = fusionblockstructure(b)
154+
for ((f1, f2), (sz, str, _)) in
155+
zip(structure.fusiontreelist, structure.fusiontreestructure)
156+
if (f1.uncoupled != reverse(f2.uncoupled)) || !(f1.coupled == f2.coupled == s)
157+
continue
158+
end
159+
160+
braiddict = artin_braid(f2, 1; inv=b.adjoint)
161+
haskey(braiddict, f1) || continue
162+
r = braiddict[f1]
163+
164+
# discard offset because single block
165+
subblock = StridedView(data, sz, str)
166+
@inbounds for i in axes(subblock, 1), j in axes(subblock, 2)
167+
subblock[i, j, j, i] = r
160168
end
161169
end
170+
162171
return data
163172
end
164173

0 commit comments

Comments
 (0)