11package org .apache .sysds .runtime .einsum ;
22
3+ import org .apache .sysds .runtime .instructions .cp .EinsumCPInstruction ;
4+ import org .apache .sysds .runtime .matrix .data .LibMatrixMult ;
5+
36import java .util .ArrayList ;
47import java .util .Arrays ;
58import java .util .HashMap ;
@@ -17,21 +20,24 @@ public class EOpNodeEinsumFuse extends EOpNode {
1720 public static final int AX_index =7 ;
1821 public static final int AZ_index =8 ;
1922 public enum EinsumRewriteType {
20- // inputops__output 'X' = simplySumDim
21- AB_BA_B_XB_BX_A_XA_AX__AB ,
22- AB_BA_B_XB_BX_A_XA_AX__B ,
23- AB_BA_B_XB_BX_A_XA_AX__A ,
24- AB_BA_B_XB_BX_A_XA_AX__ ,
25-
26- AB_BA_B_XB_BX_A_XA_AX_AZ__Z
27- }
28- public enum EinsumRewriteType_v2 { // option 2 without X dims
23+ // B -> row*row, A -> row*scalar
2924 AB_BA_B_A__AB ,
3025 AB_BA_B_A__B ,
3126 AB_BA_B_A__A ,
27+ AB_BA_B_A__ ,
3228
33- AB_BA_B_A_AZ__Z
29+ // scalar from row(AB).dot(B) multiplied by row(AZ)
30+ AB_BA_B_A_AZ__Z ,
31+
32+ // AC: last step is outer matrix multiplication using vector C
33+ AB_BA_B_A_AZ__BZ ,
34+ AB_BA_B_A_AZ__ZB ,
35+
36+ // // outer matrix multiplication using vector C and vector Z
37+ // AB_BA_B_A_AZ_AC__ZC,
38+ // AB_BA_B_A_AZ_AC__CZ,
3439 }
40+
3541 public final EinsumRewriteType einsumRewriteType ;
3642 public final List <List <EOpNode >> operands ;
3743
@@ -41,29 +47,17 @@ private EOpNodeEinsumFuse(Character c1, Character c2, EinsumRewriteType einsumRe
4147 this .operands = Arrays .asList (operands );
4248 }
4349
44- public static EOpNodeEinsumFuse match (ArrayList <EOpNode > operands , Character outChar1 , Character outChar2 ,/*, Set<Character> simplySummableChars,*/ ArrayList <EOpNode > ret , HashMap <Character , Integer > charToOccurences ){
50+ public static EOpNodeEinsumFuse match (ArrayList <EOpNode > operands , Character outChar1 , Character outChar2 ,/*, Set<Character> simplySummableChars,*/ ArrayList <EOpNode > ret , HashMap <Character , Integer > charToOccurences , HashMap < Character , Integer > charToSize ){
4551 //precompute
4652 HashSet <String > matricesChars = new HashSet <>();
4753 HashMap <String , ArrayList <EOpNode >> charsToMatrices = new HashMap <>();
48- HashMap <Character , Integer > charsToNumberOfOperands = new HashMap <>();
4954
5055 for (EOpNode operand1 : operands ) {
5156 String k ;
52- //todo remove and use input map charToOccurences
53- if (charsToNumberOfOperands .containsKey (operand1 .c1 )) {
54- charsToNumberOfOperands .put (operand1 .c1 , charsToNumberOfOperands .get (operand1 .c1 ) + 1 );
55- } else {
56- charsToNumberOfOperands .put (operand1 .c1 , 1 );
57- }
5857
5958 if (operand1 .c2 != null ) {
6059 k = operand1 .c1 .toString () + operand1 .c2 ;
6160 matricesChars .add (k );
62- if (charsToNumberOfOperands .containsKey (operand1 .c2 )) {
63- charsToNumberOfOperands .put (operand1 .c2 , charsToNumberOfOperands .get (operand1 .c2 ) + 1 );
64- } else {
65- charsToNumberOfOperands .put (operand1 .c2 , 1 );
66- }
6761 } else {
6862 k = operand1 .c1 .toString ();
6963 }
@@ -82,13 +76,14 @@ public static EOpNodeEinsumFuse match(ArrayList<EOpNode> operands, Character out
8276 ArrayList <EOpNode > BXs = new ArrayList <>();
8377 ArrayList <EOpNode > XBs = new ArrayList <>();
8478 ArrayList <EOpNode > AZs = new ArrayList <>();
79+ // ArrayList<EOpNode> ACs = new ArrayList<>();
80+ ArrayList <EOpNode > Zs = new ArrayList <>();
8581 boolean pass = false ;
8682
8783 String AB = null ;
8884 String BA = null ;
8985 boolean doSumA =false ;
9086 boolean doSumB =false ;
91-
9287 for (String ABcandidate : matricesChars ) {
9388 char a = ABcandidate .charAt (0 );
9489 char b = ABcandidate .charAt (1 );
@@ -99,9 +94,10 @@ public static EOpNodeEinsumFuse match(ArrayList<EOpNode> operands, Character out
9994 BXs = new ArrayList <>();
10095 XBs = new ArrayList <>();
10196 AZs = new ArrayList <>();
102-
97+ Character z = null ;
10398 pass =true ;
104-
99+ int AZsCounter = 0 ;
100+ HashSet <String > AZCandidates = new HashSet <>();
105101
106102 for (String chars : charsToMatrices .keySet ()) {
107103 if (chars .equals (ABcandidate ) || chars .equals (BA )) {
@@ -118,111 +114,74 @@ public static EOpNodeEinsumFuse match(ArrayList<EOpNode> operands, Character out
118114 continue ;
119115 //always ok
120116 }else {
121- if (a ==chars .charAt (1 ) && b ==chars .charAt (0 )){
117+ if (a ==chars .charAt (1 ) && b ==chars .charAt (0 )){ //BA
122118// ABsCounter++;
123- //BA
124119 continue ;
125120 }
126121 if (chars .charAt (0 )==a ){
127- if (charsToNumberOfOperands .get (chars .charAt (1 ))==1 ){
128- if (chars .charAt (1 )!= outChar1 && chars .charAt (1 ) != outChar2 ) {
129- AXs .addAll (charsToMatrices .get (chars ));
130- // AsCounter++;
131- continue ;
132- }else {
133- if (AZs .size ()==0 ){
134- AZs = charsToMatrices .get (chars );
135- continue ;
136- }
137- pass = false ;
138- break ;
139- }
140- }else {
141- //dont allow for now, in theory AZ,Z or AZ,AZ would also work, but for now do them separately
142- pass = false ;
143- break ;
144- }
122+ //AZ
123+ AZsCounter ++;
124+ AZCandidates .add (chars );
145125 }
146126 else if (chars .charAt (0 )==b ){
147- if (charsToNumberOfOperands .get (chars .charAt (1 ))==1 ){
148- if (chars .charAt (1 )!= outChar1 && chars .charAt (1 ) != outChar2 ) {
149- BXs .addAll (charsToMatrices .get (chars ));
150- // BsCounter++;
151- continue ;
152- }else {
153- pass = false ; // no BZ, maybe experiment later
154- break ;
155- }
156- }else {
157- pass = false ;
158- break ;
159- }
127+ // BZ, todo, maybe transpose ab into ba
128+ pass = false ;
129+ break ;
160130 }
161131 else if (chars .charAt (1 )==a ){
162- if (charsToNumberOfOperands .get (chars .charAt (0 ))==1 ){
163- if (chars .charAt (0 )!= outChar1 && chars .charAt (0 ) != outChar2 ) {
164- XAs .addAll (charsToMatrices .get (chars ));
165- // AsCounter++;
166- continue ;
167- }else {
168- pass = false ;
169- break ;
170- }
171- }else {
132+ //ZA, maybe its small enough that it can be tranposed? but then not impactful as the bigger A, the more sense to fuse AZ?
172133 pass = false ;
173134 break ;
174- }
175135 }
176136 else if (chars .charAt (1 )==b ){
177- if (charsToNumberOfOperands .get (chars .charAt (0 ))==1 ){
178- if (chars .charAt (0 )!= outChar1 && chars .charAt (0 ) != outChar2 ) {
179- XBs .addAll (charsToMatrices .get (chars ));
180- // BsCounter++;
181- continue ;
182- }else {
183- pass = false ;
184- break ;
185- }
186- }else {
187- pass = false ;
188- break ;
189- }
137+ // ZB
138+ pass = false ;
139+ break ;
190140 }
191141 }
192142 }
193143 if (pass ){
144+
194145 AB = ABcandidate ;
195146 String A = "" +a ;
196147 String B = "" +b ;
197148 int ABsCounter = charsToMatrices .get (ABcandidate ).size ()+(charsToMatrices .containsKey (BA ) ? charsToMatrices .get (BA ).size () : 0 );
198- int AsCounter = (charsToMatrices .containsKey (A ) ? charsToMatrices .get (A ).size () : 0 ) +AXs .size ()+XAs .size ();
199- int BsCounter = (charsToMatrices .containsKey (B ) ? charsToMatrices .get (B ).size () : 0 )+BXs .size ()+XBs .size ();
149+ // int AZsCounter = AZs.size();
150+ int AsCounter = (charsToMatrices .containsKey (A ) ? charsToMatrices .get (A ).size () : 0 );
151+ int BsCounter = (charsToMatrices .containsKey (B ) ? charsToMatrices .get (B ).size () : 0 );
200152 if (AsCounter ==0 && BsCounter ==0 && ABsCounter <2 ){
201153 pass =false ;
202154 continue ;
203155 }
204- int usedAsCount = AsCounter +ABsCounter ;
205156 int usedBsCount = BsCounter +ABsCounter ;
206- doSumA = charToOccurences .get (a )==usedAsCount && (outChar1 == null || a !=outChar1 ) && (outChar2 == null || a !=outChar2 );
207157 doSumB = charToOccurences .get (b )==usedBsCount && (outChar1 == null || b !=outChar1 ) && (outChar2 == null || b !=outChar2 );
208- if (AZs .size ()!=0 ) { // invalidate AZ fusion
209- if (outChar1 != null ) {
210- if (a == outChar1 || b == outChar1 ) {
211- pass =false ;
212- continue ;
213- }
214- }
215- if (outChar2 != null ) {
216- if (a == outChar2 || b == outChar2 ) {
217- pass =false ;
218- continue ;
219- }
158+
159+ if (AZCandidates .size ()==1 ){
160+ // if(!doSumB){
161+ // pass=false;
162+ // continue;
163+ // }
164+ int usedAsCount = AsCounter +ABsCounter +AZsCounter ;
165+ doSumA = charToOccurences .get (a )==usedAsCount && (outChar1 == null || a !=outChar1 ) && (outChar2 == null || a !=outChar2 );
166+ if (!doSumA ){ // cant do AZ
167+ break ;// just do AB,B,A ->AB / A
168+ }else {
169+ AZs = charsToMatrices .get (AZCandidates .iterator ().next ());
170+ break ;//ok
220171 }
221- if (!doSumA || !doSumB ){
222- pass =false ;
223- continue ;
172+ } else if (AZCandidates .size ()>=2 ) {
173+ doSumA = false ;
174+ if (doSumB ){
175+ pass =true ;
176+ break ; // can do it, it will create AB,B,A -> A, that will be consumed by some AZ later
224177 }
178+ pass =false ;
179+ continue ;
180+
225181 }
182+ int usedAsCount = AsCounter +ABsCounter ;
183+ doSumA = charToOccurences .get (a )==usedAsCount && (outChar1 == null || a !=outChar1 ) && (outChar2 == null || a !=outChar2 );
184+
226185 break ;
227186 }
228187 }
@@ -232,31 +191,62 @@ else if(chars.charAt(1)==b){
232191 }
233192 String B = AB .substring (1 ,2 );
234193 String A = AB .substring (0 ,1 );
194+ char a = A .charAt (0 );
195+ char b = B .charAt (0 );
235196 Character c1 = null ;
236197 Character c2 = null ;
237- EinsumRewriteType t ;
198+ EinsumRewriteType t = null ;
238199
239- if (AZs .size ()!=0 ){
240- c1 =AZs .get (0 ).c2 ;
241- t =EinsumRewriteType .AB_BA_B_XB_BX_A_XA_AX_AZ__Z ;
242- }
243- else if (doSumA ){
200+ if (!AZs .isEmpty ()){
201+ // Character azC1 = AZs.get(0).c1;
202+ Character azC2 = AZs .get (0 ).c2 ;
203+ // c1 = AZs.get(0).c2;
244204 if (doSumB ) {
245- t = EinsumRewriteType .AB_BA_B_XB_BX_A_XA_AX__ ;
205+ t = EinsumRewriteType .AB_BA_B_A_AZ__Z ;
206+ c1 = azC2 ;
207+
246208 }
247- else {
248- t = EinsumRewriteType .AB_BA_B_XB_BX_A_XA_AX__B ;
249- c1 = AB .charAt (1 );
209+ else if (EinsumCPInstruction .FUSE_OUTER_MULTIPLY ) {
210+ if (LibMatrixMult .isSkinnyRightHandSide (charToSize .get (AB .charAt (0 )), charToSize .get (AB .charAt (1 )), charToSize .get (azC2 ), charToSize .get (AB .charAt (1 )),false )||
211+ LibMatrixMult .isSkinnyRightHandSide (charToSize .get (AB .charAt (0 )), azC2 , charToSize .get (AB .charAt (1 )), charToSize .get (azC2 ),false )) {
212+ // ideally this can be changed later by parent,depending on need
213+ if (outChar1 == azC2 && outChar2 == b ) {
214+ t = EinsumRewriteType .AB_BA_B_A_AZ__ZB ;
215+ c1 = azC2 ;
216+ c2 = b ;
217+ } else if (outChar2 == azC2 && outChar1 == b ) {
218+ t = EinsumRewriteType .AB_BA_B_A_AZ__BZ ;
219+ c1 = b ;
220+ c2 = azC2 ;
221+ } else {
222+ t = EinsumRewriteType .AB_BA_B_A_AZ__ZB ;
223+ c1 = azC2 ;
224+ c2 = b ;
225+ }
226+
227+ }
228+ }
229+
230+ if (charsToMatrices .containsKey (azC2 .toString ())) {
231+ Zs = charsToMatrices .get (azC2 .toString ());
250232 }
251233 }
252- else if (doSumB ){
253- t = EinsumRewriteType .AB_BA_B_XB_BX_A_XA_AX__A ;
254- c1 = AB .charAt (0 );
255- }
256- else {
257- t = EinsumRewriteType .AB_BA_B_XB_BX_A_XA_AX__AB ;
258- c1 = AB .charAt (0 );
259- c2 = AB .charAt (1 );
234+ if (t ==null ) {
235+ if (doSumA ) {
236+ if (doSumB ) {
237+ t = EinsumRewriteType .AB_BA_B_A__ ;
238+ } else {
239+ t = EinsumRewriteType .AB_BA_B_A__B ;
240+ c1 = AB .charAt (1 );
241+ }
242+ } else if (doSumB ) {
243+ t = EinsumRewriteType .AB_BA_B_A__A ;
244+ c1 = AB .charAt (0 );
245+ } else {
246+ t = EinsumRewriteType .AB_BA_B_A__AB ;
247+ c1 = AB .charAt (0 );
248+ c2 = AB .charAt (1 );
249+ }
260250 }
261251 if (c1 != null ){
262252 charToOccurences .put (c1 , charToOccurences .get (c1 )+1 );
@@ -280,6 +270,7 @@ else if(doSumB){
280270 usedOperands .addAll (XAs );
281271 usedOperands .addAll (AXs );
282272 usedOperands .addAll (AZs );
273+ usedOperands .addAll (Zs );
283274
284275 for (EOpNode n : operands ){
285276 if (!usedOperands .contains (n )){
@@ -302,7 +293,8 @@ else if(doSumB){
302293 As ,
303294 XAs ,
304295 AXs ,
305- AZs
296+ AZs ,
297+ Zs
306298 );
307299 ret .add (e );
308300 return e ;
0 commit comments