|
1 | 1 | module KroneckerArraysTensorAlgebraExt |
2 | 2 |
|
3 | | -using KroneckerArrays: KroneckerArrays, AbstractKroneckerArray, ⊗, kroneckerfactors |
| 3 | +using KroneckerArrays: KroneckerArrays, AbstractKroneckerArray, CartesianProductUnitRange, |
| 4 | + ⊗, cartesianrange, kroneckerfactors, kroneckerfactortypes |
4 | 5 | using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockedTrivialPermutation, |
5 | | - FusionStyle, matricize, unmatricize |
| 6 | + FusionStyle, matricize, tensor_product_axis, unmatricize |
6 | 7 |
|
7 | 8 | struct KroneckerFusion{A <: FusionStyle, B <: FusionStyle} <: FusionStyle |
8 | 9 | a::A |
|
11 | 12 | KroneckerArrays.kroneckerfactors(style::KroneckerFusion) = (style.a, style.b) |
12 | 13 | KroneckerArrays.kroneckerfactortypes(::Type{KroneckerFusion{A, B}}) where {A, B} = (A, B) |
13 | 14 |
|
14 | | -function TensorAlgebra.FusionStyle(a::AbstractKroneckerArray) |
15 | | - return KroneckerFusion(FusionStyle.(kroneckerfactors(a))...) |
| 15 | +function TensorAlgebra.FusionStyle(A::Type{<:AbstractKroneckerArray}) |
| 16 | + return KroneckerFusion(FusionStyle.(kroneckerfactortypes(A))...) |
16 | 17 | end |
| 18 | +function TensorAlgebra.FusionStyle(A::Type{<:CartesianProductUnitRange}) |
| 19 | + return KroneckerFusion(FusionStyle.(kroneckerfactortypes(A))...) |
| 20 | +end |
| 21 | + |
| 22 | +function TensorAlgebra.tensor_product_axis( |
| 23 | + style::KroneckerFusion, r1::AbstractUnitRange, r2::AbstractUnitRange |
| 24 | + ) |
| 25 | + style_a, style_b = kroneckerfactors(style) |
| 26 | + r1a, r1b = kroneckerfactors(r1) |
| 27 | + r2a, r2b = kroneckerfactors(r2) |
| 28 | + ra = tensor_product_axis(style_a, r1a, r2a) |
| 29 | + rb = tensor_product_axis(style_b, r1b, r2b) |
| 30 | + return cartesianrange(ra, rb) |
| 31 | +end |
| 32 | + |
17 | 33 | function matricize_kronecker( |
18 | | - style::FusionStyle, a::AbstractArray, length1::Val, length2::Val |
| 34 | + style::FusionStyle, a::AbstractArray, length_codomain::Val |
19 | 35 | ) |
20 | | - m1 = matricize(kroneckerfactors(style, 1), kroneckerfactors(a, 1), length1, length2) |
21 | | - m2 = matricize(kroneckerfactors(style, 2), kroneckerfactors(a, 2), length1, length2) |
| 36 | + m1 = matricize(kroneckerfactors(style, 1), kroneckerfactors(a, 1), length_codomain) |
| 37 | + m2 = matricize(kroneckerfactors(style, 2), kroneckerfactors(a, 2), length_codomain) |
22 | 38 | return m1 ⊗ m2 |
23 | 39 | end |
24 | 40 | function TensorAlgebra.matricize( |
25 | | - style::KroneckerFusion, a::AbstractArray, length1::Val, length2::Val |
| 41 | + style::KroneckerFusion, a::AbstractArray, length_codomain::Val |
26 | 42 | ) |
27 | | - return matricize_kronecker(style, a, length1, length2) |
| 43 | + return matricize_kronecker(style, a, length_codomain) |
28 | 44 | end |
| 45 | + |
29 | 46 | function unmatricize_kronecker( |
30 | 47 | style::FusionStyle, |
31 | 48 | m::AbstractMatrix, |
32 | | - codomain_axes::Tuple{Vararg{AbstractUnitRange}}, |
33 | | - domain_axes::Tuple{Vararg{AbstractUnitRange}}, |
| 49 | + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, |
| 50 | + axes_domain::Tuple{Vararg{AbstractUnitRange}}, |
34 | 51 | ) |
35 | 52 | style1, style2 = kroneckerfactors(style) |
36 | 53 | m1, m2 = kroneckerfactors(m) |
37 | | - codomain1 = kroneckerfactors.(codomain_axes, 1) |
38 | | - codomain2 = kroneckerfactors.(codomain_axes, 2) |
39 | | - domain1 = kroneckerfactors.(domain_axes, 1) |
40 | | - domain2 = kroneckerfactors.(domain_axes, 2) |
| 54 | + codomain1 = kroneckerfactors.(axes_codomain, 1) |
| 55 | + codomain2 = kroneckerfactors.(axes_codomain, 2) |
| 56 | + domain1 = kroneckerfactors.(axes_domain, 1) |
| 57 | + domain2 = kroneckerfactors.(axes_domain, 2) |
41 | 58 | a1 = unmatricize(style1, m1, codomain1, domain1) |
42 | 59 | a2 = unmatricize(style2, m2, codomain2, domain2) |
43 | 60 | return a1 ⊗ a2 |
|
0 commit comments