@@ -4,7 +4,7 @@ using BlockArrays: Block
44using GradedArrays: space_isequal
55using LinearAlgebra: mul!
66using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, FusionStyle, blockedperm,
7- genperm, matricize, unmatricize
7+ blockedtrivialperm, genperm, matricize, unmatricize
88
99function TensorAlgebra. output_axes (
1010 :: typeof (contract),
@@ -52,15 +52,15 @@ function TensorAlgebra.unmatricize(
5252 return FusionTensor (data_matrix (m), codomain_axes, domain_axes)
5353end
5454
55- function TensorAlgebra. permuteblockeddims (
55+ function TensorAlgebra. bipermutedims (
5656 ft:: FusionTensor ,
5757 codomain_perm:: Tuple{Vararg{Int}} ,
5858 domain_perm:: Tuple{Vararg{Int}} ,
5959 )
6060 return permutedims (ft, permmortar ((codomain_perm, domain_perm)))
6161end
6262
63- function TensorAlgebra. permuteblockeddims ! (
63+ function TensorAlgebra. bipermutedims ! (
6464 a_dest:: FusionTensor ,
6565 a_src:: FusionTensor ,
6666 codomain_perm:: Tuple{Vararg{Int}} ,
@@ -69,7 +69,8 @@ function TensorAlgebra.permuteblockeddims!(
6969 return permutedims! (a_dest, a_src, permmortar ((codomain_perm, domain_perm)))
7070end
7171
72- # TODO define custom broadcast rules
72+ # TODO : Define custom broadcast rules for FusionTensors so that we can delete
73+ # this method.
7374function TensorAlgebra. unmatricizeadd! (
7475 style:: FusionTensorFusionStyle ,
7576 a_dest:: AbstractArray ,
@@ -119,11 +120,11 @@ for f in MATRIX_FUNCTIONS
119120 @eval begin
120121 function TensorAlgebra. $f (
121122 a:: FusionTensor ,
122- codomain_perm :: Tuple{Vararg{Int}} , domain_perm :: Tuple{Vararg{Int}} ;
123+ codomain_length :: Val , domain_length :: Val ;
123124 kwargs... ,
124125 )
125- a_mat = matricize (a, codomain_perm, domain_perm )
126- biperm = permmortar ((codomain_perm, domain_perm ))
126+ a_mat = matricize (a, codomain_length, domain_length )
127+ biperm = blockedtrivialperm ((codomain_length, domain_length ))
127128 permuted_axes = axes (a)[biperm]
128129 checkspaces_dual (codomain (permuted_axes), domain (permuted_axes))
129130 fa_mat = set_data_matrix (a_mat, Base.$ f (data_matrix (a_mat); kwargs... ))
0 commit comments