@@ -3,7 +3,7 @@ module KroneckerArraysTensorAlgebraExt
33using KroneckerArrays: KroneckerArrays, AbstractKroneckerArray, CartesianProductUnitRange,
44 ⊗ , cartesianrange, kroneckerfactors, kroneckerfactortypes
55using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockedTrivialPermutation,
6- FusionStyle, matricize, tensor_product_axis, unmatricize
6+ FusionStyle, matricize, tensor_product_axis, trivial_axis, unmatricize
77
88struct KroneckerFusion{A <: FusionStyle , B <: FusionStyle } <: FusionStyle
99 a:: A
@@ -19,14 +19,57 @@ function TensorAlgebra.FusionStyle(A::Type{<:CartesianProductUnitRange})
1919 return KroneckerFusion (FusionStyle .(kroneckerfactortypes (A))... )
2020end
2121
22+ function TensorAlgebra. trivial_axis (
23+ style:: KroneckerFusion , side:: Val{:codomain} , a:: AbstractArray ,
24+ axes_codomain:: Tuple{Vararg{AbstractUnitRange}} ,
25+ axes_domain:: Tuple{Vararg{AbstractUnitRange}} ,
26+ )
27+ return trivial_kronecker (style, side, a, axes_codomain, axes_domain)
28+ end
29+ function TensorAlgebra. trivial_axis (
30+ style:: KroneckerFusion , side:: Val{:domain} , a:: AbstractArray ,
31+ axes_codomain:: Tuple{Vararg{AbstractUnitRange}} ,
32+ axes_domain:: Tuple{Vararg{AbstractUnitRange}} ,
33+ )
34+ return trivial_kronecker (style, side, a, axes_codomain, axes_domain)
35+ end
36+ function trivial_kronecker (
37+ style:: FusionStyle , side:: Val , a:: AbstractArray ,
38+ axes_codomain:: Tuple{Vararg{AbstractUnitRange}} ,
39+ axes_domain:: Tuple{Vararg{AbstractUnitRange}} ,
40+ )
41+ style_a, style_b = kroneckerfactors (style)
42+ a_a, a_b = kroneckerfactors (a)
43+ axes_codomain_a = kroneckerfactors .(axes_codomain, 1 )
44+ axes_codomain_b = kroneckerfactors .(axes_codomain, 2 )
45+ axes_domain_a = kroneckerfactors .(axes_domain, 1 )
46+ axes_domain_b = kroneckerfactors .(axes_domain, 2 )
47+ ra = trivial_axis (style_a, side, a_a, axes_codomain_a, axes_domain_a)
48+ rb = trivial_axis (style_b, side, a_b, axes_codomain_b, axes_domain_b)
49+ return cartesianrange (ra, rb)
50+ end
51+
2252function TensorAlgebra. tensor_product_axis (
23- style:: KroneckerFusion , r1:: AbstractUnitRange , r2:: AbstractUnitRange
53+ style:: KroneckerFusion , side:: Val{:codomain} ,
54+ r1:: AbstractUnitRange , r2:: AbstractUnitRange ,
55+ )
56+ return tensor_product_kronecker (style, side, r1, r2)
57+ end
58+ function TensorAlgebra. tensor_product_axis (
59+ style:: KroneckerFusion , side:: Val{:domain} ,
60+ r1:: AbstractUnitRange , r2:: AbstractUnitRange ,
61+ )
62+ return tensor_product_kronecker (style, side, r1, r2)
63+ end
64+ function tensor_product_kronecker (
65+ style:: KroneckerFusion , side:: Val ,
66+ r1:: AbstractUnitRange , r2:: AbstractUnitRange ,
2467 )
2568 style_a, style_b = kroneckerfactors (style)
2669 r1a, r1b = kroneckerfactors (r1)
2770 r2a, r2b = kroneckerfactors (r2)
28- ra = tensor_product_axis (style_a, r1a, r2a)
29- rb = tensor_product_axis (style_b, r1b, r2b)
71+ ra = tensor_product_axis (style_a, side, r1a, r2a)
72+ rb = tensor_product_axis (style_b, side, r1b, r2b)
3073 return cartesianrange (ra, rb)
3174end
3275
@@ -44,8 +87,7 @@ function TensorAlgebra.matricize(
4487end
4588
4689function unmatricize_kronecker (
47- style:: FusionStyle ,
48- m:: AbstractMatrix ,
90+ style:: FusionStyle , m:: AbstractMatrix ,
4991 axes_codomain:: Tuple{Vararg{AbstractUnitRange}} ,
5092 axes_domain:: Tuple{Vararg{AbstractUnitRange}} ,
5193 )
@@ -60,8 +102,7 @@ function unmatricize_kronecker(
60102 return a1 ⊗ a2
61103end
62104function TensorAlgebra. unmatricize (
63- style:: KroneckerFusion ,
64- m:: AbstractMatrix ,
105+ style:: KroneckerFusion , m:: AbstractMatrix ,
65106 codomain_axes:: Tuple{Vararg{AbstractUnitRange}} ,
66107 domain_axes:: Tuple{Vararg{AbstractUnitRange}} ,
67108 )
0 commit comments