Skip to content

Commit d5301ee

Browse files
committed
update fusiontensor overload
1 parent 37622f7 commit d5301ee

File tree

2 files changed

+21
-22
lines changed

2 files changed

+21
-22
lines changed

src/TensorKit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ using ScopedValues
119119

120120
using TensorKitSectors
121121
import TensorKitSectors: dim, BraidingStyle, FusionStyle, ,
122-
import TensorKitSectors: dual, type_repr
122+
import TensorKitSectors: dual, type_repr, fusiontensor
123123
import TensorKitSectors: twist
124124

125125
using Base: @boundscheck, @propagate_inbounds, @constprop,

src/fusiontrees/fusiontrees.jl

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -264,50 +264,50 @@ end
264264
fusiontreedict(I) = FusionStyle(I) isa UniqueFusion ? SingletonDict : FusionTreeDict
265265

266266
# converting to actual array
267-
function Base.convert(A::Type{<:AbstractArray}, f::FusionTree{I, 0}) where {I}
268-
X = convert(A, fusiontensor(unit(I), unit(I), unit(I)))[1, 1, :]
269-
return X
270-
end
271-
function Base.convert(A::Type{<:AbstractArray}, f::FusionTree{I, 1}) where {I}
267+
Base.convert(A::Type{<:AbstractArray}, f::FusionTree) = convert(A, fusiontensor(f))
268+
# TODO: is this piracy?
269+
Base.convert(A::Type{<:AbstractArray}, (f₁, f₂)::FusionTreePair) =
270+
convert(A, fusiontensor((f₁, f₂)))
271+
272+
fusiontensor(::FusionTree{I, 0}) where {I} = fusiontensor(unit(I), unit(I), unit(I))[1, 1, :]
273+
function fusiontensor(f::FusionTree{I, 1}) where {I}
272274
c = f.coupled
273275
if f.isdual[1]
274276
sqrtdc = sqrtdim(c)
275-
Zcbartranspose = sqrtdc * convert(A, fusiontensor(dual(c), c, unit(c)))[:, :, 1, 1]
277+
Zcbartranspose = sqrtdc * fusiontensor(dual(c), c, unit(c))[:, :, 1, 1]
276278
X = conj!(Zcbartranspose) # we want Zcbar^†
277279
else
278-
X = convert(A, fusiontensor(c, unit(c), c))[:, 1, :, 1, 1]
280+
X = fusiontensor(c, unit(c), c)[:, 1, :, 1, 1]
279281
end
280282
return X
281283
end
282-
283-
function Base.convert(A::Type{<:AbstractArray}, f::FusionTree{I, 2}) where {I}
284+
function fusiontensor(f::FusionTree{I, 2}) where {I}
284285
a, b = f.uncoupled
285286
isduala, isdualb = f.isdual
286287
c = f.coupled
287288
μ = (FusionStyle(I) isa GenericFusion) ? f.vertices[1] : 1
288-
C = convert(A, fusiontensor(a, b, c))[:, :, :, μ]
289+
C = fusiontensor(a, b, c)[:, :, :, μ]
289290
X = C
290291
if isduala
291-
Za = convert(A, FusionTree((a,), a, (isduala,), ()))
292+
Za = fusiontensor(FusionTree((a,), a, (isduala,), ()))
292293
@tensor X[a′, b, c] := Za[a′, a] * X[a, b, c]
293294
end
294295
if isdualb
295-
Zb = convert(A, FusionTree((b,), b, (isdualb,), ()))
296+
Zb = fusiontensor(FusionTree((b,), b, (isdualb,), ()))
296297
@tensor X[a, b′, c] := Zb[b′, b] * X[a, b, c]
297298
end
298299
return X
299300
end
300-
301-
function Base.convert(A::Type{<:AbstractArray}, f::FusionTree{I, N}) where {I, N}
301+
function fusiontensor(f::FusionTree{I, N}) where {I, N}
302302
tailout = (f.innerlines[1], TupleTools.tail2(f.uncoupled)...)
303303
isdualout = (false, TupleTools.tail2(f.isdual)...)
304304
ftail = FusionTree(tailout, f.coupled, isdualout, Base.tail(f.innerlines), Base.tail(f.vertices))
305-
Ctail = convert(A, ftail)
305+
Ctail = fusiontensor(ftail)
306306
f₁ = FusionTree(
307307
(f.uncoupled[1], f.uncoupled[2]), f.innerlines[1],
308308
(f.isdual[1], f.isdual[2]), (), (f.vertices[1],)
309309
)
310-
C1 = convert(A, f₁)
310+
C1 = fusiontensor(f₁)
311311
dtail = size(Ctail)
312312
d1 = size(C1)
313313
X = similar(C1, (d1[1], d1[2], Base.tail(dtail)...))
@@ -320,20 +320,19 @@ function Base.convert(A::Type{<:AbstractArray}, f::FusionTree{I, N}) where {I, N
320320
)
321321
end
322322

323-
# TODO: is this piracy?
324-
function Base.convert(A::Type{<:AbstractArray}, (f₁, f₂)::FusionTreePair{I}) where {I}
325-
F₁ = convert(A, f₁)
326-
F₂ = convert(A, f₂)
323+
function fusiontensor((f₁, f₂)::FusionTreePair)
324+
F₁ = fusiontensor(f₁)
325+
F₂ = fusiontensor(f₂)
327326
sz1 = size(F₁)
328327
sz2 = size(F₂)
329328
d1 = TupleTools.front(sz1)
330329
d2 = TupleTools.front(sz2)
331-
332330
return reshape(
333331
reshape(F₁, TupleTools.prod(d1), sz1[end]) *
334332
reshape(F₂, TupleTools.prod(d2), sz2[end])', (d1..., d2...)
335333
)
336334
end
335+
fusiontensor(src::FusionTreeBlock) = sum(fusiontensor, fusiontrees(src))
337336

338337
# Show methods
339338
function Base.show(io::IO, t::FusionTree{I}) where {I <: Sector}

0 commit comments

Comments
 (0)