Skip to content

Commit 365a049

Browse files
committed
define permutedims!
1 parent c8f3ad5 commit 365a049

File tree

5 files changed

+81
-30
lines changed

5 files changed

+81
-30
lines changed

src/fusiontensor/base_interface.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ Base.imag(ft::FusionTensor) = set_data_matrix(ft, imag(data_matrix(ft)))
7777

7878
Base.permutedims(ft::FusionTensor, args...) = fusiontensor_permutedims(ft, args...)
7979

80+
function Base.permutedims!(ftdst::FusionTensor, ftsrc::FusionTensor, args...)
81+
return fusiontensor_permutedims!(ftdst, ftsrc, args...)
82+
end
83+
8084
Base.real(ft::FusionTensor{<:Real}) = ft # same object
8185
Base.real(ft::FusionTensor) = set_data_matrix(ft, real(data_matrix(ft)))
8286

src/fusiontensor/fusiontensor.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,8 @@ function GradedArrays.sector_type(::Type{FT}) where {FT<:FusionTensor}
255255
return sector_type(type_parameters(FT, 3))
256256
end
257257

258+
SymmetryStyle(::Type{FT}) where {FT<:FusionTensor} = SymmetryStyle(sector_type(FT))
259+
258260
# ============================== FusionTensor interface ==================================
259261

260262
# misc access

src/fusiontensor/fusiontensoraxes.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using GradedArrays:
33
GradedArrays,
44
AbstractGradedUnitRange,
55
AbstractSector,
6+
SymmetryStyle,
67
TrivialSector,
78
dual,
89
sector_type,
@@ -110,6 +111,10 @@ function GradedArrays.sector_type(::Type{FTA}) where {BT,FTA<:FusionTensorAxes{B
110111
return sector_type(type_parameters(type_parameters(BT, 3), 1))
111112
end
112113

114+
function GradedArrays.SymmetryStyle(::Type{FTA}) where {FTA<:FusionTensorAxes}
115+
return SymmetryStyle(sector_type(FTA))
116+
end
117+
113118
function GradedArrays.checkspaces(
114119
::Type{Bool}, left::FusionTensorAxes, right::FusionTensorAxes
115120
)

src/permutedims/permutedims.jl

Lines changed: 47 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,18 @@
33
using BlockArrays: blocklengths
44
using Strided: Strided, @strided
55

6-
using TensorAlgebra: BlockedPermutation, permmortar, blockpermute
7-
8-
function naive_permutedims(ft, biperm::BlockedPermutation{2})
9-
@assert ndims(ft) == length(biperm)
10-
11-
# naive permute: cast to dense, permutedims, cast to FusionTensor
12-
arr = Array(ft)
13-
permuted_arr = permutedims(arr, Tuple(biperm))
14-
permuted = to_fusiontensor(permuted_arr, blocks(axes(ft)[biperm])...)
15-
return permuted
16-
end
6+
using GradedArrays: AbelianStyle, NotAbelianStyle, SymmetryStyle, checkspaces
7+
using TensorAlgebra: AbstractBlockPermutation, permmortar
178

189
# permutedims with 1 tuple of 2 separate tuples
1910
function fusiontensor_permutedims(ft, new_leg_indices::Tuple{Tuple,Tuple})
2011
return fusiontensor_permutedims(ft, new_leg_indices...)
2112
end
2213

14+
function fusiontensor_permutedims!(ftdst, ftsrc, new_leg_indices::Tuple{Tuple,Tuple})
15+
return fusiontensor_permutedims!(ftdst, ftsrc, new_leg_indices...)
16+
end
17+
2318
# permutedims with 2 separate tuples
2419
function fusiontensor_permutedims(
2520
ft, new_codomain_indices::Tuple, new_domain_indices::Tuple
@@ -28,29 +23,55 @@ function fusiontensor_permutedims(
2823
return fusiontensor_permutedims(ft, biperm)
2924
end
3025

31-
function fusiontensor_permutedims(ft, biperm::BlockedPermutation{2})
26+
function fusiontensor_permutedims!(
27+
ftdst, ftsrc, new_codomain_indices::Tuple, new_domain_indices::Tuple
28+
)
29+
biperm = permmortar((new_codomain_indices, new_domain_indices))
30+
return fusiontensor_permutedims!(ftdst, ftsrc, biperm)
31+
end
32+
33+
# permutedims with BlockedPermutation
34+
function fusiontensor_permutedims(ft, biperm::AbstractBlockPermutation{2})
3235
ndims(ft) == length(biperm) || throw(ArgumentError("Invalid permutation length"))
36+
ftdst = FusionTensor{eltype(ft)}(undef, axes(ft)[biperm])
37+
fusiontensor_permutedims!(ftdst, ft, biperm)
38+
return ftdst
39+
end
40+
41+
function fusiontensor_permutedims!(ftdst, ftsrc, biperm::AbstractBlockPermutation{2})
42+
ndims(ftsrc) == length(biperm) || throw(ArgumentError("Invalid permutation length"))
43+
blocklengths(axes(ftdst)) == blocklengths(biperm) ||
44+
throw(ArgumentError("Destination tensor does not match bipermutation"))
45+
checkspaces(axes(ftdst), axes(ftsrc)[biperm])
3346

34-
# early return for identity operation. Do not copy. Also handle tricky 0-dim case.
35-
if ndims_codomain(ft) == first(blocklengths(biperm)) # compile time
36-
if Tuple(biperm) == ntuple(identity, ndims(ft))
37-
return ft
47+
# early return for identity operation. Also handle tricky 0-dim case.
48+
if ndims_codomain(ftdst) == ndims_codomain(ftsrc) # compile time
49+
if Tuple(biperm) == ntuple(identity, ndims(ftdst))
50+
copy!(data_matrix(ftdst), data_matrix(ftsrc))
51+
return nothing
3852
end
3953
end
54+
return permute_data!(SymmetryStyle(ftdst), ftdst, ftsrc, Tuple(biperm))
55+
end
4056

41-
new_ft = FusionTensor{eltype(ft)}(undef, axes(ft)[biperm])
42-
fusiontensor_permutedims!(new_ft, ft, Tuple(biperm))
43-
return new_ft
57+
# =============================== Internal =============================================
58+
function permute_data!(::AbelianStyle, ftdst, ftsrc, flatperm)
59+
# abelian case: all unitary blocks are 1x1 identity matrices
60+
# compute_unitary is only called to get block positions
61+
unitary = compute_unitary(ftdst, ftsrc, flatperm)
62+
for ((old_trees, new_trees), _) in unitary
63+
new_block = view(ftdst, new_trees...)
64+
old_block = view(ftsrc, old_trees...)
65+
@strided new_block .= permutedims(old_block, flatperm)
66+
end
4467
end
4568

46-
function fusiontensor_permutedims!(
47-
new_ft::FusionTensor{T,N}, old_ft::FusionTensor{T,N}, flatperm::NTuple{N,Integer}
48-
) where {T,N}
49-
foreach(m -> fill!(m, zero(T)), eachstoredblock(data_matrix(new_ft)))
50-
unitary = compute_unitary(new_ft, old_ft, flatperm)
69+
function permute_data!(::NotAbelianStyle, ftdst, ftsrc, flatperm)
70+
foreach(m -> fill!(m, zero(eltype(ftdst))), eachstoredblock(data_matrix(ftdst)))
71+
unitary = compute_unitary(ftdst, ftsrc, flatperm)
5172
for ((old_trees, new_trees), coeff) in unitary
52-
new_block = view(new_ft, new_trees...)
53-
old_block = view(old_ft, old_trees...)
73+
new_block = view(ftdst, new_trees...)
74+
old_block = view(ftsrc, old_trees...)
5475
@strided new_block .+= coeff .* permutedims(old_block, flatperm)
5576
end
5677
end

test/test_permutedims.jl

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
using Test: @test, @testset, @test_broken, @test_throws
2+
using BlockArrays: blocks
23

34
using FusionTensors:
45
FusionTensor,
56
FusionTensorAxes,
67
data_matrix,
78
codomain_axis,
89
domain_axis,
9-
naive_permutedims,
1010
ndims_domain,
1111
ndims_codomain,
1212
to_fusiontensor
@@ -15,6 +15,16 @@ using TensorAlgebra: permmortar, tuplemortar
1515

1616
include("setup.jl")
1717

18+
function naive_permutedims(ft, biperm)
19+
@assert ndims(ft) == length(biperm)
20+
21+
# naive permute: cast to dense, permutedims, cast to FusionTensor
22+
arr = Array(ft)
23+
permuted_arr = permutedims(arr, Tuple(biperm))
24+
permuted = to_fusiontensor(permuted_arr, blocks(axes(ft)[biperm])...)
25+
return permuted
26+
end
27+
1828
@testset "Abelian permutedims" begin
1929
@testset "dummy" begin
2030
g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3])
@@ -28,14 +38,23 @@ include("setup.jl")
2838

2939
# test permutedims interface
3040
ft2 = permutedims(ft1, (1, 2), (3, 4)) # trivial with 2 tuples
31-
@test ft2 === ft1 # same object
41+
@test ft2 ft1
42+
@test ft2 !== ft1
43+
@test data_matrix(ft2) !== data_matrix(ft1) # check copy
44+
@test data_matrix(ft2) == data_matrix(ft1) # check copy
3245

3346
ft2 = permutedims(ft1, ((1, 2), (3, 4))) # trivial with tuple of 2 tuples
34-
@test ft2 === ft1 # same object
47+
@test ft2 ft1
48+
@test ft2 !== ft1
49+
@test data_matrix(ft2) !== data_matrix(ft1) # check copy
50+
@test data_matrix(ft2) == data_matrix(ft1) # check copy
3551

3652
biperm = permmortar(((1, 2), (3, 4)))
3753
ft2 = permutedims(ft1, biperm) # trivial with biperm
38-
@test ft2 === ft1 # same object
54+
@test ft2 ft1
55+
@test ft2 !== ft1
56+
@test data_matrix(ft2) !== data_matrix(ft1) # check copy
57+
@test data_matrix(ft2) == data_matrix(ft1) # check copy
3958

4059
ft3 = permutedims(ft1, (4,), (1, 2, 3))
4160
@test ft3 !== ft1

0 commit comments

Comments
 (0)