Skip to content

Commit 11df61e

Browse files
authored
Upgrade TensorAlgebra overloads (#89)
1 parent 25cd023 commit 11df61e

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "FusionTensors"
22
uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.5.15"
4+
version = "0.5.16"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -29,7 +29,7 @@ LRUCache = "1.6"
2929
LinearAlgebra = "1.10"
3030
Random = "1.10"
3131
Strided = "2.3"
32-
TensorAlgebra = "0.5.1"
32+
TensorAlgebra = "0.5.2"
3333
TensorKitSectors = "0.1, 0.2"
3434
TensorProducts = "0.1.7"
3535
TypeParameterAccessors = "0.4"

src/fusiontensor/tensor_algebra_interface.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using BlockArrays: Block
44
using GradedArrays: space_isequal
55
using LinearAlgebra: mul!
66
using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, FusionStyle, blockedperm,
7-
genperm, matricize, unmatricize
7+
blockedtrivialperm, genperm, matricize, unmatricize
88

99
function TensorAlgebra.output_axes(
1010
::typeof(contract),
@@ -52,15 +52,15 @@ function TensorAlgebra.unmatricize(
5252
return FusionTensor(data_matrix(m), codomain_axes, domain_axes)
5353
end
5454

55-
function TensorAlgebra.permuteblockeddims(
55+
function TensorAlgebra.bipermutedims(
5656
ft::FusionTensor,
5757
codomain_perm::Tuple{Vararg{Int}},
5858
domain_perm::Tuple{Vararg{Int}},
5959
)
6060
return permutedims(ft, permmortar((codomain_perm, domain_perm)))
6161
end
6262

63-
function TensorAlgebra.permuteblockeddims!(
63+
function TensorAlgebra.bipermutedims!(
6464
a_dest::FusionTensor,
6565
a_src::FusionTensor,
6666
codomain_perm::Tuple{Vararg{Int}},
@@ -69,7 +69,8 @@ function TensorAlgebra.permuteblockeddims!(
6969
return permutedims!(a_dest, a_src, permmortar((codomain_perm, domain_perm)))
7070
end
7171

72-
# TODO define custom broadcast rules
72+
# TODO: Define custom broadcast rules for FusionTensors so that we can delete
73+
# this method.
7374
function TensorAlgebra.unmatricizeadd!(
7475
style::FusionTensorFusionStyle,
7576
a_dest::AbstractArray,
@@ -119,11 +120,11 @@ for f in MATRIX_FUNCTIONS
119120
@eval begin
120121
function TensorAlgebra.$f(
121122
a::FusionTensor,
122-
codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}};
123+
codomain_length::Val, domain_length::Val;
123124
kwargs...,
124125
)
125-
a_mat = matricize(a, codomain_perm, domain_perm)
126-
biperm = permmortar((codomain_perm, domain_perm))
126+
a_mat = matricize(a, codomain_length, domain_length)
127+
biperm = blockedtrivialperm((codomain_length, domain_length))
127128
permuted_axes = axes(a)[biperm]
128129
checkspaces_dual(codomain(permuted_axes), domain(permuted_axes))
129130
fa_mat = set_data_matrix(a_mat, Base.$f(data_matrix(a_mat); kwargs...))

0 commit comments

Comments
 (0)