33using BlockArrays: blocklengths
44using Strided: Strided, @strided
55
6- using TensorAlgebra: BlockedPermutation, permmortar, blockpermute
7-
8- function naive_permutedims (ft, biperm:: BlockedPermutation{2} )
9- @assert ndims (ft) == length (biperm)
10-
11- # naive permute: cast to dense, permutedims, cast to FusionTensor
12- arr = Array (ft)
13- permuted_arr = permutedims (arr, Tuple (biperm))
14- permuted = to_fusiontensor (permuted_arr, blocks (axes (ft)[biperm])... )
15- return permuted
16- end
6+ using GradedArrays: AbelianStyle, NotAbelianStyle, SymmetryStyle, checkspaces
7+ using TensorAlgebra: AbstractBlockPermutation, permmortar
178
189# permutedims with 1 tuple of 2 separate tuples
1910function fusiontensor_permutedims (ft, new_leg_indices:: Tuple{Tuple,Tuple} )
2011 return fusiontensor_permutedims (ft, new_leg_indices... )
2112end
2213
14+ function fusiontensor_permutedims! (ftdst, ftsrc, new_leg_indices:: Tuple{Tuple,Tuple} )
15+ return fusiontensor_permutedims! (ftdst, ftsrc, new_leg_indices... )
16+ end
17+
2318# permutedims with 2 separate tuples
2419function fusiontensor_permutedims (
2520 ft, new_codomain_indices:: Tuple , new_domain_indices:: Tuple
@@ -28,29 +23,55 @@ function fusiontensor_permutedims(
2823 return fusiontensor_permutedims (ft, biperm)
2924end
3025
31- function fusiontensor_permutedims (ft, biperm:: BlockedPermutation{2} )
26+ function fusiontensor_permutedims! (
27+ ftdst, ftsrc, new_codomain_indices:: Tuple , new_domain_indices:: Tuple
28+ )
29+ biperm = permmortar ((new_codomain_indices, new_domain_indices))
30+ return fusiontensor_permutedims! (ftdst, ftsrc, biperm)
31+ end
32+
33+ # permutedims with BlockedPermutation
34+ function fusiontensor_permutedims (ft, biperm:: AbstractBlockPermutation{2} )
3235 ndims (ft) == length (biperm) || throw (ArgumentError (" Invalid permutation length" ))
36+ ftdst = FusionTensor {eltype(ft)} (undef, axes (ft)[biperm])
37+ fusiontensor_permutedims! (ftdst, ft, biperm)
38+ return ftdst
39+ end
40+
41+ function fusiontensor_permutedims! (ftdst, ftsrc, biperm:: AbstractBlockPermutation{2} )
42+ ndims (ftsrc) == length (biperm) || throw (ArgumentError (" Invalid permutation length" ))
43+ blocklengths (axes (ftdst)) == blocklengths (biperm) ||
44+ throw (ArgumentError (" Destination tensor does not match bipermutation" ))
45+ checkspaces (axes (ftdst), axes (ftsrc)[biperm])
3346
34- # early return for identity operation. Do not copy. Also handle tricky 0-dim case.
35- if ndims_codomain (ft) == first (blocklengths (biperm)) # compile time
36- if Tuple (biperm) == ntuple (identity, ndims (ft))
37- return ft
47+ # early return for identity operation. Also handle tricky 0-dim case.
48+ if ndims_codomain (ftdst) == ndims_codomain (ftsrc) # compile time
49+ if Tuple (biperm) == ntuple (identity, ndims (ftdst))
50+ copy! (data_matrix (ftdst), data_matrix (ftsrc))
51+ return nothing
3852 end
3953 end
54+ return permute_data! (SymmetryStyle (ftdst), ftdst, ftsrc, Tuple (biperm))
55+ end
4056
41- new_ft = FusionTensor {eltype(ft)} (undef, axes (ft)[biperm])
42- fusiontensor_permutedims! (new_ft, ft, Tuple (biperm))
43- return new_ft
57+ # =============================== Internal =============================================
58+ function permute_data! (:: AbelianStyle , ftdst, ftsrc, flatperm)
59+ # abelian case: all unitary blocks are 1x1 identity matrices
60+ # compute_unitary is only called to get block positions
61+ unitary = compute_unitary (ftdst, ftsrc, flatperm)
62+ for ((old_trees, new_trees), _) in unitary
63+ new_block = view (ftdst, new_trees... )
64+ old_block = view (ftsrc, old_trees... )
65+ @strided new_block .= permutedims (old_block, flatperm)
66+ end
4467end
4568
46- function fusiontensor_permutedims! (
47- new_ft:: FusionTensor{T,N} , old_ft:: FusionTensor{T,N} , flatperm:: NTuple{N,Integer}
48- ) where {T,N}
49- foreach (m -> fill! (m, zero (T)), eachstoredblock (data_matrix (new_ft)))
50- unitary = compute_unitary (new_ft, old_ft, flatperm)
69+ function permute_data! (:: NotAbelianStyle , ftdst, ftsrc, flatperm)
70+ foreach (m -> fill! (m, zero (eltype (ftdst))), eachstoredblock (data_matrix (ftdst)))
71+ unitary = compute_unitary (ftdst, ftsrc, flatperm)
5172 for ((old_trees, new_trees), coeff) in unitary
52- new_block = view (new_ft , new_trees... )
53- old_block = view (old_ft , old_trees... )
73+ new_block = view (ftdst , new_trees... )
74+ old_block = view (ftsrc , old_trees... )
5475 @strided new_block .+ = coeff .* permutedims (old_block, flatperm)
5576 end
5677end
0 commit comments