3030import org .apache .sysds .runtime .functionobjects .ReduceRow ;
3131import org .apache .sysds .runtime .functionobjects .SwapIndex ;
3232import org .apache .sysds .runtime .instructions .cp .DoubleObject ;
33+ import org .apache .sysds .runtime .instructions .cp .EinsumCPInstruction ;
3334import org .apache .sysds .runtime .instructions .cp .ScalarObject ;
3435import org .apache .sysds .runtime .matrix .data .LibMatrixMult ;
3536import org .apache .sysds .runtime .matrix .data .MatrixBlock ;
4849
4950public class EOpNodeBinary extends EOpNode {
5051
51-
5252 public enum EBinaryOperand { // upper case: char has to remain, lower case: to be summed
53- ////// summations: //////
54- aB_a ,// -> B
55- Ba_a , // -> B
56- Ba_aC , // mmult -> BC
57- aB_Ca ,
53+ ////// mm: //////
54+ Ba_aC , // -> BC
55+ aB_Ca , // -> CB
5856 Ba_Ca , // -> BC
59- aB_aC , // outer mult, possibly with transposing first -> BC
60- a_a ,// dot ->
57+ aB_aC , // -> BC
6158
62- ////// elementwisemult and sums, something like ij,ij->i //////
59+ ////// elementwisemult and sums //////
6360 aB_aB ,// elemwise and colsum -> B
6461 Ba_Ba , // elemwise and rowsum ->B
6562 Ba_aB , // elemwise, either colsum or rowsum -> B
6663 aB_Ba ,
64+ ab_ab ,//M-M sum all
65+ ab_ba , //M-M.T sum all
66+ aB_a ,// -> B
67+ Ba_a , // -> B
6768
6869 ////// elementwise, no summations: //////
6970 A_A ,// v-elemwise -> A
7071 AB_AB ,// M-M elemwise -> AB
7172 AB_BA , // M-M.T elemwise -> AB
7273 AB_A , // M-v colwise -> BA!?
7374 BA_A , // M-v rowwise -> BA
74- ab_ab ,//M-M sum all
75- ab_ba , //M-M.T sum all
75+
7676 ////// other //////
77+ a_a ,// dot ->
7778 A_B , // outer mult -> AB
7879 A_scalar , // v-scalar
7980 AB_scalar , // m-scalar
8081 scalar_scalar
8182 }
82- public EOpNode _left ;
83- public EOpNode _right ;
84- public EBinaryOperand _operand ;
83+ public EOpNode left ;
84+ public EOpNode right ;
85+ public EBinaryOperand operand ;
8586 private boolean transposeResult ;
86- public EOpNodeBinary (Character c1 , Character c2 , EOpNode left , EOpNode right , EBinaryOperand operand ){
87- super (c1 ,c2 );
88- this ._left = left ;
89- this ._right = right ;
90- this ._operand = operand ;
91- }
87+ public EOpNodeBinary (EOpNode left , EOpNode right , EBinaryOperand operand ){
88+ super (null ,null ,null , null );
89+ Character c1 , c2 ;
90+ Integer dim1 , dim2 ;
91+ switch (operand ){
92+ case Ba_aC -> {
93+ c1 =left .c1 ;
94+ c2 =right .c2 ;
95+ dim1 =left .dim1 ;
96+ dim2 =right .dim2 ;
97+ }
98+ case aB_Ca -> {
99+ c1 =left .c2 ;
100+ c2 =right .c1 ;
101+ dim1 =left .dim2 ;
102+ dim2 =right .dim1 ;
103+ }
104+ case Ba_Ca -> {
105+ c1 =left .c1 ;
106+ c2 =right .c1 ;
107+ dim1 =left .dim1 ;
108+ dim2 =right .dim1 ;
109+ }
110+ case aB_aC -> {
111+ c1 =left .c2 ;
112+ c2 =right .c2 ;
113+ dim1 =left .dim2 ;
114+ dim2 =right .dim2 ;
115+ }
116+ case aB_aB , aB_Ba , aB_a -> {
117+ c1 =left .c2 ;
118+ c2 =null ;
119+ dim1 =left .dim2 ;
120+ dim2 =null ;
121+ }
122+ case Ba_Ba , Ba_aB , Ba_a , A_A , A_scalar -> {
123+ c1 =left .c1 ;
124+ c2 =null ;
125+ dim1 =left .dim1 ;
126+ dim2 =null ;
127+ }
128+ case ab_ab , ab_ba , a_a , scalar_scalar -> {
129+ c1 =null ;
130+ c2 =null ;
131+ dim1 =null ;
132+ dim2 =null ;
133+ }
134+ case AB_AB , AB_BA , AB_A , BA_A , AB_scalar ->{
135+ c1 =left .c1 ;
136+ c2 =left .c2 ;
137+ dim1 =left .dim1 ;
138+ dim2 =left .dim2 ;
139+ }
140+ case A_B -> {
141+ c1 =left .c1 ;
142+ c2 =right .c1 ;
143+ dim1 =left .dim1 ;
144+ dim2 =right .dim1 ;
145+ }
146+ default -> throw new IllegalStateException ("EOpNodeBinary Unexpected type: " + operand );
147+ }
148+ // super(c1, c2, dim1, dim2); // unavailable in JDK < 22
149+ this .c1 = c1 ;
150+ this .c2 = c2 ;
151+ this .dim1 = dim1 ;
152+ this .dim2 = dim2 ;
153+ this .left = left ;
154+ this .right = right ;
155+ this .operand = operand ;
156+ }
157+
92158 public void setTransposeResult (boolean transposeResult ){
93159 this .transposeResult = transposeResult ;
94160 }
95161
96162 public static EOpNodeBinary combineMatrixMultiply (EOpNode left , EOpNode right ) {
97- if (left .c2 == right .c1 ) { return new EOpNodeBinary (left . c1 , right . c2 , left , right , EBinaryOperand .Ba_aC ); }
98- if (left .c2 == right .c2 ) { return new EOpNodeBinary (left . c1 , right . c1 , left , right , EBinaryOperand .Ba_Ca ); }
99- if (left .c1 == right .c1 ) { return new EOpNodeBinary (left . c2 , right . c2 , left , right , EBinaryOperand .aB_aC ); }
163+ if (left .c2 == right .c1 ) { return new EOpNodeBinary (left , right , EBinaryOperand .Ba_aC ); }
164+ if (left .c2 == right .c2 ) { return new EOpNodeBinary (left , right , EBinaryOperand .Ba_Ca ); }
165+ if (left .c1 == right .c1 ) { return new EOpNodeBinary (left , right , EBinaryOperand .aB_aC ); }
100166 if (left .c1 == right .c2 ) {
101- var res = new EOpNodeBinary (left . c2 , right . c1 , left , right , EBinaryOperand .aB_Ca );
167+ var res = new EOpNodeBinary (left , right , EBinaryOperand .aB_Ca );
102168 res .setTransposeResult (true );
103169 return res ;
104170 }
@@ -107,10 +173,10 @@ public static EOpNodeBinary combineMatrixMultiply(EOpNode left, EOpNode right) {
107173
108174 @ Override
109175 public String [] recursivePrintString () {
110- String [] left = _left .recursivePrintString ();
111- String [] right = _right .recursivePrintString ();
176+ String [] left = this . left .recursivePrintString ();
177+ String [] right = this . right .recursivePrintString ();
112178 String [] res = new String [left .length + right .length +1 ];
113- res [0 ] = this .getClass ().getSimpleName ()+" (" +_operand .toString ()+") " +this .toString ();
179+ res [0 ] = this .getClass ().getSimpleName ()+" (" + operand .toString ()+") " +this .toString ();
114180 for (int i =0 ; i <left .length ; i ++) {
115181 res [i +1 ] = (i ==0 ? "┌─ " : " " ) +left [i ];
116182 }
@@ -123,16 +189,16 @@ public String[] recursivePrintString() {
123189 @ Override
124190 public MatrixBlock computeEOpNode (ArrayList <MatrixBlock > inputs , int numThreads , Log LOG ) {
125191 EOpNodeBinary bin = this ;
126- MatrixBlock left = _left .computeEOpNode (inputs , numThreads , LOG );
127- MatrixBlock right = _right .computeEOpNode (inputs , numThreads , LOG );
192+ MatrixBlock left = this . left .computeEOpNode (inputs , numThreads , LOG );
193+ MatrixBlock right = this . right .computeEOpNode (inputs , numThreads , LOG );
128194
129195 AggregateOperator agg = new AggregateOperator (0 , Plus .getPlusFnObject ());
130196
131197 MatrixBlock res ;
132198
133- if (LOG .isTraceEnabled ()) LOG .trace ("computing binary " +bin ._left +"," +bin ._right +"->" +bin );
199+ if (LOG .isTraceEnabled ()) LOG .trace ("computing binary " +bin .left +"," +bin .right +"->" +bin );
134200
135- switch (bin ._operand ){
201+ switch (bin .operand ){
136202 case AB_AB -> {
137203 res = MatrixBlock .naryOperations (new SimpleOperator (Multiply .getMultiplyFnObject ()), new MatrixBlock []{left , right },new ScalarObject []{}, new MatrixBlock ());
138204 }
@@ -255,7 +321,7 @@ public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numThreads,
255321 return new MatrixBlock (left .get (0 ,0 )*right .get (0 ,0 ));
256322 }
257323 default -> {
258- throw new IllegalArgumentException ("Unexpected value: " + bin ._operand .toString ());
324+ throw new IllegalArgumentException ("Unexpected value: " + bin .operand .toString ());
259325 }
260326
261327 }
@@ -267,25 +333,47 @@ public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numThreads,
267333 }
268334
269335 @ Override
270- public void reorderChildren (Character outChar1 , Character outChar2 ) {
271- if (this ._operand ==EBinaryOperand .aB_aC ){
272- if (this ._right .c2 == outChar1 ) {
273- var tmp = _left ;
274- _left = _right ;
275- _right = tmp ;
276- var tmp2 = c1 ;
277- c1 = c2 ;
278- c2 = tmp2 ;
336+ public EOpNode reorderChildrenAndOptimize (EOpNode parent , Character outChar1 , Character outChar2 ) {
337+ if (this .operand ==EBinaryOperand .aB_aC ){
338+ if (this .right .c2 == outChar1 ) { // result is CB so Swap aB and aC
339+ var tmpLeft = left ; left = right ; right = tmpLeft ;
340+ var tmpC1 = c1 ; c1 = c2 ; c2 = tmpC1 ;
341+ var tmpDim1 = dim1 ; dim1 = dim2 ; dim2 = tmpDim1 ;
279342 }
280- _left .reorderChildren (_left .c2 , _left .c1 );
281- // check if change happened:
282- if (_left .c2 == _right .c1 ) {
283- this ._operand = EBinaryOperand .Ba_aC ;
343+ if (EinsumCPInstruction .FUSE_OUTER_MULTIPLY && left instanceof EOpNodeFuse fuse && fuse .einsumRewriteType == EOpNodeFuse .EinsumRewriteType .AB_BA_B_A__AB &&
344+ LibMatrixMult .isSkinnyRightHandSide (left .dim1 , left .dim2 , right .dim1 , right .dim2 , true )) {
345+ fuse .operands .get (4 ).add (right );
346+ fuse .einsumRewriteType = EOpNodeFuse .EinsumRewriteType .AB_BA_B_A_AZ__BZ ;
347+ fuse .c1 = fuse .c2 ;
348+ fuse .c2 = right .c2 ;
349+ return fuse ;
350+ }
351+
352+ left = left .reorderChildrenAndOptimize (this , left .c2 , left .c1 ); // maybe can be reordered
353+ if (left .c2 == right .c1 ) { // check if change happened:
354+ this .operand = EBinaryOperand .Ba_aC ;
284355 }
285- }
356+ right = right .reorderChildrenAndOptimize (this , right .c1 , right .c2 );
357+ }else if (this .operand ==EBinaryOperand .Ba_Ca ){
358+ if (this .right .c1 == outChar1 ) { // result is CB so Swap Ba and Ca
359+ var tmpLeft = left ; left = right ; right = tmpLeft ;
360+ var tmpC1 = c1 ; c1 = c2 ; c2 = tmpC1 ;
361+ var tmpDim1 = dim1 ; dim1 = dim2 ; dim2 = tmpDim1 ;
362+ }
363+
364+ right = right .reorderChildrenAndOptimize (this , right .c2 , right .c1 ); // maybe can be reordered
365+ if (left .c2 == right .c1 ) { // check if change happened:
366+ this .operand = EBinaryOperand .Ba_aC ;
367+ }
368+ left = left .reorderChildrenAndOptimize (this , left .c1 , left .c2 );
369+ }else {
370+ left = left .reorderChildrenAndOptimize (this , left .c1 , left .c2 ); // just recurse
371+ right = right .reorderChildrenAndOptimize (this , right .c1 , right .c2 );
372+ }
373+ return this ;
286374 }
287375
288- // used in old method
376+ // used in the old approach
289377 public static Triple <Integer , EBinaryOperand , Pair <Character , Character >> TryCombineAndCost (EOpNode n1 , EOpNode n2 , HashMap <Character , Integer > charToSizeMap , HashMap <Character , Integer > charToOccurences , Character outChar1 , Character outChar2 ){
290378 Predicate <Character > cannotBeSummed = (c ) ->
291379 c == outChar1 || c == outChar2 || charToOccurences .get (c ) > 2 ;
@@ -388,7 +476,7 @@ else if (n1.c2 == n2.c2) {
388476 return Triple .of (charToSizeMap .get (n1 .c1 )*charToSizeMap .get (n1 .c2 ) +(charToSizeMap .get (n1 .c1 )*charToSizeMap .get (n1 .c2 )*charToSizeMap .get (n2 .c1 )), EBinaryOperand .Ba_Ca , Pair .of (n1 .c1 , n2 .c1 )); // or n2.c1, n1.c1
389477 }
390478 }
391- else { // something like ab,cd
479+ else { // something like AB,CD
392480 return null ;
393481 }
394482 }
0 commit comments