Skip to content

Commit fa15514

Browse files
sanderdemeyerJutholkdvos
authored
add rrule for flip (#241)
* add rrule for flip * add inverse flip and ad rules * Update src/tensors/indexmanipulations.jl fix spelling mistake Co-authored-by: Lukas Devos <[email protected]> --------- Co-authored-by: Jutho Haegeman <[email protected]> Co-authored-by: Lukas Devos <[email protected]>
1 parent ed09f13 commit fa15514

File tree

5 files changed

+51
-11
lines changed

5 files changed

+51
-11
lines changed

ext/TensorKitChainRulesCoreExt/linalg.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ function ChainRulesCore.rrule(::typeof(twist), A::AbstractTensorMap, is; inv::Bo
8383
return tA, twist_pullback
8484
end
8585

86+
function ChainRulesCore.rrule(::typeof(flip), A::AbstractTensorMap, is; inv::Bool=false)
87+
tA = flip(A, is; inv)
88+
flip_pullback(ΔA) = NoTangent(), flip(unthunk(ΔA), is; inv=!inv), NoTangent()
89+
return tA, flip_pullback
90+
end
91+
8692
function ChainRulesCore.rrule(::typeof(dot), a::AbstractTensorMap, b::AbstractTensorMap)
8793
dot_pullback(Δd) = NoTangent(), @thunk(b * Δd'), @thunk(a * Δd)
8894
return dot(a, b), dot_pullback

src/fusiontrees/manipulations.jl

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,24 +243,46 @@ end
243243
# -> A-move (foldleft, foldright) is complicated, needs to be reexpressed in standard form
244244

245245
# flip a duality flag of a fusion tree
246-
function flip(f₁::FusionTree{I,N₁}, f₂::FusionTree{I,N₂}, i::Int) where {I<:Sector,N₁,N₂}
246+
function flip(f₁::FusionTree{I,N₁}, f₂::FusionTree{I,N₂}, i::Int;
247+
inv::Bool=false) where {I<:Sector,N₁,N₂}
247248
@assert 0 < i N₁ + N₂
248249
if i N₁
249250
a = f₁.uncoupled[i]
250-
fs = frobeniusschur(a) * twist(a)
251-
factor = f₁.isdual[i] ? fs : one(fs)
251+
χₐ = frobeniusschur(a)
252+
θₐ = twist(a)
253+
if !inv
254+
factor = f₁.isdual[i] ? χₐ * θₐ : one(θₐ)
255+
else
256+
factor = f₁.isdual[i] ? one(θₐ) : χₐ * conj(θₐ)
257+
end
252258
isdual′ = TupleTools.setindex(f₁.isdual, !f₁.isdual[i], i)
253259
f₁′ = FusionTree{I}(f₁.uncoupled, f₁.coupled, isdual′, f₁.innerlines, f₁.vertices)
254260
return SingletonDict((f₁′, f₂) => factor)
255261
else
256262
i -= N₁
257263
a = f₂.uncoupled[i]
258-
factor = f₂.isdual[i] ? frobeniusschur(a) : twist(a)
264+
χₐ = frobeniusschur(a)
265+
θₐ = twist(a)
266+
if !inv
267+
factor = f₂.isdual[i] ? χₐ * one(θₐ) : θₐ
268+
else
269+
factor = f₂.isdual[i] ? conj(θₐ) : χₐ * one(θₐ)
270+
end
259271
isdual′ = TupleTools.setindex(f₂.isdual, !f₂.isdual[i], i)
260272
f₂′ = FusionTree{I}(f₂.uncoupled, f₂.coupled, isdual′, f₂.innerlines, f₂.vertices)
261273
return SingletonDict((f₁, f₂′) => factor)
262274
end
263275
end
276+
function flip(f₁::FusionTree{I,N₁}, f₂::FusionTree{I,N₂}, ind;
277+
inv::Bool=false) where {I<:Sector,N₁,N₂}
278+
f₁′, f₂′ = f₁, f₂
279+
factor = one(sectorscalartype(I))
280+
for i in ind
281+
(f₁′, f₂′), s = only(flip(f₁′, f₂′, i; inv))
282+
factor *= s
283+
end
284+
return SingletonDict((f₁′, f₂′) => factor)
285+
end
264286

265287
# change to N₁ - 1, N₂ + 1
266288
function bendright(f₁::FusionTree{I,N₁}, f₂::FusionTree{I,N₂}) where {I<:Sector,N₁,N₂}

src/tensors/indexmanipulations.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,19 @@
55
66
Return a new tensor that is isomorphic to `t` but where the arrows on the indices `i` that satisfy
77
`i ∈ I` are flipped, i.e. `space(t′, i) = flip(space(t, i))`.
8+
9+
!!! note
10+
The isomorphism that `flip` applies to each of the indices `i ∈ I` is such that flipping two indices
11+
that are afterwards contracted within an `@tensor` contraction will yield the same result as without
12+
flipping those indices first. However, `flip` is not involutory, i.e. `flip(flip(t, I), I) != t` in
13+
general. To obtain the original tensor, one can use the `inv` keyword, i.e. it holds that
14+
`flip(flip(t, I), I; inv=true) == t`.
815
"""
9-
function flip(t::AbstractTensorMap, I)
16+
function flip(t::AbstractTensorMap, I; inv::Bool=false)
1017
P = flip(space(t), I)
1118
t′ = similar(t, P)
1219
for (f₁, f₂) in fusiontrees(t)
13-
f₁′, f₂′ = f₁, f₂
14-
factor = one(scalartype(t))
15-
for i in I
16-
(f₁′, f₂′), s = only(flip(f₁′, f₂′, i))
17-
factor *= s
18-
end
20+
(f₁′, f₂′), factor = only(flip(f₁, f₂, I; inv))
1921
scale!(t′[f₁′, f₂′], t[f₁, f₂], factor)
2022
end
2123
return t′

test/ad.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,9 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
224224
test_rrule(twist, A, 1)
225225
test_rrule(twist, A, [1, 3])
226226

227+
test_rrule(flip, A, 1)
228+
test_rrule(flip, A, [1, 3, 4])
229+
227230
D = randn(T, V[1] V[2] V[3])
228231
E = randn(T, V[4] V[5])
229232
symmetricbraiding && test_rrule(, D, E)

test/tensors.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,13 @@ for V in spacelist
324324
@test HrA12array convert(Array, HrA12)
325325
end
326326
end
327+
@timedtestset "Index flipping: test flipping inverse" begin
328+
t = rand(ComplexF64, V1 V1' V1' V1)
329+
for i in 1:4
330+
@test t flip(flip(t, i), i; inv=true)
331+
@test t flip(flip(t, i; inv=true), i)
332+
end
333+
end
327334
@timedtestset "Index flipping: test via explicit flip" begin
328335
t = rand(ComplexF64, V1 V1' V1' V1)
329336
F1 = unitary(flip(V1), V1)

0 commit comments

Comments
 (0)