Skip to content

Commit 8f0a0c0

Browse files
authored
Add rrule for twist (#217)
* Add rrule for `twist` * test rrule for `twist` * Add anyonic AD test
1 parent 6352ec6 commit 8f0a0c0

File tree

2 files changed

+109
-88
lines changed

2 files changed

+109
-88
lines changed

ext/TensorKitChainRulesCoreExt/linalg.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ function ChainRulesCore.rrule(::typeof(adjoint), A::AbstractTensorMap)
7777
return adjoint(A), adjoint_pullback
7878
end
7979

80+
function ChainRulesCore.rrule(::typeof(twist), A::AbstractTensorMap, is; inv::Bool=false)
81+
tA = twist(A, is; inv)
82+
twist_pullback(ΔA) = NoTangent(), twist(unthunk(ΔA), is; inv=!inv), NoTangent()
83+
return tA, twist_pullback
84+
end
85+
8086
function ChainRulesCore.rrule(::typeof(dot), a::AbstractTensorMap, b::AbstractTensorMap)
8187
dot_pullback(Δd) = NoTangent(), @thunk(b * Δd'), @thunk(a * Δd)
8288
return dot(a, b), dot_pullback

test/ad.jl

Lines changed: 103 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,18 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
122122
ℂ[SU2Irrep](0 => 1, 1 => 1),
123123
ℂ[SU2Irrep](1 // 2 => 1, 1 => 1)',
124124
ℂ[SU2Irrep](1 // 2 => 2),
125-
ℂ[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)'))
125+
ℂ[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)'),
126+
(ℂ[FibonacciAnyon](:I => 1, => 1),
127+
ℂ[FibonacciAnyon](:I => 1, => 2)',
128+
ℂ[FibonacciAnyon](:I => 3, => 2)',
129+
ℂ[FibonacciAnyon](:I => 2, => 3),
130+
ℂ[FibonacciAnyon](:I => 2, => 2)))
126131

127132
@timedtestset "Automatic Differentiation with spacetype $(TensorKit.type_repr(eltype(V)))" verbose = true for V in
128133
Vlist
134+
eltypes = isreal(sectortype(eltype(V))) ? (Float64, ComplexF64) : (ComplexF64,)
135+
symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding
136+
129137
@timedtestset "Basic utility" begin
130138
T1 = randn(Float64, V[1] V[2] V[3] V[4])
131139
T2 = randn(ComplexF64, V[1] V[2] V[3] V[4])
@@ -137,14 +145,16 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
137145
test_rrule(copy, T1)
138146
test_rrule(copy, T2)
139147
test_rrule(TensorKit.copy_oftype, T1, ComplexF64)
140-
test_rrule(TensorKit.permutedcopy_oftype, T1, ComplexF64, ((3, 1), (2, 4)))
148+
if symmetricbraiding
149+
test_rrule(TensorKit.permutedcopy_oftype, T1, ComplexF64, ((3, 1), (2, 4)))
141150

142-
test_rrule(convert, Array, T1)
143-
test_rrule(TensorMap, convert(Array, T1), codomain(T1), domain(T1);
144-
fkwargs=(; tol=Inf))
151+
test_rrule(convert, Array, T1)
152+
test_rrule(TensorMap, convert(Array, T1), codomain(T1), domain(T1);
153+
fkwargs=(; tol=Inf))
154+
end
145155
end
146156

147-
@timedtestset "Basic Linear Algebra with scalartype $T" for T in (Float64, ComplexF64)
157+
@timedtestset "Basic Linear Algebra with scalartype $T" for T in eltypes
148158
A = randn(T, V[1] V[2] V[3] V[4] V[5])
149159
B = randn(T, space(A))
150160

@@ -162,14 +172,16 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
162172
C = randn(T, domain(A), codomain(A))
163173
test_rrule(*, A, C)
164174

165-
test_rrule(permute, A, ((1, 3, 2), (5, 4)))
175+
symmetricbraiding && test_rrule(permute, A, ((1, 3, 2), (5, 4)))
176+
test_rrule(twist, A, 1)
177+
test_rrule(twist, A, [1, 3])
166178

167179
D = randn(T, V[1] V[2] V[3])
168180
E = randn(T, V[4] V[5])
169-
test_rrule(, D, E)
181+
symmetricbraiding && test_rrule(, D, E)
170182
end
171183

172-
@timedtestset "Linear Algebra part II with scalartype $T" for T in (Float64, ComplexF64)
184+
@timedtestset "Linear Algebra part II with scalartype $T" for T in eltypes
173185
for i in 1:3
174186
E = randn(T, (V[1:i]...) (V[1:i]...))
175187
test_rrule(LinearAlgebra.tr, E)
@@ -184,97 +196,100 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
184196
test_rrule(LinearAlgebra.dot, A, B)
185197
end
186198

187-
@timedtestset "TensorOperations with scalartype $T" for T in (Float64, ComplexF64)
188-
atol = precision(T)
189-
rtol = precision(T)
190-
191-
@timedtestset "tensortrace!" begin
192-
for _ in 1:5
193-
k1 = rand(0:3)
194-
k2 = k1 == 3 ? 1 : rand(1:2)
195-
V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1))
196-
V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2))
197-
198-
(_p, _q) = randindextuple(k1 + 2 * k2, k1)
199-
p = _repartition(_p, rand(0:k1))
200-
q = _repartition(_q, k2)
201-
ip = _repartition(invperm(linearize((_p, _q))), rand(0:(k1 + 2 * k2)))
202-
A = randn(T, permute(prod(V1) prod(V2) prod(V2), ip))
199+
symmetricbraiding &&
200+
@timedtestset "TensorOperations with scalartype $T" for T in eltypes
201+
atol = precision(T)
202+
rtol = precision(T)
203203

204-
α = randn(T)
205-
β = randn(T)
206-
for conjA in (false, true)
207-
C = randn!(TensorOperations.tensoralloc_add(T, A, p, conjA, Val(false)))
208-
test_rrule(tensortrace!, C, A, p, q, conjA, α, β; atol, rtol)
204+
@timedtestset "tensortrace!" begin
205+
for _ in 1:5
206+
k1 = rand(0:3)
207+
k2 = k1 == 3 ? 1 : rand(1:2)
208+
V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1))
209+
V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2))
210+
211+
(_p, _q) = randindextuple(k1 + 2 * k2, k1)
212+
p = _repartition(_p, rand(0:k1))
213+
q = _repartition(_q, k2)
214+
ip = _repartition(invperm(linearize((_p, _q))), rand(0:(k1 + 2 * k2)))
215+
A = randn(T, permute(prod(V1) prod(V2) prod(V2), ip))
216+
217+
α = randn(T)
218+
β = randn(T)
219+
for conjA in (false, true)
220+
C = randn!(TensorOperations.tensoralloc_add(T, A, p, conjA,
221+
Val(false)))
222+
test_rrule(tensortrace!, C, A, p, q, conjA, α, β; atol, rtol)
223+
end
209224
end
210225
end
211-
end
212226

213-
@timedtestset "tensoradd!" begin
214-
A = randn(T, V[1] V[2] V[3] V[4] V[5])
215-
α = randn(T)
216-
β = randn(T)
217-
218-
# repeat a couple times to get some distribution of arrows
219-
for _ in 1:5
220-
p = randindextuple(length(V))
227+
@timedtestset "tensoradd!" begin
228+
A = randn(T, V[1] V[2] V[3] V[4] V[5])
229+
α = randn(T)
230+
β = randn(T)
221231

222-
C1 = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false)))
223-
test_rrule(tensoradd!, C1, A, p, false, α, β; atol, rtol)
232+
# repeat a couple times to get some distribution of arrows
233+
for _ in 1:5
234+
p = randindextuple(length(V))
224235

225-
C2 = randn!(TensorOperations.tensoralloc_add(T, A, p, true, Val(false)))
226-
test_rrule(tensoradd!, C2, A, p, true, α, β; atol, rtol)
236+
C1 = randn!(TensorOperations.tensoralloc_add(T, A, p, false,
237+
Val(false)))
238+
test_rrule(tensoradd!, C1, A, p, false, α, β; atol, rtol)
227239

228-
A = rand(Bool) ? C1 : C2
229-
end
230-
end
240+
C2 = randn!(TensorOperations.tensoralloc_add(T, A, p, true, Val(false)))
241+
test_rrule(tensoradd!, C2, A, p, true, α, β; atol, rtol)
231242

232-
@timedtestset "tensorcontract!" begin
233-
for _ in 1:5
234-
d = 0
235-
local V1, V2, V3
236-
# retry a couple times to make sure there are at least some nonzero elements
237-
for _ in 1:10
238-
k1 = rand(0:3)
239-
k2 = rand(0:2)
240-
k3 = rand(0:2)
241-
V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init=one(V[1]))
242-
V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init=one(V[1]))
243-
V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init=one(V[1]))
244-
d = min(dim(V1 V2), dim(V1' V2), dim(V2 V3), dim(V2' V3))
245-
d > 0 && break
243+
A = rand(Bool) ? C1 : C2
246244
end
247-
ipA = randindextuple(length(V1) + length(V2))
248-
pA = _repartition(invperm(linearize(ipA)), length(V1))
249-
ipB = randindextuple(length(V2) + length(V3))
250-
pB = _repartition(invperm(linearize(ipB)), length(V2))
251-
pAB = randindextuple(length(V1) + length(V3))
245+
end
252246

253-
α = randn(T)
254-
β = randn(T)
255-
V2_conj = prod(conj, V2; init=one(V[1]))
256-
257-
for conjA in (false, true), conjB in (false, true)
258-
A = randn(T, permute(V1 (conjA ? V2_conj : V2), ipA))
259-
B = randn(T, permute((conjB ? V2_conj : V2) V3, ipB))
260-
C = randn!(TensorOperations.tensoralloc_contract(T, A, pA,
261-
conjA,
262-
B, pB, conjB, pAB,
263-
Val(false)))
264-
test_rrule(tensorcontract!, C,
265-
A, pA, conjA, B, pB, conjB, pAB,
266-
α, β; atol, rtol)
247+
@timedtestset "tensorcontract!" begin
248+
for _ in 1:5
249+
d = 0
250+
local V1, V2, V3
251+
# retry a couple times to make sure there are at least some nonzero elements
252+
for _ in 1:10
253+
k1 = rand(0:3)
254+
k2 = rand(0:2)
255+
k3 = rand(0:2)
256+
V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init=one(V[1]))
257+
V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init=one(V[1]))
258+
V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init=one(V[1]))
259+
d = min(dim(V1 V2), dim(V1' V2), dim(V2 V3), dim(V2' V3))
260+
d > 0 && break
261+
end
262+
ipA = randindextuple(length(V1) + length(V2))
263+
pA = _repartition(invperm(linearize(ipA)), length(V1))
264+
ipB = randindextuple(length(V2) + length(V3))
265+
pB = _repartition(invperm(linearize(ipB)), length(V2))
266+
pAB = randindextuple(length(V1) + length(V3))
267+
268+
α = randn(T)
269+
β = randn(T)
270+
V2_conj = prod(conj, V2; init=one(V[1]))
271+
272+
for conjA in (false, true), conjB in (false, true)
273+
A = randn(T, permute(V1 (conjA ? V2_conj : V2), ipA))
274+
B = randn(T, permute((conjB ? V2_conj : V2) V3, ipB))
275+
C = randn!(TensorOperations.tensoralloc_contract(T, A, pA,
276+
conjA,
277+
B, pB, conjB, pAB,
278+
Val(false)))
279+
test_rrule(tensorcontract!, C,
280+
A, pA, conjA, B, pB, conjB, pAB,
281+
α, β; atol, rtol)
282+
end
267283
end
268284
end
269-
end
270285

271-
@timedtestset "tensorscalar" begin
272-
A = randn(T, ProductSpace{typeof(V[1]),0}())
273-
test_rrule(tensorscalar, A)
286+
@timedtestset "tensorscalar" begin
287+
A = randn(T, ProductSpace{typeof(V[1]),0}())
288+
test_rrule(tensorscalar, A)
289+
end
274290
end
275-
end
276291

277-
@timedtestset "Factorizations with scalartype $T" for T in (Float64, ComplexF64)
292+
@timedtestset "Factorizations with scalartype $T" for T in eltypes
278293
A = randn(T, V[1] V[2] V[3] V[4] V[5])
279294
B = randn(T, space(A)')
280295
C = randn(T, V[1] V[2] V[1] V[2])
@@ -367,13 +382,13 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
367382

368383
c, = TensorKit.MatrixAlgebra._argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])),
369384
blocks(S))
370-
U, S, V, ϵ = tsvd(C; trunc=truncdim(2 * dim(c)))
385+
trunc = truncdim(round(Int, 2 * dim(c)))
386+
U, S, V, ϵ = tsvd(C; trunc)
371387
ΔU = randn(scalartype(U), space(U))
372388
ΔS = randn(scalartype(S), space(S))
373389
ΔV = randn(scalartype(V), space(V))
374390
T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V)
375-
test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0),
376-
fkwargs=(; trunc=truncdim(2 * dim(c))))
391+
test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0), fkwargs=(; trunc))
377392
end
378393

379394
let D = LinearAlgebra.eigvals(C)

0 commit comments

Comments
 (0)