Skip to content

Commit 25cd023

Browse files
authored
Upgrade to TensorAlgebra.jl v0.5 (#88)
1 parent 2a853ff commit 25cd023

File tree

4 files changed

+82
-67
lines changed

4 files changed

+82
-67
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "FusionTensors"
22
uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.5.14"
4+
version = "0.5.15"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -29,7 +29,7 @@ LRUCache = "1.6"
2929
LinearAlgebra = "1.10"
3030
Random = "1.10"
3131
Strided = "2.3"
32-
TensorAlgebra = "0.4"
32+
TensorAlgebra = "0.5.1"
3333
TensorKitSectors = "0.1, 0.2"
3434
TensorProducts = "0.1.7"
3535
TypeParameterAccessors = "0.4"

src/fusiontensor/fusiontensoraxes.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ end
4444
# ==================================== Definitions =======================================
4545

4646
# TBD explicit axis type as type parameters?
47-
struct FusionTensorAxes{BT <: BlockedTuple{2}}
47+
struct FusionTensorAxes{BT <: AbstractBlockTuple{2}} <: AbstractBlockTuple{2}
4848
outer_axes::BT
4949

5050
function FusionTensorAxes{BT}(bt) where {BT}
@@ -75,19 +75,19 @@ TensorAlgebra.length_domain(fta::FusionTensorAxes) = length(domain(fta))
7575
# ================================== Base interface ======================================
7676

7777
for f in [
78-
:(broadcastable), :(Tuple), :(axes), :(firstindex), :(lastindex), :(iterate), :(length),
78+
:(broadcastable), :(Tuple), :(axes), :(firstindex), :(lastindex), :(length),
7979
]
8080
@eval Base.$f(fta::FusionTensorAxes) = Base.$f(BlockedTuple(fta))
8181
end
8282

83-
for f in [:(getindex), :(iterate)]
84-
@eval Base.$f(fta::FusionTensorAxes, i) = $f(BlockedTuple(fta), i)
85-
end
86-
83+
Base.getindex(fta::FusionTensorAxes, i::Int) = BlockedTuple(fta)[i]
8784
function Base.getindex(fta::FusionTensorAxes, bp::AbstractBlockPermutation)
8885
return FusionTensorAxes(BlockedTuple(fta)[bp])
8986
end
9087

88+
Base.iterate(fta::FusionTensorAxes) = iterate(BlockedTuple(fta))
89+
Base.iterate(fta::FusionTensorAxes, state::Int) = iterate(BlockedTuple(fta), state)
90+
9191
Base.copy(fta::FusionTensorAxes) = FusionTensorAxes(copy.(BlockedTuple(fta)))
9292

9393
Base.deepcopy(fta::FusionTensorAxes) = FusionTensorAxes(deepcopy.(BlockedTuple(fta)))

src/fusiontensor/tensor_algebra_interface.jl

Lines changed: 70 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,10 @@
11
# This file defines TensorAlgebra interface for a FusionTensor
22

3-
using LinearAlgebra: mul!
4-
53
using BlockArrays: Block
6-
74
using GradedArrays: space_isequal
8-
using TensorAlgebra:
9-
TensorAlgebra,
10-
AbstractBlockPermutation,
11-
BlockedTrivialPermutation,
12-
BlockedTuple,
13-
FusionStyle,
14-
Matricize,
15-
blockedperm,
16-
genperm,
17-
matricize,
18-
unmatricize
19-
20-
const MATRIX_FUNCTIONS = [
21-
:exp,
22-
:cis,
23-
:log,
24-
:sqrt,
25-
:cbrt,
26-
:cos,
27-
:sin,
28-
:tan,
29-
:csc,
30-
:sec,
31-
:cot,
32-
:cosh,
33-
:sinh,
34-
:tanh,
35-
:csch,
36-
:sech,
37-
:coth,
38-
:acos,
39-
:asin,
40-
:atan,
41-
:acsc,
42-
:asec,
43-
:acot,
44-
:acosh,
45-
:asinh,
46-
:atanh,
47-
:acsch,
48-
:asech,
49-
:acoth,
50-
]
5+
using LinearAlgebra: mul!
6+
using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, FusionStyle, blockedperm,
7+
genperm, matricize, unmatricize
518

529
function TensorAlgebra.output_axes(
5310
::typeof(contract),
@@ -75,43 +32,98 @@ struct FusionTensorFusionStyle <: FusionStyle end
7532

7633
TensorAlgebra.FusionStyle(::Type{<:FusionTensor}) = FusionTensorFusionStyle()
7734

35+
unval(::Val{x}) where {x} = x
36+
7837
function TensorAlgebra.matricize(
79-
::FusionTensorFusionStyle, ft::AbstractArray, biperm::BlockedTrivialPermutation{2}
38+
::FusionTensorFusionStyle, ft::AbstractArray,
39+
codomain_length::Val, domain_length::Val
8040
)
81-
blocklengths(biperm) == blocklengths(axes(ft)) ||
41+
blocklengths(axes(ft)) == unval.((codomain_length, domain_length)) ||
8242
throw(ArgumentError("Invalid trivial biperm"))
8343
return FusionTensor(data_matrix(ft), (codomain_axis(ft),), (domain_axis(ft),))
8444
end
8545

86-
function TensorAlgebra.unmatricize(::FusionTensorFusionStyle, m, blocked_axes)
87-
return FusionTensor(data_matrix(m), blocked_axes)
46+
function TensorAlgebra.unmatricize(
47+
::FusionTensorFusionStyle,
48+
m::AbstractMatrix,
49+
codomain_axes::Tuple{Vararg{AbstractUnitRange}},
50+
domain_axes::Tuple{Vararg{AbstractUnitRange}},
51+
)
52+
return FusionTensor(data_matrix(m), codomain_axes, domain_axes)
8853
end
8954

9055
function TensorAlgebra.permuteblockeddims(
91-
ft::FusionTensor, biperm::AbstractBlockPermutation
56+
ft::FusionTensor,
57+
codomain_perm::Tuple{Vararg{Int}},
58+
domain_perm::Tuple{Vararg{Int}},
9259
)
93-
return permutedims(ft, biperm)
60+
return permutedims(ft, permmortar((codomain_perm, domain_perm)))
9461
end
9562

9663
function TensorAlgebra.permuteblockeddims!(
97-
a::FusionTensor, b::FusionTensor, biperm::AbstractBlockPermutation
64+
a_dest::FusionTensor,
65+
a_src::FusionTensor,
66+
codomain_perm::Tuple{Vararg{Int}},
67+
domain_perm::Tuple{Vararg{Int}},
9868
)
99-
return permutedims!(a, b, biperm)
69+
return permutedims!(a_dest, a_src, permmortar((codomain_perm, domain_perm)))
10070
end
10171

10272
# TODO define custom broadcast rules
103-
function TensorAlgebra.unmatricizeadd!(a_dest::FusionTensor, a_dest_mat, invbiperm, α, β)
104-
a12 = unmatricize(a_dest_mat, axes(a_dest), invbiperm)
73+
function TensorAlgebra.unmatricizeadd!(
74+
style::FusionTensorFusionStyle,
75+
a_dest::AbstractArray,
76+
a_dest_mat::AbstractMatrix,
77+
codomain_perm::Tuple{Vararg{Int}},
78+
domain_perm::Tuple{Vararg{Int}},
79+
α::Number, β::Number,
80+
)
81+
a12 = unmatricize(a_dest_mat, axes(a_dest), codomain_perm, domain_perm)
10582
data_matrix(a_dest) .= α .* data_matrix(a12) .+ β .* data_matrix(a_dest)
10683
return a_dest
10784
end
10885

86+
const MATRIX_FUNCTIONS = [
87+
:exp,
88+
:cis,
89+
:log,
90+
:sqrt,
91+
:cbrt,
92+
:cos,
93+
:sin,
94+
:tan,
95+
:csc,
96+
:sec,
97+
:cot,
98+
:cosh,
99+
:sinh,
100+
:tanh,
101+
:csch,
102+
:sech,
103+
:coth,
104+
:acos,
105+
:asin,
106+
:atan,
107+
:acsc,
108+
:asec,
109+
:acot,
110+
:acosh,
111+
:asinh,
112+
:atanh,
113+
:acsch,
114+
:asech,
115+
:acoth,
116+
]
117+
109118
for f in MATRIX_FUNCTIONS
110119
@eval begin
111120
function TensorAlgebra.$f(
112-
a::FusionTensor, biperm::AbstractBlockPermutation{2}; kwargs...
121+
a::FusionTensor,
122+
codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}};
123+
kwargs...,
113124
)
114-
a_mat = matricize(a, biperm)
125+
a_mat = matricize(a, codomain_perm, domain_perm)
126+
biperm = permmortar((codomain_perm, domain_perm))
115127
permuted_axes = axes(a)[biperm]
116128
checkspaces_dual(codomain(permuted_axes), domain(permuted_axes))
117129
fa_mat = set_data_matrix(a_mat, Base.$f(data_matrix(a_mat); kwargs...))

test/Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
1313
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
1414
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1515

16+
[sources]
17+
FusionTensors = {path = ".."}
18+
1619
[compat]
1720
Aqua = "0.8.11"
1821
BlockArrays = "1.6"
@@ -24,6 +27,6 @@ Random = "1.10"
2427
SUNRepresentations = "0.3.1"
2528
SafeTestsets = "0.1.0"
2629
Suppressor = "0.2.8"
27-
TensorAlgebra = "0.4"
30+
TensorAlgebra = "0.5"
2831
TensorProducts = "0.1"
2932
Test = "1.10.0"

0 commit comments

Comments
 (0)