@@ -122,10 +122,18 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
122122 ℂ[SU2Irrep](0 => 1 , 1 => 1 ),
123123 ℂ[SU2Irrep](1 // 2 => 1 , 1 => 1 )' ,
124124 ℂ[SU2Irrep](1 // 2 => 2 ),
125- ℂ[SU2Irrep](0 => 1 , 1 // 2 => 1 , 3 // 2 => 1 )' ))
125+ ℂ[SU2Irrep](0 => 1 , 1 // 2 => 1 , 3 // 2 => 1 )' ),
126+ (ℂ[FibonacciAnyon](:I => 1 , :τ => 1 ),
127+ ℂ[FibonacciAnyon](:I => 1 , :τ => 2 )' ,
128+ ℂ[FibonacciAnyon](:I => 3 , :τ => 2 )' ,
129+ ℂ[FibonacciAnyon](:I => 2 , :τ => 3 ),
130+ ℂ[FibonacciAnyon](:I => 2 , :τ => 2 )))
126131
127132@timedtestset " Automatic Differentiation with spacetype $(TensorKit. type_repr (eltype (V))) " verbose = true for V in
128133 Vlist
134+ eltypes = isreal (sectortype (eltype (V))) ? (Float64, ComplexF64) : (ComplexF64,)
135+ symmetricbraiding = BraidingStyle (sectortype (eltype (V))) isa SymmetricBraiding
136+
129137 @timedtestset " Basic utility" begin
130138 T1 = randn (Float64, V[1 ] ⊗ V[2 ] ← V[3 ] ⊗ V[4 ])
131139 T2 = randn (ComplexF64, V[1 ] ⊗ V[2 ] ← V[3 ] ⊗ V[4 ])
@@ -137,14 +145,16 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
137145 test_rrule (copy, T1)
138146 test_rrule (copy, T2)
139147 test_rrule (TensorKit. copy_oftype, T1, ComplexF64)
140- test_rrule (TensorKit. permutedcopy_oftype, T1, ComplexF64, ((3 , 1 ), (2 , 4 )))
148+ if symmetricbraiding
149+ test_rrule (TensorKit. permutedcopy_oftype, T1, ComplexF64, ((3 , 1 ), (2 , 4 )))
141150
142- test_rrule (convert, Array, T1)
143- test_rrule (TensorMap, convert (Array, T1), codomain (T1), domain (T1);
144- fkwargs= (; tol= Inf ))
151+ test_rrule (convert, Array, T1)
152+ test_rrule (TensorMap, convert (Array, T1), codomain (T1), domain (T1);
153+ fkwargs= (; tol= Inf ))
154+ end
145155 end
146156
147- @timedtestset " Basic Linear Algebra with scalartype $T " for T in (Float64, ComplexF64)
157+ @timedtestset " Basic Linear Algebra with scalartype $T " for T in eltypes
148158 A = randn (T, V[1 ] ⊗ V[2 ] ← V[3 ] ⊗ V[4 ] ⊗ V[5 ])
149159 B = randn (T, space (A))
150160
@@ -162,14 +172,16 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
162172 C = randn (T, domain (A), codomain (A))
163173 test_rrule (* , A, C)
164174
165- test_rrule (permute, A, ((1 , 3 , 2 ), (5 , 4 )))
175+ symmetricbraiding && test_rrule (permute, A, ((1 , 3 , 2 ), (5 , 4 )))
176+ test_rrule (twist, A, 1 )
177+ test_rrule (twist, A, [1 , 3 ])
166178
167179 D = randn (T, V[1 ] ⊗ V[2 ] ← V[3 ])
168180 E = randn (T, V[4 ] ← V[5 ])
169- test_rrule (⊗ , D, E)
181+ symmetricbraiding && test_rrule (⊗ , D, E)
170182 end
171183
172- @timedtestset " Linear Algebra part II with scalartype $T " for T in (Float64, ComplexF64)
184+ @timedtestset " Linear Algebra part II with scalartype $T " for T in eltypes
173185 for i in 1 : 3
174186 E = randn (T, ⊗ (V[1 : i]. .. ) ← ⊗ (V[1 : i]. .. ))
175187 test_rrule (LinearAlgebra. tr, E)
@@ -184,97 +196,100 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
184196 test_rrule (LinearAlgebra. dot, A, B)
185197 end
186198
187- @timedtestset " TensorOperations with scalartype $T " for T in (Float64, ComplexF64)
188- atol = precision (T)
189- rtol = precision (T)
190-
191- @timedtestset " tensortrace!" begin
192- for _ in 1 : 5
193- k1 = rand (0 : 3 )
194- k2 = k1 == 3 ? 1 : rand (1 : 2 )
195- V1 = map (v -> rand (Bool) ? v' : v, rand (V, k1))
196- V2 = map (v -> rand (Bool) ? v' : v, rand (V, k2))
197-
198- (_p, _q) = randindextuple (k1 + 2 * k2, k1)
199- p = _repartition (_p, rand (0 : k1))
200- q = _repartition (_q, k2)
201- ip = _repartition (invperm (linearize ((_p, _q))), rand (0 : (k1 + 2 * k2)))
202- A = randn (T, permute (prod (V1) ⊗ prod (V2) ← prod (V2), ip))
199+ symmetricbraiding &&
200+ @timedtestset " TensorOperations with scalartype $T " for T in eltypes
201+ atol = precision (T)
202+ rtol = precision (T)
203203
204- α = randn (T)
205- β = randn (T)
206- for conjA in (false , true )
207- C = randn! (TensorOperations. tensoralloc_add (T, A, p, conjA, Val (false )))
208- test_rrule (tensortrace!, C, A, p, q, conjA, α, β; atol, rtol)
204+ @timedtestset " tensortrace!" begin
205+ for _ in 1 : 5
206+ k1 = rand (0 : 3 )
207+ k2 = k1 == 3 ? 1 : rand (1 : 2 )
208+ V1 = map (v -> rand (Bool) ? v' : v, rand (V, k1))
209+ V2 = map (v -> rand (Bool) ? v' : v, rand (V, k2))
210+
211+ (_p, _q) = randindextuple (k1 + 2 * k2, k1)
212+ p = _repartition (_p, rand (0 : k1))
213+ q = _repartition (_q, k2)
214+ ip = _repartition (invperm (linearize ((_p, _q))), rand (0 : (k1 + 2 * k2)))
215+ A = randn (T, permute (prod (V1) ⊗ prod (V2) ← prod (V2), ip))
216+
217+ α = randn (T)
218+ β = randn (T)
219+ for conjA in (false , true )
220+ C = randn! (TensorOperations. tensoralloc_add (T, A, p, conjA,
221+ Val (false )))
222+ test_rrule (tensortrace!, C, A, p, q, conjA, α, β; atol, rtol)
223+ end
209224 end
210225 end
211- end
212226
213- @timedtestset " tensoradd!" begin
214- A = randn (T, V[1 ] ⊗ V[2 ] ⊗ V[3 ] ← V[4 ] ⊗ V[5 ])
215- α = randn (T)
216- β = randn (T)
217-
218- # repeat a couple times to get some distribution of arrows
219- for _ in 1 : 5
220- p = randindextuple (length (V))
227+ @timedtestset " tensoradd!" begin
228+ A = randn (T, V[1 ] ⊗ V[2 ] ⊗ V[3 ] ← V[4 ] ⊗ V[5 ])
229+ α = randn (T)
230+ β = randn (T)
221231
222- C1 = randn! (TensorOperations. tensoralloc_add (T, A, p, false , Val (false )))
223- test_rrule (tensoradd!, C1, A, p, false , α, β; atol, rtol)
232+ # repeat a couple times to get some distribution of arrows
233+ for _ in 1 : 5
234+ p = randindextuple (length (V))
224235
225- C2 = randn! (TensorOperations. tensoralloc_add (T, A, p, true , Val (false )))
226- test_rrule (tensoradd!, C2, A, p, true , α, β; atol, rtol)
236+ C1 = randn! (TensorOperations. tensoralloc_add (T, A, p, false ,
237+ Val (false )))
238+ test_rrule (tensoradd!, C1, A, p, false , α, β; atol, rtol)
227239
228- A = rand (Bool) ? C1 : C2
229- end
230- end
240+ C2 = randn! (TensorOperations. tensoralloc_add (T, A, p, true , Val (false )))
241+ test_rrule (tensoradd!, C2, A, p, true , α, β; atol, rtol)
231242
232- @timedtestset " tensorcontract!" begin
233- for _ in 1 : 5
234- d = 0
235- local V1, V2, V3
236- # retry a couple times to make sure there are at least some nonzero elements
237- for _ in 1 : 10
238- k1 = rand (0 : 3 )
239- k2 = rand (0 : 2 )
240- k3 = rand (0 : 2 )
241- V1 = prod (v -> rand (Bool) ? v' : v, rand (V, k1); init= one (V[1 ]))
242- V2 = prod (v -> rand (Bool) ? v' : v, rand (V, k2); init= one (V[1 ]))
243- V3 = prod (v -> rand (Bool) ? v' : v, rand (V, k3); init= one (V[1 ]))
244- d = min (dim (V1 ← V2), dim (V1' ← V2), dim (V2 ← V3), dim (V2' ← V3))
245- d > 0 && break
243+ A = rand (Bool) ? C1 : C2
246244 end
247- ipA = randindextuple (length (V1) + length (V2))
248- pA = _repartition (invperm (linearize (ipA)), length (V1))
249- ipB = randindextuple (length (V2) + length (V3))
250- pB = _repartition (invperm (linearize (ipB)), length (V2))
251- pAB = randindextuple (length (V1) + length (V3))
245+ end
252246
253- α = randn (T)
254- β = randn (T)
255- V2_conj = prod (conj, V2; init= one (V[1 ]))
256-
257- for conjA in (false , true ), conjB in (false , true )
258- A = randn (T, permute (V1 ← (conjA ? V2_conj : V2), ipA))
259- B = randn (T, permute ((conjB ? V2_conj : V2) ← V3, ipB))
260- C = randn! (TensorOperations. tensoralloc_contract (T, A, pA,
261- conjA,
262- B, pB, conjB, pAB,
263- Val (false )))
264- test_rrule (tensorcontract!, C,
265- A, pA, conjA, B, pB, conjB, pAB,
266- α, β; atol, rtol)
247+ @timedtestset " tensorcontract!" begin
248+ for _ in 1 : 5
249+ d = 0
250+ local V1, V2, V3
251+ # retry a couple times to make sure there are at least some nonzero elements
252+ for _ in 1 : 10
253+ k1 = rand (0 : 3 )
254+ k2 = rand (0 : 2 )
255+ k3 = rand (0 : 2 )
256+ V1 = prod (v -> rand (Bool) ? v' : v, rand (V, k1); init= one (V[1 ]))
257+ V2 = prod (v -> rand (Bool) ? v' : v, rand (V, k2); init= one (V[1 ]))
258+ V3 = prod (v -> rand (Bool) ? v' : v, rand (V, k3); init= one (V[1 ]))
259+ d = min (dim (V1 ← V2), dim (V1' ← V2), dim (V2 ← V3), dim (V2' ← V3))
260+ d > 0 && break
261+ end
262+ ipA = randindextuple (length (V1) + length (V2))
263+ pA = _repartition (invperm (linearize (ipA)), length (V1))
264+ ipB = randindextuple (length (V2) + length (V3))
265+ pB = _repartition (invperm (linearize (ipB)), length (V2))
266+ pAB = randindextuple (length (V1) + length (V3))
267+
268+ α = randn (T)
269+ β = randn (T)
270+ V2_conj = prod (conj, V2; init= one (V[1 ]))
271+
272+ for conjA in (false , true ), conjB in (false , true )
273+ A = randn (T, permute (V1 ← (conjA ? V2_conj : V2), ipA))
274+ B = randn (T, permute ((conjB ? V2_conj : V2) ← V3, ipB))
275+ C = randn! (TensorOperations. tensoralloc_contract (T, A, pA,
276+ conjA,
277+ B, pB, conjB, pAB,
278+ Val (false )))
279+ test_rrule (tensorcontract!, C,
280+ A, pA, conjA, B, pB, conjB, pAB,
281+ α, β; atol, rtol)
282+ end
267283 end
268284 end
269- end
270285
271- @timedtestset " tensorscalar" begin
272- A = randn (T, ProductSpace {typeof(V[1]),0} ())
273- test_rrule (tensorscalar, A)
286+ @timedtestset " tensorscalar" begin
287+ A = randn (T, ProductSpace {typeof(V[1]),0} ())
288+ test_rrule (tensorscalar, A)
289+ end
274290 end
275- end
276291
277- @timedtestset " Factorizations with scalartype $T " for T in (Float64, ComplexF64)
292+ @timedtestset " Factorizations with scalartype $T " for T in eltypes
278293 A = randn (T, V[1 ] ⊗ V[2 ] ← V[3 ] ⊗ V[4 ] ⊗ V[5 ])
279294 B = randn (T, space (A)' )
280295 C = randn (T, V[1 ] ⊗ V[2 ] ← V[1 ] ⊗ V[2 ])
@@ -367,13 +382,13 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
367382
368383 c, = TensorKit. MatrixAlgebra. _argmax (x -> sqrt (dim (x[1 ])) * maximum (diag (x[2 ])),
369384 blocks (S))
370- U, S, V, ϵ = tsvd (C; trunc= truncdim (2 * dim (c)))
385+ trunc = truncdim (round (Int, 2 * dim (c)))
386+ U, S, V, ϵ = tsvd (C; trunc)
371387 ΔU = randn (scalartype (U), space (U))
372388 ΔS = randn (scalartype (S), space (S))
373389 ΔV = randn (scalartype (V), space (V))
374390 T <: Complex && remove_svdgauge_depence! (ΔU, ΔV, U, S, V)
375- test_rrule (tsvd, C; atol, output_tangent= (ΔU, ΔS, ΔV, 0.0 ),
376- fkwargs= (; trunc= truncdim (2 * dim (c))))
391+ test_rrule (tsvd, C; atol, output_tangent= (ΔU, ΔS, ΔV, 0.0 ), fkwargs= (; trunc))
377392 end
378393
379394 let D = LinearAlgebra. eigvals (C)
0 commit comments