4343import org .apache .sysds .utils .Explain ;
4444
4545import java .util .*;
46- import java .util .function .Predicate ;
4746import java .util .stream .Collectors ;
4847
4948import static org .apache .sysds .api .DMLScript .EXPLAIN ;
5049import static org .apache .sysds .hops .rewrite .RewriteMatrixMultChainOptimization .mmChainDP ;
5150
5251public class EinsumCPInstruction extends BuiltinNaryCPInstruction {
5352 public static final boolean FORCE_CELL_TPL = false ;
54- public static final boolean FUSED = true ;
53+ // public static final boolean FUSED = true;
5554 public static final boolean FUSE_OUTER_MULTIPLY = true ;
5655
5756
@@ -132,10 +131,10 @@ public void processInstruction(ExecutionContext ec) {
132131 ArrayList <MatrixBlock > remainingMatrices ;
133132
134133 if (!FORCE_CELL_TPL ) {
135- if (true ){
134+ if (true ){ // new way: search for fusions and matrix-multiplications chain in a loop
136135 plan = generatePlanFusionAndMM (eOpNodes , eOpNodesScalars , einc .charToDimensionSize , characterToOccurences , einc .outChar1 , einc .outChar2 );
137136 }else { // old way: try to do fusion first and then rest in binary fashion cost based
138- if (FUSED ) {
137+ if (true /* FUSED*/ ) {
139138 ret = new ArrayList <>();
140139 EOpNodeFuse fuse = EOpNodeFuse .match (eOpNodes , einc .outChar1 , einc .outChar2 ,
141140 einc .charToDimensionSize , characterToOccurences , ret );
@@ -152,10 +151,9 @@ public void processInstruction(ExecutionContext ec) {
152151 eOpNodes = ret ;
153152 }
154153 }
155-
156154 }
157155
158- Pair <Integer , List <EOpNode >> costAndPlan = generatePlanBinaryCostBased (0 , eOpNodes , einc .charToDimensionSize , characterToOccurences ,
156+ Pair <Integer , List <EOpNode >> costAndPlan = generateBinaryPlanCostBased (0 , eOpNodes , einc .charToDimensionSize , characterToOccurences ,
159157 einc .outChar1 , einc .outChar2 );
160158 plan = costAndPlan .getRight ();
161159 }
@@ -191,6 +189,7 @@ public void processInstruction(ExecutionContext ec) {
191189 plan .set (0 , new EOpNodeBinary (plan .get (0 ).c1 , plan .get (1 ).c1 , plan .get (0 ), plan .get (1 ), EBinaryOperand .A_B ));
192190 if (plan .get (0 ).c1 == einc .outChar2 && plan .get (1 ).c1 == einc .outChar1 )
193191 plan .set (0 , new EOpNodeBinary (plan .get (1 ).c1 , plan .get (0 ).c1 , plan .get (1 ), plan .get (0 ), EBinaryOperand .A_B ));
192+ plan .remove (1 );
194193 }
195194 if (EXPLAIN != Explain .ExplainType .NONE )
196195 System .out .println ("Einsum plan:" );
@@ -224,6 +223,7 @@ else if(resNode.c1 == einc.outChar2 && resNode.c2 == einc.outChar1){
224223 }
225224 }else if (einc .outChar1 != null ){
226225 if (resNode .c1 == einc .outChar1 && resNode .c2 == null ){
226+ ensureMatrixBlockColumnVector (remainingMatrices .get (0 ));
227227 ec .setMatrixOutput (output .getName (), remainingMatrices .get (0 ));
228228 }else {
229229 if (LOG .isTraceEnabled ()) LOG .trace ("Einsum expected: " +resultString + ", got: " +resNode .c1 +resNode .c2 );
@@ -255,6 +255,9 @@ else if(resNode.c1 == einc.outChar2 && resNode.c2 == einc.outChar1){
255255
256256 MatrixBlock res = computeCellSummation (mbs , chars , resultString , einc .charToDimensionSize , summingChars , einc .outRows , einc .outCols );
257257
258+ if (einc .outChar2 == null )
259+ ensureMatrixBlockColumnVector (res );
260+
258261 if (einc .outRows == 1 && einc .outCols == 1 )
259262 ec .setScalarOutput (output .getName (), new DoubleObject (res .get (0 , 0 )));
260263 else ec .setMatrixOutput (output .getName (), res );
@@ -422,7 +425,6 @@ private static List<EOpNode> generatePlanFusionAndMM(ArrayList<EOpNode> eOpNodes
422425 lastNumOfOperands = eOpNodes .size ();
423426
424427 EOpNodeFuse fuse = null ;
425-
426428 do {
427429 ret = new ArrayList <>();
428430 fuse = EOpNodeFuse .match (eOpNodes , outChar1 , outChar2 , charToSizeMap , charToOccurences , ret );
@@ -445,7 +447,6 @@ private static List<EOpNode> generatePlanFusionAndMM(ArrayList<EOpNode> eOpNodes
445447 ret .add (bin );
446448 }
447449 eOpNodes = ret ;
448-
449450 }
450451
451452 return eOpNodes ;
@@ -596,15 +597,15 @@ private static ArrayList<List<EOpNode>> findMatrixMultiplicationChains(ArrayList
596597 }
597598
598599 // old way
599- private Pair <Integer , List <EOpNode >> generatePlanBinaryCostBased (int cost , ArrayList <EOpNode > operands , HashMap <Character , Integer > charToSizeMap , HashMap <Character , Integer > charToOccurences , Character outChar1 , Character outChar2 ) {
600+ private Pair <Integer , List <EOpNode >> generateBinaryPlanCostBased (int cost , ArrayList <EOpNode > operands , HashMap <Character , Integer > charToSizeMap , HashMap <Character , Integer > charToOccurences , Character outChar1 , Character outChar2 ) {
600601 Integer minCost = cost ;
601602 List <EOpNode > minNodes = operands ;
602603
603604 if (operands .size () == 2 ){
604605 boolean swap = (operands .get (0 ).c2 == null && operands .get (1 ).c2 != null ) || operands .get (0 ).c1 == null ;
605606 EOpNode n1 = operands .get (!swap ? 0 : 1 );
606607 EOpNode n2 = operands .get (!swap ? 1 : 0 );
607- Triple <Integer , EBinaryOperand , Pair <Character , Character >> t = TryCombineAndCost (n1 , n2 , charToSizeMap , charToOccurences , outChar1 , outChar2 );
608+ Triple <Integer , EBinaryOperand , Pair <Character , Character >> t = EOpNodeBinary . TryCombineAndCost (n1 , n2 , charToSizeMap , charToOccurences , outChar1 , outChar2 );
608609 if (t != null ) {
609610 EOpNodeBinary newNode = new EOpNodeBinary (t .getRight ().getLeft (), t .getRight ().getRight (), n1 , n2 , t .getMiddle ());
610611 int thisCost = cost + t .getLeft ();
@@ -625,7 +626,7 @@ else if (operands.size() == 1){
625626 EOpNode n2 = operands .get (!swap ? j : i );
626627
627628
628- Triple <Integer , EBinaryOperand , Pair <Character , Character >> t = TryCombineAndCost (n1 , n2 , charToSizeMap , charToOccurences , outChar1 , outChar2 );
629+ Triple <Integer , EBinaryOperand , Pair <Character , Character >> t = EOpNodeBinary . TryCombineAndCost (n1 , n2 , charToSizeMap , charToOccurences , outChar1 , outChar2 );
629630 if (t != null ){
630631 EOpNodeBinary newNode = new EOpNodeBinary (t .getRight ().getLeft (), t .getRight ().getRight (), n1 , n2 , t .getMiddle ());
631632 int thisCost = cost + t .getLeft ();
@@ -644,7 +645,7 @@ else if (operands.size() == 1){
644645 }
645646 newOperands .add (newNode );
646647
647- Pair <Integer , List <EOpNode >> furtherPlan = generatePlanBinaryCostBased (thisCost , newOperands ,charToSizeMap , charToOccurences , outChar1 , outChar2 );
648+ Pair <Integer , List <EOpNode >> furtherPlan = generateBinaryPlanCostBased (thisCost , newOperands ,charToSizeMap , charToOccurences , outChar1 , outChar2 );
648649 if (furtherPlan .getRight ().size () < (minNodes .size ()) || furtherPlan .getLeft () < minCost ){
649650 minCost = furtherPlan .getLeft ();
650651 minNodes = furtherPlan .getRight ();
@@ -663,114 +664,6 @@ else if (operands.size() == 1){
663664 return Pair .of (minCost , minNodes );
664665 }
665666
666- private static Triple <Integer , EBinaryOperand , Pair <Character , Character >> TryCombineAndCost (EOpNode n1 , EOpNode n2 , HashMap <Character , Integer > charToSizeMap , HashMap <Character , Integer > charToOccurences , Character outChar1 , Character outChar2 ){
667- Predicate <Character > cannotBeSummed = (c ) ->
668- c == outChar1 || c == outChar2 || charToOccurences .get (c ) > 2 ;
669-
670- if (n1 .c1 == null ) {
671- // n2.c1 also has to be null
672- return Triple .of (1 , EBinaryOperand .scalar_scalar , Pair .of (null , null ));
673- }
674-
675- if (n2 .c1 == null ) {
676- if (n1 .c2 == null )
677- return Triple .of (charToSizeMap .get (n1 .c1 ), EBinaryOperand .A_scalar , Pair .of (n1 .c1 , null ));
678- return Triple .of (charToSizeMap .get (n1 .c1 )*charToSizeMap .get (n1 .c2 ), EBinaryOperand .AB_scalar , Pair .of (n1 .c1 , n1 .c2 ));
679- }
680-
681- if (n1 .c1 == n2 .c1 ){
682- if (n1 .c2 != null ){
683- if ( n1 .c2 == n2 .c2 ){
684- if ( cannotBeSummed .test (n1 .c1 )){
685- if (cannotBeSummed .test (n1 .c2 )){
686- return Triple .of (charToSizeMap .get (n1 .c1 )*charToSizeMap .get (n1 .c2 ), EBinaryOperand .AB_AB , Pair .of (n1 .c1 , n1 .c2 ));
687- }
688- return Triple .of (charToSizeMap .get (n1 .c1 )*charToSizeMap .get (n1 .c2 ), EBinaryOperand .Ba_Ba , Pair .of (n1 .c1 , null ));
689- }
690-
691- if (cannotBeSummed .test (n1 .c2 )){
692- return Triple .of (charToSizeMap .get (n1 .c1 )*charToSizeMap .get (n1 .c2 ), EBinaryOperand .aB_aB , Pair .of (n1 .c2 , null ));
693- }
694-
695- return Triple .of (charToSizeMap .get (n1 .c1 )*charToSizeMap .get (n1 .c2 ), EBinaryOperand .ab_ab , Pair .of (null , null ));
696-
697- }
698-
699- else if (n2 .c2 == null ){
700- if (cannotBeSummed .test (n1 .c1 )){
701- return Triple .of (charToSizeMap .get (n1 .c1 )*charToSizeMap .get (n1 .c2 )*2 , EBinaryOperand .AB_A , Pair .of (n1 .c1 , n1 .c2 ));
702- }
703- return Triple .of (charToSizeMap .get (n1 .c1 )*charToSizeMap .get (n1 .c2 )*2 , EBinaryOperand .aB_a , Pair .of (n1 .c2 , null )); // in theory (null, n1.c2)
704- }
705- else if (n1 .c1 ==outChar1 || n1 .c1 ==outChar2 || charToOccurences .get (n1 .c1 ) > 2 ){
706- return null ;// AB,AC
707- }
708- else {
709- return Triple .of ((charToSizeMap .get (n1 .c1 )*charToSizeMap .get (n1 .c2 ))+(charToSizeMap .get (n1 .c1 )*charToSizeMap .get (n1 .c2 )*charToSizeMap .get (n2 .c2 )), EBinaryOperand .aB_aC , Pair .of (n1 .c2 , n2 .c2 )); // or n2.c2, n1.c2
710- }
711- }else { // n1.c2 = null -> c2.c2 = null
712- if (n1 .c1 ==outChar1 || n1 .c1 ==outChar2 || charToOccurences .get (n1 .c1 ) > 2 ){
713- return Triple .of (charToSizeMap .get (n1 .c1 ), EBinaryOperand .A_A , Pair .of (n1 .c1 , null ));
714- }
715- return Triple .of (charToSizeMap .get (n1 .c1 ), EBinaryOperand .a_a , Pair .of (null , null ));
716- }
717-
718-
719- }else { // n1.c1 != n2.c1
720- if (n1 .c2 == null ) {
721- return Triple .of (charToSizeMap .get (n1 .c1 )*charToSizeMap .get (n2 .c1 ), EBinaryOperand .A_B , Pair .of (n1 .c1 , n2 .c1 ));
722- }
723- else if (n2 .c2 == null ) { // ab,c
724- if (n1 .c2 == n2 .c1 ) {
725- if (cannotBeSummed .test (n1 .c2 )){
726- return Triple .of (charToSizeMap .get (n1 .c1 )*charToSizeMap .get (n2 .c1 ), EBinaryOperand .BA_A , Pair .of (n1 .c1 , n1 .c2 ));
727- }
728- return Triple .of (charToSizeMap .get (n1 .c1 )*charToSizeMap .get (n2 .c1 ), EBinaryOperand .Ba_a , Pair .of (n1 .c1 , null ));
729- }
730- return null ; // AB,C
731- }
732- else if (n1 .c2 == n2 .c1 ) {
733- if (n1 .c1 == n2 .c2 ){ // ab,ba
734- if (cannotBeSummed .test (n1 .c1 )){
735- if (cannotBeSummed .test (n1 .c2 )){
736- return Triple .of (charToSizeMap .get (n1 .c1 )*charToSizeMap .get (n1 .c2 ), EBinaryOperand .AB_BA , Pair .of (n1 .c1 , n1 .c2 ));
737- }
738- return Triple .of (charToSizeMap .get (n1 .c1 )*charToSizeMap .get (n1 .c2 ), EBinaryOperand .Ba_aB , Pair .of (n1 .c1 , null ));
739- }
740- if (cannotBeSummed .test (n1 .c2 )){
741- return Triple .of (charToSizeMap .get (n1 .c1 )*charToSizeMap .get (n1 .c2 ), EBinaryOperand .aB_Ba , Pair .of (n1 .c2 , null ));
742- }
743- return Triple .of (charToSizeMap .get (n1 .c1 )*charToSizeMap .get (n1 .c2 ), EBinaryOperand .ab_ba , Pair .of (null , null ));
744- }
745- if (cannotBeSummed .test (n1 .c2 )){
746- return null ; // AB_B
747- }else {
748- return Triple .of (charToSizeMap .get (n1 .c1 )*charToSizeMap .get (n1 .c2 )*charToSizeMap .get (n2 .c2 ), EBinaryOperand .Ba_aC , Pair .of (n1 .c1 , n2 .c2 ));
749- // if(n1.c1 ==outChar1 || n1.c1==outChar2|| charToOccurences.get(n1.c1) > 2){
750- // return null; // AB_B
751- // }
752- // return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_a, Pair.of(n1.c1, null));
753- }
754- }
755- if (n1 .c1 == n2 .c2 ) {
756- if (cannotBeSummed .test (n1 .c1 )){
757- return null ; // AB_B
758- }
759- return Triple .of (charToSizeMap .get (n1 .c1 )*charToSizeMap .get (n1 .c2 )*charToSizeMap .get (n2 .c1 ), EBinaryOperand .aB_Ca , Pair .of (n2 .c1 , n1 .c2 )); // * its just reorder of mmult
760- }
761- else if (n1 .c2 == n2 .c2 ) {
762- if (n1 .c2 ==outChar1 || n1 .c2 ==outChar2 || charToOccurences .get (n1 .c2 ) > 2 ){
763- return null ; // BA_CA
764- }else {
765- return Triple .of (charToSizeMap .get (n1 .c1 )*charToSizeMap .get (n1 .c2 ) +(charToSizeMap .get (n1 .c1 )*charToSizeMap .get (n1 .c2 )*charToSizeMap .get (n2 .c1 )), EBinaryOperand .Ba_Ca , Pair .of (n1 .c1 , n2 .c1 )); // or n2.c1, n1.c1
766- }
767- }
768- else { // we have something like ab,cd
769- return null ;
770- }
771- }
772- }
773-
774667 private ArrayList <MatrixBlock > executePlan (List <EOpNode > plan , ArrayList <MatrixBlock > inputs ) {
775668 ArrayList <MatrixBlock > res = new ArrayList <>(plan .size ());
776669 for (EOpNode p : plan ){
0 commit comments