|
1 | 1 | # This file defines TensorAlgebra interface for a FusionTensor |
2 | 2 |
|
3 | | -using LinearAlgebra: mul! |
4 | | - |
5 | 3 | using BlockArrays: Block |
6 | | - |
7 | 4 | using GradedArrays: space_isequal |
8 | | -using TensorAlgebra: |
9 | | - TensorAlgebra, |
10 | | - AbstractBlockPermutation, |
11 | | - BlockedTrivialPermutation, |
12 | | - BlockedTuple, |
13 | | - FusionStyle, |
14 | | - Matricize, |
15 | | - blockedperm, |
16 | | - genperm, |
17 | | - matricize, |
18 | | - unmatricize |
19 | | - |
20 | | -const MATRIX_FUNCTIONS = [ |
21 | | - :exp, |
22 | | - :cis, |
23 | | - :log, |
24 | | - :sqrt, |
25 | | - :cbrt, |
26 | | - :cos, |
27 | | - :sin, |
28 | | - :tan, |
29 | | - :csc, |
30 | | - :sec, |
31 | | - :cot, |
32 | | - :cosh, |
33 | | - :sinh, |
34 | | - :tanh, |
35 | | - :csch, |
36 | | - :sech, |
37 | | - :coth, |
38 | | - :acos, |
39 | | - :asin, |
40 | | - :atan, |
41 | | - :acsc, |
42 | | - :asec, |
43 | | - :acot, |
44 | | - :acosh, |
45 | | - :asinh, |
46 | | - :atanh, |
47 | | - :acsch, |
48 | | - :asech, |
49 | | - :acoth, |
50 | | -] |
| 5 | +using LinearAlgebra: mul! |
| 6 | +using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, FusionStyle, blockedperm, |
| 7 | + genperm, matricize, unmatricize |
51 | 8 |
|
52 | 9 | function TensorAlgebra.output_axes( |
53 | 10 | ::typeof(contract), |
@@ -75,43 +32,98 @@ struct FusionTensorFusionStyle <: FusionStyle end |
75 | 32 |
|
76 | 33 | TensorAlgebra.FusionStyle(::Type{<:FusionTensor}) = FusionTensorFusionStyle() |
77 | 34 |
|
| 35 | +unval(::Val{x}) where {x} = x |
| 36 | + |
78 | 37 | function TensorAlgebra.matricize( |
79 | | - ::FusionTensorFusionStyle, ft::AbstractArray, biperm::BlockedTrivialPermutation{2} |
| 38 | + ::FusionTensorFusionStyle, ft::AbstractArray, |
| 39 | + codomain_length::Val, domain_length::Val |
80 | 40 | ) |
81 | | - blocklengths(biperm) == blocklengths(axes(ft)) || |
| 41 | + blocklengths(axes(ft)) == unval.((codomain_length, domain_length)) || |
82 | 42 | throw(ArgumentError("Invalid trivial biperm")) |
83 | 43 | return FusionTensor(data_matrix(ft), (codomain_axis(ft),), (domain_axis(ft),)) |
84 | 44 | end |
85 | 45 |
|
86 | | -function TensorAlgebra.unmatricize(::FusionTensorFusionStyle, m, blocked_axes) |
87 | | - return FusionTensor(data_matrix(m), blocked_axes) |
| 46 | +function TensorAlgebra.unmatricize( |
| 47 | + ::FusionTensorFusionStyle, |
| 48 | + m::AbstractMatrix, |
| 49 | + codomain_axes::Tuple{Vararg{AbstractUnitRange}}, |
| 50 | + domain_axes::Tuple{Vararg{AbstractUnitRange}}, |
| 51 | + ) |
| 52 | + return FusionTensor(data_matrix(m), codomain_axes, domain_axes) |
88 | 53 | end |
89 | 54 |
|
90 | 55 | function TensorAlgebra.permuteblockeddims( |
91 | | - ft::FusionTensor, biperm::AbstractBlockPermutation |
| 56 | + ft::FusionTensor, |
| 57 | + codomain_perm::Tuple{Vararg{Int}}, |
| 58 | + domain_perm::Tuple{Vararg{Int}}, |
92 | 59 | ) |
93 | | - return permutedims(ft, biperm) |
| 60 | + return permutedims(ft, permmortar((codomain_perm, domain_perm))) |
94 | 61 | end |
95 | 62 |
|
96 | 63 | function TensorAlgebra.permuteblockeddims!( |
97 | | - a::FusionTensor, b::FusionTensor, biperm::AbstractBlockPermutation |
| 64 | + a_dest::FusionTensor, |
| 65 | + a_src::FusionTensor, |
| 66 | + codomain_perm::Tuple{Vararg{Int}}, |
| 67 | + domain_perm::Tuple{Vararg{Int}}, |
98 | 68 | ) |
99 | | - return permutedims!(a, b, biperm) |
| 69 | + return permutedims!(a_dest, a_src, permmortar((codomain_perm, domain_perm))) |
100 | 70 | end |
101 | 71 |
|
102 | 72 | # TODO define custom broadcast rules |
103 | | -function TensorAlgebra.unmatricizeadd!(a_dest::FusionTensor, a_dest_mat, invbiperm, α, β) |
104 | | - a12 = unmatricize(a_dest_mat, axes(a_dest), invbiperm) |
| 73 | +function TensorAlgebra.unmatricizeadd!( |
| 74 | + style::FusionTensorFusionStyle, |
| 75 | + a_dest::AbstractArray, |
| 76 | + a_dest_mat::AbstractMatrix, |
| 77 | + codomain_perm::Tuple{Vararg{Int}}, |
| 78 | + domain_perm::Tuple{Vararg{Int}}, |
| 79 | + α::Number, β::Number, |
| 80 | + ) |
| 81 | + a12 = unmatricize(a_dest_mat, axes(a_dest), codomain_perm, domain_perm) |
105 | 82 | data_matrix(a_dest) .= α .* data_matrix(a12) .+ β .* data_matrix(a_dest) |
106 | 83 | return a_dest |
107 | 84 | end |
108 | 85 |
|
| 86 | +const MATRIX_FUNCTIONS = [ |
| 87 | + :exp, |
| 88 | + :cis, |
| 89 | + :log, |
| 90 | + :sqrt, |
| 91 | + :cbrt, |
| 92 | + :cos, |
| 93 | + :sin, |
| 94 | + :tan, |
| 95 | + :csc, |
| 96 | + :sec, |
| 97 | + :cot, |
| 98 | + :cosh, |
| 99 | + :sinh, |
| 100 | + :tanh, |
| 101 | + :csch, |
| 102 | + :sech, |
| 103 | + :coth, |
| 104 | + :acos, |
| 105 | + :asin, |
| 106 | + :atan, |
| 107 | + :acsc, |
| 108 | + :asec, |
| 109 | + :acot, |
| 110 | + :acosh, |
| 111 | + :asinh, |
| 112 | + :atanh, |
| 113 | + :acsch, |
| 114 | + :asech, |
| 115 | + :acoth, |
| 116 | +] |
| 117 | + |
109 | 118 | for f in MATRIX_FUNCTIONS |
110 | 119 | @eval begin |
111 | 120 | function TensorAlgebra.$f( |
112 | | - a::FusionTensor, biperm::AbstractBlockPermutation{2}; kwargs... |
| 121 | + a::FusionTensor, |
| 122 | + codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; |
| 123 | + kwargs..., |
113 | 124 | ) |
114 | | - a_mat = matricize(a, biperm) |
| 125 | + a_mat = matricize(a, codomain_perm, domain_perm) |
| 126 | + biperm = permmortar((codomain_perm, domain_perm)) |
115 | 127 | permuted_axes = axes(a)[biperm] |
116 | 128 | checkspaces_dual(codomain(permuted_axes), domain(permuted_axes)) |
117 | 129 | fa_mat = set_data_matrix(a_mat, Base.$f(data_matrix(a_mat); kwargs...)) |
|
0 commit comments