3
3
using BlockArrays: blocklengths
4
4
using Strided: Strided, @strided
5
5
6
- using TensorAlgebra: BlockedPermutation, permmortar, blockpermute
7
-
8
- function naive_permutedims (ft, biperm:: BlockedPermutation{2} )
9
- @assert ndims (ft) == length (biperm)
10
-
11
- # naive permute: cast to dense, permutedims, cast to FusionTensor
12
- arr = Array (ft)
13
- permuted_arr = permutedims (arr, Tuple (biperm))
14
- permuted = to_fusiontensor (permuted_arr, blocks (axes (ft)[biperm])... )
15
- return permuted
16
- end
6
+ using GradedArrays: AbelianStyle, NotAbelianStyle, SymmetryStyle, checkspaces
7
+ using TensorAlgebra: AbstractBlockPermutation, permmortar
17
8
18
9
# permutedims with 1 tuple of 2 separate tuples
19
10
function fusiontensor_permutedims (ft, new_leg_indices:: Tuple{Tuple,Tuple} )
20
11
return fusiontensor_permutedims (ft, new_leg_indices... )
21
12
end
22
13
14
+ function fusiontensor_permutedims! (ftdst, ftsrc, new_leg_indices:: Tuple{Tuple,Tuple} )
15
+ return fusiontensor_permutedims! (ftdst, ftsrc, new_leg_indices... )
16
+ end
17
+
23
18
# permutedims with 2 separate tuples
24
19
function fusiontensor_permutedims (
25
20
ft, new_codomain_indices:: Tuple , new_domain_indices:: Tuple
@@ -28,29 +23,55 @@ function fusiontensor_permutedims(
28
23
return fusiontensor_permutedims (ft, biperm)
29
24
end
30
25
31
- function fusiontensor_permutedims (ft, biperm:: BlockedPermutation{2} )
26
+ function fusiontensor_permutedims! (
27
+ ftdst, ftsrc, new_codomain_indices:: Tuple , new_domain_indices:: Tuple
28
+ )
29
+ biperm = permmortar ((new_codomain_indices, new_domain_indices))
30
+ return fusiontensor_permutedims! (ftdst, ftsrc, biperm)
31
+ end
32
+
33
+ # permutedims with BlockedPermutation
34
+ function fusiontensor_permutedims (ft, biperm:: AbstractBlockPermutation{2} )
32
35
ndims (ft) == length (biperm) || throw (ArgumentError (" Invalid permutation length" ))
36
+ ftdst = FusionTensor {eltype(ft)} (undef, axes (ft)[biperm])
37
+ fusiontensor_permutedims! (ftdst, ft, biperm)
38
+ return ftdst
39
+ end
40
+
41
+ function fusiontensor_permutedims! (ftdst, ftsrc, biperm:: AbstractBlockPermutation{2} )
42
+ ndims (ftsrc) == length (biperm) || throw (ArgumentError (" Invalid permutation length" ))
43
+ blocklengths (axes (ftdst)) == blocklengths (biperm) ||
44
+ throw (ArgumentError (" Destination tensor does not match bipermutation" ))
45
+ checkspaces (axes (ftdst), axes (ftsrc)[biperm])
33
46
34
- # early return for identity operation. Do not copy. Also handle tricky 0-dim case.
35
- if ndims_codomain (ft) == first (blocklengths (biperm)) # compile time
36
- if Tuple (biperm) == ntuple (identity, ndims (ft))
37
- return ft
47
+ # early return for identity operation. Also handle tricky 0-dim case.
48
+ if ndims_codomain (ftdst) == ndims_codomain (ftsrc) # compile time
49
+ if Tuple (biperm) == ntuple (identity, ndims (ftdst))
50
+ copy! (data_matrix (ftdst), data_matrix (ftsrc))
51
+ return nothing
38
52
end
39
53
end
54
+ return permute_data! (SymmetryStyle (ftdst), ftdst, ftsrc, Tuple (biperm))
55
+ end
40
56
41
- new_ft = FusionTensor {eltype(ft)} (undef, axes (ft)[biperm])
42
- fusiontensor_permutedims! (new_ft, ft, Tuple (biperm))
43
- return new_ft
57
+ # =============================== Internal =============================================
58
+ function permute_data! (:: AbelianStyle , ftdst, ftsrc, flatperm)
59
+ # abelian case: all unitary blocks are 1x1 identity matrices
60
+ # compute_unitary is only called to get block positions
61
+ unitary = compute_unitary (ftdst, ftsrc, flatperm)
62
+ for ((old_trees, new_trees), _) in unitary
63
+ new_block = view (ftdst, new_trees... )
64
+ old_block = view (ftsrc, old_trees... )
65
+ @strided new_block .= permutedims (old_block, flatperm)
66
+ end
44
67
end
45
68
46
- function fusiontensor_permutedims! (
47
- new_ft:: FusionTensor{T,N} , old_ft:: FusionTensor{T,N} , flatperm:: NTuple{N,Integer}
48
- ) where {T,N}
49
- foreach (m -> fill! (m, zero (T)), eachstoredblock (data_matrix (new_ft)))
50
- unitary = compute_unitary (new_ft, old_ft, flatperm)
69
+ function permute_data! (:: NotAbelianStyle , ftdst, ftsrc, flatperm)
70
+ foreach (m -> fill! (m, zero (eltype (ftdst))), eachstoredblock (data_matrix (ftdst)))
71
+ unitary = compute_unitary (ftdst, ftsrc, flatperm)
51
72
for ((old_trees, new_trees), coeff) in unitary
52
- new_block = view (new_ft , new_trees... )
53
- old_block = view (old_ft , old_trees... )
73
+ new_block = view (ftdst , new_trees... )
74
+ old_block = view (ftsrc , old_trees... )
54
75
@strided new_block .+ = coeff .* permutedims (old_block, flatperm)
55
76
end
56
77
end
0 commit comments