4646
4747public class EOpNodeBinary extends EOpNode {
4848
49- public enum EBinaryOperand { // upper case: char has to remain , lower case: to be summed
49+ public enum EBinaryOperand { // upper case: char remains , lower case: summed (reduced) dimension
5050 ////// mm: //////
5151 Ba_aC , // -> BC
5252 aB_Ca , // -> CB
5353 Ba_Ca , // -> BC
5454 aB_aC , // -> BC
5555
56- ////// elementwisemult and sums //////
56+ ////// element-wise multiplications and sums //////
5757 aB_aB ,// elemwise and colsum -> B
5858 Ab_Ab , // elemwise and rowsum ->A
5959 Ab_bA , // elemwise, either colsum or rowsum -> A
@@ -169,18 +169,12 @@ public static EOpNodeBinary combineMatrixMultiply(EOpNode left, EOpNode right) {
169169 }
170170
171171 @ Override
172- public String [] recursivePrintString () {
173- String [] left = this .left .recursivePrintString ();
174- String [] right = this .right .recursivePrintString ();
175- String [] res = new String [left .length + right .length +1 ];
176- res [0 ] = this .getClass ().getSimpleName ()+" (" + operand .toString ()+") " +this .toString ();
177- for (int i =0 ; i <left .length ; i ++) {
178- res [i +1 ] = (i ==0 ? "┌─ " : " " ) +left [i ];
179- }
180- for (int i =0 ; i <right .length ; i ++) {
181- res [left .length +i +1 ] = (i ==0 ? "└─ " : " " ) +right [i ];
182- }
183- return res ;
172+ public List <EOpNode > getChildren () {
173+ return List .of (this .left , this .right );
174+ }
175+ @ Override
176+ public String toString () {
177+ return this .getClass ().getSimpleName ()+" (" + operand .toString ()+") " +getOutputString ();
184178 }
185179
186180 @ Override
@@ -193,8 +187,6 @@ public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numThreads,
193187
194188 MatrixBlock res ;
195189
196- if (LOG .isTraceEnabled ()) LOG .trace ("computing binary " +bin .left +"," +bin .right +"->" +bin );
197-
198190 switch (bin .operand ){
199191 case AB_AB -> {
200192 res = MatrixBlock .naryOperations (new SimpleOperator (Multiply .getMultiplyFnObject ()), new MatrixBlock []{left , right },new ScalarObject []{}, new MatrixBlock ());
@@ -212,22 +204,28 @@ public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numThreads,
212204 res .getDenseBlockValues ()[0 ] = LibMatrixMult .dotProduct (left .getDenseBlockValues (), right .getDenseBlockValues (), 0 ,0 , left .getNumRows ());
213205 }
214206 case Ab_Ab -> {
215- res = EOpNodeFuse .compute (EOpNodeFuse .EinsumRewriteType .AB_BA_B_A__A , List .of (left , right ), new ArrayList <>(), new ArrayList <>(), new ArrayList <>(), new ArrayList <>(),null ,numThreads );
207+ res = EOpNodeFuse .compute (EOpNodeFuse .EinsumRewriteType .AB_BA_B_A__A , List .of (left , right ), new ArrayList <>(), new ArrayList <>(), new ArrayList <>(), new ArrayList <>(),
208+ null , numThreads );
216209 }
217210 case aB_aB -> {
218- res = EOpNodeFuse .compute (EOpNodeFuse .EinsumRewriteType .AB_BA_A__B , List .of (left , right ), new ArrayList <>(), new ArrayList <>(), new ArrayList <>(), new ArrayList <>(),null ,numThreads );
211+ res = EOpNodeFuse .compute (EOpNodeFuse .EinsumRewriteType .AB_BA_A__B , List .of (left , right ), new ArrayList <>(), new ArrayList <>(), new ArrayList <>(), new ArrayList <>(),
212+ null , numThreads );
219213 }
220214 case ab_ab -> {
221- res = EOpNodeFuse .compute (EOpNodeFuse .EinsumRewriteType .AB_BA_B_A__ , List .of (left , right ), new ArrayList <>(), new ArrayList <>(), new ArrayList <>(), new ArrayList <>(),null ,numThreads );
215+ res = EOpNodeFuse .compute (EOpNodeFuse .EinsumRewriteType .AB_BA_B_A__ , List .of (left , right ), new ArrayList <>(), new ArrayList <>(), new ArrayList <>(), new ArrayList <>(),
216+ null , numThreads );
222217 }
223218 case ab_ba -> {
224- res = EOpNodeFuse .compute (EOpNodeFuse .EinsumRewriteType .AB_BA_B_A__ , List .of (left ), List .of (right ), new ArrayList <>(), new ArrayList <>(), new ArrayList <>(),null ,numThreads );
219+ res = EOpNodeFuse .compute (EOpNodeFuse .EinsumRewriteType .AB_BA_B_A__ , List .of (left ), List .of (right ), new ArrayList <>(), new ArrayList <>(), new ArrayList <>(),
220+ null , numThreads );
225221 }
226222 case Ab_bA -> {
227- res = EOpNodeFuse .compute (EOpNodeFuse .EinsumRewriteType .AB_BA_B_A__A , List .of (left ), List .of (right ), new ArrayList <>(), new ArrayList <>(), new ArrayList <>(),null ,numThreads );
223+ res = EOpNodeFuse .compute (EOpNodeFuse .EinsumRewriteType .AB_BA_B_A__A , List .of (left ), List .of (right ), new ArrayList <>(), new ArrayList <>(), new ArrayList <>(),
224+ null , numThreads );
228225 }
229226 case aB_Ba -> {
230- res = EOpNodeFuse .compute (EOpNodeFuse .EinsumRewriteType .AB_BA_A__B , List .of (left ), List .of (right ), new ArrayList <>(), new ArrayList <>(), new ArrayList <>(),null ,numThreads );
227+ res = EOpNodeFuse .compute (EOpNodeFuse .EinsumRewriteType .AB_BA_A__B , List .of (left ), List .of (right ), new ArrayList <>(), new ArrayList <>(), new ArrayList <>(),
228+ null , numThreads );
231229 }
232230 case AB_BA -> {
233231 ReorgOperator transpose = new ReorgOperator (SwapIndex .getSwapIndexFnObject (), numThreads );
@@ -271,14 +269,16 @@ public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numThreads,
271269 res = left .binaryOperations (new BinaryOperator (Multiply .getMultiplyFnObject ()), right );
272270 }
273271 case Ab_b -> {
274- res = EOpNodeFuse .compute (EOpNodeFuse .EinsumRewriteType .AB_BA_B_A__A , List .of (left ), new ArrayList <>(), List .of (right ), new ArrayList <>(), new ArrayList <>(),null ,numThreads );
272+ res = EOpNodeFuse .compute (EOpNodeFuse .EinsumRewriteType .AB_BA_B_A__A , List .of (left ), new ArrayList <>(), List .of (right ), new ArrayList <>(), new ArrayList <>(),
273+ null , numThreads );
275274 }
276275 case AB_A -> {
277276 ensureMatrixBlockColumnVector (right );
278277 res = left .binaryOperations (new BinaryOperator (Multiply .getMultiplyFnObject ()), right );
279278 }
280279 case aB_a -> {
281- res = EOpNodeFuse .compute (EOpNodeFuse .EinsumRewriteType .AB_BA_A__B , List .of (left ), new ArrayList <>(), new ArrayList <>(), List .of (right ), new ArrayList <>(),null ,numThreads );
280+ res = EOpNodeFuse .compute (EOpNodeFuse .EinsumRewriteType .AB_BA_A__B , List .of (left ), new ArrayList <>(), new ArrayList <>(), List .of (right ), new ArrayList <>(),
281+ null , numThreads );
282282 }
283283 case A_B -> {
284284 ensureMatrixBlockColumnVector (left );
@@ -427,10 +427,6 @@ else if (n1.c2 == n2.c1) {
427427 return null ; // AB_B
428428 }else {
429429 return Triple .of (charToSizeMap .get (n1 .c1 )*charToSizeMap .get (n1 .c2 )*charToSizeMap .get (n2 .c2 ), EBinaryOperand .Ba_aC , Pair .of (n1 .c1 , n2 .c2 ));
430- // if(n1.c1 ==outChar1 || n1.c1==outChar2|| charToOccurences.get(n1.c1) > 2){
431- // return null; // AB_B
432- // }
433- // return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_a, Pair.of(n1.c1, null));
434430 }
435431 }
436432 if (n1 .c1 == n2 .c2 ) {
0 commit comments