Skip to content

Commit 652235f

Browse files
authored
define Base.permutedims! (#64)
1 parent 0f5ee4c commit 652235f

File tree

6 files changed

+121
-39
lines changed

6 files changed

+121
-39
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "FusionTensors"
22
uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.5.3"
4+
version = "0.5.4"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

src/fusiontensor/base_interface.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ Base.imag(ft::FusionTensor) = set_data_matrix(ft, imag(data_matrix(ft)))
7777

7878
Base.permutedims(ft::FusionTensor, args...) = fusiontensor_permutedims(ft, args...)
7979

80+
function Base.permutedims!(ftdst::FusionTensor, ftsrc::FusionTensor, args...)
81+
return fusiontensor_permutedims!(ftdst, ftsrc, args...)
82+
end
83+
8084
Base.real(ft::FusionTensor{<:Real}) = ft # same object
8185
Base.real(ft::FusionTensor) = set_data_matrix(ft, real(data_matrix(ft)))
8286

@@ -103,13 +107,18 @@ end
103107
function Base.similar(::FusionTensor, ::Type, ::Tuple{})
104108
throw(MethodError(similar, (Tuple{},)))
105109
end
106-
107110
function Base.similar(
108111
ft::FusionTensor, ::Type{T}, new_axes::Tuple{<:Tuple,<:Tuple}
109112
) where {T}
110113
return similar(ft, T, tuplemortar(new_axes))
111114
end
112-
function Base.similar(::FusionTensor, ::Type{T}, new_axes::BlockedTuple{2}) where {T}
115+
function Base.similar(ft::FusionTensor, ::Type{T}, new_axes::BlockedTuple{2}) where {T}
116+
return similar(ft, T, FusionTensorAxes(new_axes))
117+
end
118+
function Base.similar(ft::FusionTensor, new_axes::FusionTensorAxes)
119+
return similar(ft, eltype(ft), new_axes)
120+
end
121+
function Base.similar(::FusionTensor, ::Type{T}, new_axes::FusionTensorAxes) where {T}
113122
return FusionTensor{T}(undef, new_axes)
114123
end
115124

src/fusiontensor/fusiontensor.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,8 @@ function GradedArrays.sector_type(::Type{FT}) where {FT<:FusionTensor}
249249
return sector_type(type_parameters(FT, 3))
250250
end
251251

252+
SymmetryStyle(::Type{FT}) where {FT<:FusionTensor} = SymmetryStyle(sector_type(FT))
253+
252254
# ============================== FusionTensor interface ==================================
253255

254256
# misc access

src/fusiontensor/fusiontensoraxes.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using GradedArrays:
33
GradedArrays,
44
AbstractGradedUnitRange,
55
AbstractSector,
6+
SymmetryStyle,
67
TrivialSector,
78
dual,
89
sector_type,
@@ -110,6 +111,10 @@ function GradedArrays.sector_type(::Type{FTA}) where {BT,FTA<:FusionTensorAxes{B
110111
return sector_type(type_parameters(type_parameters(BT, 3), 1))
111112
end
112113

114+
function GradedArrays.SymmetryStyle(::Type{FTA}) where {FTA<:FusionTensorAxes}
115+
return SymmetryStyle(sector_type(FTA))
116+
end
117+
113118
function GradedArrays.checkspaces(
114119
::Type{Bool}, left::FusionTensorAxes, right::FusionTensorAxes
115120
)

src/permutedims/permutedims.jl

Lines changed: 52 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,54 +3,75 @@
33
using BlockArrays: blocklengths
44
using Strided: Strided, @strided
55

6-
using TensorAlgebra: BlockedPermutation, permmortar, blockpermute
6+
using GradedArrays: AbelianStyle, NotAbelianStyle, SymmetryStyle, checkspaces
7+
using TensorAlgebra: AbstractBlockPermutation, permmortar
78

8-
function naive_permutedims(ft, biperm::BlockedPermutation{2})
9-
@assert ndims(ft) == length(biperm)
10-
11-
# naive permute: cast to dense, permutedims, cast to FusionTensor
12-
arr = Array(ft)
13-
permuted_arr = permutedims(arr, Tuple(biperm))
14-
permuted = to_fusiontensor(permuted_arr, blocks(axes(ft)[biperm])...)
15-
return permuted
9+
# permutedims with 1 tuple of 2 separate tuples
10+
function fusiontensor_permutedims(ft, new_leg_dims::Tuple{Tuple,Tuple})
11+
return fusiontensor_permutedims(ft, new_leg_dims...)
1612
end
1713

18-
# permutedims with 1 tuple of 2 separate tuples
19-
function fusiontensor_permutedims(ft, new_leg_indices::Tuple{Tuple,Tuple})
20-
return fusiontensor_permutedims(ft, new_leg_indices...)
14+
function fusiontensor_permutedims!(ftdst, ftsrc, new_leg_dims::Tuple{Tuple,Tuple})
15+
return fusiontensor_permutedims!(ftdst, ftsrc, new_leg_dims...)
2116
end
2217

2318
# permutedims with 2 separate tuples
24-
function fusiontensor_permutedims(
25-
ft, new_codomain_indices::Tuple, new_domain_indices::Tuple
26-
)
27-
biperm = permmortar((new_codomain_indices, new_domain_indices))
19+
function fusiontensor_permutedims(ft, new_codomain_dims::Tuple, new_domain_dims::Tuple)
20+
biperm = permmortar((new_codomain_dims, new_domain_dims))
2821
return fusiontensor_permutedims(ft, biperm)
2922
end
3023

31-
function fusiontensor_permutedims(ft, biperm::BlockedPermutation{2})
24+
function fusiontensor_permutedims!(
25+
ftdst, ftsrc, new_codomain_dims::Tuple, new_domain_dims::Tuple
26+
)
27+
biperm = permmortar((new_codomain_dims, new_domain_dims))
28+
return fusiontensor_permutedims!(ftdst, ftsrc, biperm)
29+
end
30+
31+
# permutedims with BlockedPermutation
32+
function fusiontensor_permutedims(ft, biperm::AbstractBlockPermutation{2})
3233
ndims(ft) == length(biperm) || throw(ArgumentError("Invalid permutation length"))
34+
ftdst = similar(ft, axes(ft)[biperm])
35+
fusiontensor_permutedims!(ftdst, ft, biperm)
36+
return ftdst
37+
end
38+
39+
function fusiontensor_permutedims!(ftdst, ftsrc, biperm::AbstractBlockPermutation{2})
40+
ndims(ftsrc) == length(biperm) || throw(ArgumentError("Invalid permutation length"))
41+
blocklengths(axes(ftdst)) == blocklengths(biperm) ||
42+
throw(ArgumentError("Destination tensor does not match bipermutation"))
43+
checkspaces(axes(ftdst), axes(ftsrc)[biperm])
3344

34-
# early return for identity operation. Do not copy. Also handle tricky 0-dim case.
35-
if ndims_codomain(ft) == first(blocklengths(biperm)) # compile time
36-
if Tuple(biperm) == ntuple(identity, ndims(ft))
37-
return ft
45+
# early return for identity operation. Also handle tricky 0-dim case.
46+
if ndims_codomain(ftdst) == ndims_codomain(ftsrc) # compile time
47+
if Tuple(biperm) == ntuple(identity, ndims(ftdst))
48+
copy!(data_matrix(ftdst), data_matrix(ftsrc))
49+
return ftdst
3850
end
3951
end
52+
return permute_data!(SymmetryStyle(ftdst), ftdst, ftsrc, Tuple(biperm))
53+
end
4054

41-
new_ft = FusionTensor{eltype(ft)}(undef, axes(ft)[biperm])
42-
fusiontensor_permutedims!(new_ft, ft, Tuple(biperm))
43-
return new_ft
55+
# =============================== Internal =============================================
56+
function permute_data!(::AbelianStyle, ftdst, ftsrc, flatperm)
57+
# abelian case: all unitary blocks are 1x1 identity matrices
58+
# compute_unitary is only called to get block positions
59+
unitary = compute_unitary(ftdst, ftsrc, flatperm)
60+
for ((old_trees, new_trees), _) in unitary
61+
new_block = view(ftdst, new_trees...)
62+
old_block = view(ftsrc, old_trees...)
63+
@strided new_block .= permutedims(old_block, flatperm)
64+
end
65+
return ftdst
4466
end
4567

46-
function fusiontensor_permutedims!(
47-
new_ft::FusionTensor{T,N}, old_ft::FusionTensor{T,N}, flatperm::NTuple{N,Integer}
48-
) where {T,N}
49-
foreach(m -> fill!(m, zero(T)), eachstoredblock(data_matrix(new_ft)))
50-
unitary = compute_unitary(new_ft, old_ft, flatperm)
68+
function permute_data!(::NotAbelianStyle, ftdst, ftsrc, flatperm)
69+
foreach(m -> fill!(m, zero(eltype(ftdst))), eachstoredblock(data_matrix(ftdst)))
70+
unitary = compute_unitary(ftdst, ftsrc, flatperm)
5171
for ((old_trees, new_trees), coeff) in unitary
52-
new_block = view(new_ft, new_trees...)
53-
old_block = view(old_ft, old_trees...)
72+
new_block = view(ftdst, new_trees...)
73+
old_block = view(ftsrc, old_trees...)
5474
@strided new_block .+= coeff .* permutedims(old_block, flatperm)
5575
end
76+
return ftdst
5677
end

test/test_permutedims.jl

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
using Test: @test, @testset, @test_broken, @test_throws
2+
using BlockArrays: blocks
23

34
using FusionTensors:
45
FusionTensor,
56
FusionTensorAxes,
67
data_matrix,
78
codomain_axis,
89
domain_axis,
9-
naive_permutedims,
1010
ndims_domain,
1111
ndims_codomain,
1212
to_fusiontensor
@@ -15,27 +15,47 @@ using TensorAlgebra: permmortar, tuplemortar
1515

1616
include("setup.jl")
1717

18+
function naive_permutedims(ft, biperm)
19+
@assert ndims(ft) == length(biperm)
20+
21+
# naive permute: cast to dense, permutedims, cast to FusionTensor
22+
arr = Array(ft)
23+
permuted_arr = permutedims(arr, Tuple(biperm))
24+
permuted = to_fusiontensor(permuted_arr, blocks(axes(ft)[biperm])...)
25+
return permuted
26+
end
27+
1828
@testset "Abelian permutedims" begin
1929
@testset "dummy" begin
2030
g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3])
2131
g2 = gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1])
2232
g3 = gradedrange([U1(-1) => 1, U1(0) => 2, U1(1) => 1])
2333
g4 = gradedrange([U1(-1) => 1, U1(0) => 1, U1(1) => 1])
34+
ftaxes1 = FusionTensorAxes((g1, g2), (dual(g3), dual(g4)))
2435

2536
for elt in (Float64, ComplexF64)
26-
ft1 = FusionTensor{elt}(undef, (g1, g2), dual.((g3, g4)))
37+
ft1 = randn(elt, ftaxes1)
2738
@test isnothing(check_sanity(ft1))
2839

2940
# test permutedims interface
3041
ft2 = permutedims(ft1, (1, 2), (3, 4)) # trivial with 2 tuples
31-
@test ft2 === ft1 # same object
42+
@test ft2 ft1
43+
@test ft2 !== ft1
44+
@test data_matrix(ft2) !== data_matrix(ft1) # check copy
45+
@test data_matrix(ft2) == data_matrix(ft1) # check copy
3246

3347
ft2 = permutedims(ft1, ((1, 2), (3, 4))) # trivial with tuple of 2 tuples
34-
@test ft2 === ft1 # same object
48+
@test ft2 ft1
49+
@test ft2 !== ft1
50+
@test data_matrix(ft2) !== data_matrix(ft1) # check copy
51+
@test data_matrix(ft2) == data_matrix(ft1) # check copy
3552

3653
biperm = permmortar(((1, 2), (3, 4)))
3754
ft2 = permutedims(ft1, biperm) # trivial with biperm
38-
@test ft2 === ft1 # same object
55+
@test ft2 ft1
56+
@test ft2 !== ft1
57+
@test data_matrix(ft2) !== data_matrix(ft1) # check copy
58+
@test data_matrix(ft2) == data_matrix(ft1) # check copy
3959

4060
ft3 = permutedims(ft1, (4,), (1, 2, 3))
4161
@test ft3 !== ft1
@@ -49,8 +69,33 @@ include("setup.jl")
4969
@test space_isequal(domain_axis(ft1), domain_axis(ft4))
5070
@test ft4 ft1
5171

72+
# test permutedims! interface
73+
ft2 = randn(elt, axes(ft1))
74+
permutedims!(ft2, ft1, (1, 2), (3, 4))
75+
@test ft2 ft1
76+
@test data_matrix(ft2) !== data_matrix(ft1) # check copy
77+
@test data_matrix(ft2) == data_matrix(ft1) # check copy
78+
79+
ft2 = randn(elt, axes(ft1))
80+
permutedims!(ft2, ft1, ((1, 2), (3, 4)))
81+
@test ft2 ft1
82+
@test data_matrix(ft2) !== data_matrix(ft1) # check copy
83+
@test data_matrix(ft2) == data_matrix(ft1) # check copy
84+
85+
ft2 = randn(elt, axes(ft1))
86+
permutedims!(ft2, ft1, biperm)
87+
@test ft2 ft1
88+
@test data_matrix(ft2) !== data_matrix(ft1) # check copy
89+
@test data_matrix(ft2) == data_matrix(ft1) # check copy
90+
91+
# test clean errors
92+
ft2 = randn(elt, axes(ft1))
5293
@test_throws MethodError permutedims(ft1, (2, 3, 4, 1))
5394
@test_throws ArgumentError permutedims(ft1, (2, 3), (5, 4, 1))
95+
@test_throws MethodError permutedims!(ft2, ft1, (2, 3, 4, 1))
96+
@test_throws ArgumentError permutedims!(ft2, ft1, (2, 3), (5, 4, 1))
97+
@test_throws ArgumentError permutedims!(ft2, ft1, (1, 2, 3), (4,))
98+
@test_throws ArgumentError permutedims!(ft2, ft1, (1, 2), (4, 3))
5499
end
55100
end
56101

0 commit comments

Comments
 (0)