@@ -4,9 +4,29 @@ abstract type Algorithm end
44
55Algorithm (alg:: Algorithm ) = alg
66
7- struct Matricize <: Algorithm end
7+ struct Matricize{Style} <: Algorithm
8+ fusion_style:: Style
9+ end
10+ Matricize () = Matricize (ReshapeFusion ())
811
9- default_contract_alg () = Matricize ()
12+ function default_contract_alg (a1:: AbstractArray , labels1, a2:: AbstractArray , labels2)
13+ style1 = FusionStyle (a1)
14+ style2 = FusionStyle (a2)
15+ style1 == style2 || error (" Styles must match." )
16+ return Matricize (style1)
17+ end
18+ function default_contractadd!_alg (
19+ a_dest:: AbstractArray , labels_dest,
20+ a1:: AbstractArray , labels1,
21+ a2:: AbstractArray , labels2,
22+ α:: Number , β:: Number ,
23+ )
24+ style_dest = FusionStyle (a_dest)
25+ style1 = FusionStyle (a1)
26+ style2 = FusionStyle (a2)
27+ style_dest == style1 == style2 || error (" Styles must match." )
28+ return Matricize (style_dest)
29+ end
1030
1131# Required interface if not using
1232# matricized contraction.
@@ -29,7 +49,7 @@ function contract(
2949 labels1,
3050 a2:: AbstractArray ,
3151 labels2;
32- alg = default_contract_alg (),
52+ alg = default_contract_alg (a1, labels1, a2, labels2 ),
3353 kwargs... ,
3454 )
3555 return contract (Algorithm (alg), a1, labels1, a2, labels2; kwargs... )
@@ -48,7 +68,7 @@ function contract(
4868 labels1,
4969 a2:: AbstractArray ,
5070 labels2;
51- alg = default_contract_alg (),
71+ alg = default_contract_alg (a1, labels1, a2, labels2 ),
5272 kwargs... ,
5373 )
5474 return contract (Algorithm (alg), labels_dest, a1, labels1, a2, labels2; kwargs... )
@@ -75,7 +95,7 @@ function contractadd!(
7595 labels2,
7696 α:: Number ,
7797 β:: Number ;
78- alg = default_contract_alg ( ),
98+ alg = default_contractadd!_alg (a_dest, labels_dest, a1, labels1, a2, labels2, α, β ),
7999 kwargs... ,
80100 )
81101 contractadd! (
0 commit comments