|
1 | 1 | module KroneckerArraysTensorAlgebraExt |
2 | 2 |
|
3 | 3 | using KroneckerArrays: KroneckerArrays, AbstractKroneckerArray, ⊗, kroneckerfactors |
4 | | -using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockedTrivialPermutation, FusionStyle, |
5 | | - matricize, unmatricize |
| 4 | +using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockedTrivialPermutation, |
| 5 | + FusionStyle, matricize, unmatricize |
6 | 6 |
|
7 | 7 | struct KroneckerFusion{A <: FusionStyle, B <: FusionStyle} <: FusionStyle |
8 | 8 | a::A |
|
11 | 11 | KroneckerArrays.kroneckerfactors(style::KroneckerFusion) = (style.a, style.b) |
12 | 12 | KroneckerArrays.kroneckerfactortypes(::Type{KroneckerFusion{A, B}}) where {A, B} = (A, B) |
13 | 13 |
|
14 | | -TensorAlgebra.FusionStyle(a::AbstractKroneckerArray) = KroneckerFusion(FusionStyle.(kroneckerfactors(a))...) |
| 14 | +function TensorAlgebra.FusionStyle(a::AbstractKroneckerArray) |
| 15 | + return KroneckerFusion(FusionStyle.(kroneckerfactors(a))...) |
| 16 | +end |
15 | 17 | function matricize_kronecker( |
16 | | - style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2} |
| 18 | + style::FusionStyle, a::AbstractArray, length1::Val, length2::Val |
17 | 19 | ) |
18 | | - return matricize(kroneckerfactors(style, 1), kroneckerfactors(a, 1), biperm) ⊗ |
19 | | - matricize(kroneckerfactors(style, 2), kroneckerfactors(a, 2), biperm) |
| 20 | + m1 = matricize(kroneckerfactors(style, 1), kroneckerfactors(a, 1), length1, length2) |
| 21 | + m2 = matricize(kroneckerfactors(style, 2), kroneckerfactors(a, 2), length1, length2) |
| 22 | + return m1 ⊗ m2 |
20 | 23 | end |
21 | 24 | function TensorAlgebra.matricize( |
22 | | - style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2} |
| 25 | + style::KroneckerFusion, a::AbstractArray, length1::Val, length2::Val |
23 | 26 | ) |
24 | | - return matricize_kronecker(style, a, biperm) |
| 27 | + return matricize_kronecker(style, a, length1, length2) |
25 | 28 | end |
26 | | -# Fix ambiguity error. |
27 | | -# TODO: Investigate rewriting the logic in `TensorAlgebra.jl` to avoid this. |
28 | | -using TensorAlgebra: BlockedTrivialPermutation, unmatricize |
29 | | -function TensorAlgebra.matricize( |
30 | | - style::KroneckerFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2} |
| 29 | +function unmatricize_kronecker( |
| 30 | + style::FusionStyle, |
| 31 | + m::AbstractMatrix, |
| 32 | + codomain_axes::Tuple{Vararg{AbstractUnitRange}}, |
| 33 | + domain_axes::Tuple{Vararg{AbstractUnitRange}}, |
31 | 34 | ) |
32 | | - return matricize_kronecker(style, a, biperm) |
| 35 | + style1, style2 = kroneckerfactors(style) |
| 36 | + m1, m2 = kroneckerfactors(m) |
| 37 | + codomain1 = kroneckerfactors.(codomain_axes, 1) |
| 38 | + codomain2 = kroneckerfactors.(codomain_axes, 2) |
| 39 | + domain1 = kroneckerfactors.(domain_axes, 1) |
| 40 | + domain2 = kroneckerfactors.(domain_axes, 2) |
| 41 | + a1 = unmatricize(style1, m1, codomain1, domain1) |
| 42 | + a2 = unmatricize(style2, m2, codomain2, domain2) |
| 43 | + return a1 ⊗ a2 |
33 | 44 | end |
34 | | -function unmatricize_kronecker(style::KroneckerFusion, a::AbstractArray, ax) |
35 | | - return unmatricize(kroneckerfactors(style, 1), kroneckerfactors(a, 1), kroneckerfactors.(ax, 1)) ⊗ |
36 | | - unmatricize(kroneckerfactors(style, 2), kroneckerfactors(a, 2), kroneckerfactors.(ax, 2)) |
37 | | -end |
38 | | -function TensorAlgebra.unmatricize(style::KroneckerFusion, a::AbstractArray, ax) |
39 | | - return unmatricize_kronecker(style, a, ax) |
| 45 | +function TensorAlgebra.unmatricize( |
| 46 | + style::KroneckerFusion, |
| 47 | + m::AbstractMatrix, |
| 48 | + codomain_axes::Tuple{Vararg{AbstractUnitRange}}, |
| 49 | + domain_axes::Tuple{Vararg{AbstractUnitRange}}, |
| 50 | + ) |
| 51 | + return unmatricize_kronecker(style, m, codomain_axes, domain_axes) |
40 | 52 | end |
41 | 53 |
|
42 | 54 | end |
0 commit comments