Skip to content

Commit 903ceec

Browse files
fix in einsum codegen and added no-codegen fused op
1 parent e4db310 commit 903ceec

File tree

4 files changed

+535
-174
lines changed

4 files changed

+535
-174
lines changed

src/main/java/org/apache/sysds/runtime/einsum/EOpNodeEinsumFuse.java

Lines changed: 113 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
package org.apache.sysds.runtime.einsum;
22

3+
import org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction;
4+
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
5+
36
import java.util.ArrayList;
47
import java.util.Arrays;
58
import java.util.HashMap;
@@ -17,21 +20,24 @@ public class EOpNodeEinsumFuse extends EOpNode {
1720
public static final int AX_index=7;
1821
public static final int AZ_index=8;
1922
public enum EinsumRewriteType{
20-
// inputops__output 'X' = simplySumDim
21-
AB_BA_B_XB_BX_A_XA_AX__AB,
22-
AB_BA_B_XB_BX_A_XA_AX__B,
23-
AB_BA_B_XB_BX_A_XA_AX__A,
24-
AB_BA_B_XB_BX_A_XA_AX__,
25-
26-
AB_BA_B_XB_BX_A_XA_AX_AZ__Z
27-
}
28-
public enum EinsumRewriteType_v2{ // option 2 without X dims
23+
// B -> row*row, A -> row*scalar
2924
AB_BA_B_A__AB,
3025
AB_BA_B_A__B,
3126
AB_BA_B_A__A,
27+
AB_BA_B_A__,
3228

33-
AB_BA_B_A_AZ__Z
29+
// scalar from row(AB).dot(B) multiplied by row(AZ)
30+
AB_BA_B_A_AZ__Z,
31+
32+
// AC: last step is outer matrix multiplication using vector C
33+
AB_BA_B_A_AZ__BZ,
34+
AB_BA_B_A_AZ__ZB,
35+
36+
// // outer matrix multiplication using vector C and vector Z
37+
// AB_BA_B_A_AZ_AC__ZC,
38+
// AB_BA_B_A_AZ_AC__CZ,
3439
}
40+
3541
public final EinsumRewriteType einsumRewriteType;
3642
public final List<List<EOpNode>> operands;
3743

@@ -41,29 +47,17 @@ private EOpNodeEinsumFuse(Character c1, Character c2, EinsumRewriteType einsumRe
4147
this.operands = Arrays.asList(operands);
4248
}
4349

44-
public static EOpNodeEinsumFuse match(ArrayList<EOpNode> operands, Character outChar1, Character outChar2,/*, Set<Character> simplySummableChars,*/ ArrayList<EOpNode> ret, HashMap<Character, Integer> charToOccurences){
50+
public static EOpNodeEinsumFuse match(ArrayList<EOpNode> operands, Character outChar1, Character outChar2,/*, Set<Character> simplySummableChars,*/ ArrayList<EOpNode> ret, HashMap<Character, Integer> charToOccurences, HashMap<Character, Integer> charToSize){
4551
//precompute
4652
HashSet<String> matricesChars = new HashSet<>();
4753
HashMap<String, ArrayList<EOpNode>> charsToMatrices = new HashMap<>();
48-
HashMap<Character, Integer> charsToNumberOfOperands = new HashMap<>();
4954

5055
for (EOpNode operand1 : operands) {
5156
String k;
52-
//todo remove and use input map charToOccurences
53-
if (charsToNumberOfOperands.containsKey(operand1.c1)) {
54-
charsToNumberOfOperands.put(operand1.c1, charsToNumberOfOperands.get(operand1.c1) + 1);
55-
} else {
56-
charsToNumberOfOperands.put(operand1.c1, 1);
57-
}
5857

5958
if (operand1.c2 != null) {
6059
k = operand1.c1.toString() + operand1.c2;
6160
matricesChars.add(k);
62-
if (charsToNumberOfOperands.containsKey(operand1.c2)) {
63-
charsToNumberOfOperands.put(operand1.c2, charsToNumberOfOperands.get(operand1.c2) + 1);
64-
} else {
65-
charsToNumberOfOperands.put(operand1.c2, 1);
66-
}
6761
} else {
6862
k = operand1.c1.toString();
6963
}
@@ -82,13 +76,14 @@ public static EOpNodeEinsumFuse match(ArrayList<EOpNode> operands, Character out
8276
ArrayList<EOpNode> BXs = new ArrayList<>();
8377
ArrayList<EOpNode> XBs = new ArrayList<>();
8478
ArrayList<EOpNode> AZs = new ArrayList<>();
79+
// ArrayList<EOpNode> ACs = new ArrayList<>();
80+
ArrayList<EOpNode> Zs = new ArrayList<>();
8581
boolean pass = false;
8682

8783
String AB = null;
8884
String BA = null;
8985
boolean doSumA=false;
9086
boolean doSumB=false;
91-
9287
for (String ABcandidate : matricesChars) {
9388
char a = ABcandidate.charAt(0);
9489
char b = ABcandidate.charAt(1);
@@ -99,9 +94,10 @@ public static EOpNodeEinsumFuse match(ArrayList<EOpNode> operands, Character out
9994
BXs = new ArrayList<>();
10095
XBs = new ArrayList<>();
10196
AZs = new ArrayList<>();
102-
97+
Character z = null;
10398
pass=true;
104-
99+
int AZsCounter = 0;
100+
HashSet<String> AZCandidates = new HashSet<>();
105101

106102
for (String chars : charsToMatrices.keySet()) {
107103
if (chars.equals(ABcandidate) || chars.equals(BA)) {
@@ -118,111 +114,74 @@ public static EOpNodeEinsumFuse match(ArrayList<EOpNode> operands, Character out
118114
continue;
119115
//always ok
120116
}else{
121-
if(a==chars.charAt(1) && b==chars.charAt(0)){
117+
if(a==chars.charAt(1) && b==chars.charAt(0)){ //BA
122118
// ABsCounter++;
123-
//BA
124119
continue;
125120
}
126121
if(chars.charAt(0)==a){
127-
if(charsToNumberOfOperands.get(chars.charAt(1))==1){
128-
if(chars.charAt(1)!= outChar1 && chars.charAt(1) != outChar2) {
129-
AXs.addAll(charsToMatrices.get(chars));
130-
// AsCounter++;
131-
continue;
132-
}else{
133-
if(AZs.size()==0){
134-
AZs = charsToMatrices.get(chars);
135-
continue;
136-
}
137-
pass = false;
138-
break;
139-
}
140-
}else{
141-
//dont allow for now, in theory AZ,Z or AZ,AZ would also work, but for now do them separately
142-
pass = false;
143-
break;
144-
}
122+
//AZ
123+
AZsCounter++;
124+
AZCandidates.add(chars);
145125
}
146126
else if(chars.charAt(0)==b){
147-
if(charsToNumberOfOperands.get(chars.charAt(1))==1){
148-
if(chars.charAt(1)!= outChar1 && chars.charAt(1) != outChar2) {
149-
BXs.addAll(charsToMatrices.get(chars));
150-
// BsCounter++;
151-
continue;
152-
}else{
153-
pass = false; // no BZ, maybe experiment later
154-
break;
155-
}
156-
}else{
157-
pass = false;
158-
break;
159-
}
127+
// BZ, todo, maybe transpose ab into ba
128+
pass = false;
129+
break;
160130
}
161131
else if(chars.charAt(1)==a){
162-
if(charsToNumberOfOperands.get(chars.charAt(0))==1){
163-
if(chars.charAt(0)!= outChar1 && chars.charAt(0) != outChar2) {
164-
XAs.addAll(charsToMatrices.get(chars));
165-
// AsCounter++;
166-
continue;
167-
}else{
168-
pass = false;
169-
break;
170-
}
171-
}else{
132+
//ZA, maybe its small enough that it can be tranposed? but then not impactful as the bigger A, the more sense to fuse AZ?
172133
pass = false;
173134
break;
174-
}
175135
}
176136
else if(chars.charAt(1)==b){
177-
if(charsToNumberOfOperands.get(chars.charAt(0))==1){
178-
if(chars.charAt(0)!= outChar1 && chars.charAt(0) != outChar2) {
179-
XBs.addAll(charsToMatrices.get(chars));
180-
// BsCounter++;
181-
continue;
182-
}else{
183-
pass = false;
184-
break;
185-
}
186-
}else{
187-
pass = false;
188-
break;
189-
}
137+
// ZB
138+
pass = false;
139+
break;
190140
}
191141
}
192142
}
193143
if(pass){
144+
194145
AB = ABcandidate;
195146
String A = ""+a;
196147
String B = ""+b;
197148
int ABsCounter = charsToMatrices.get(ABcandidate).size()+(charsToMatrices.containsKey(BA) ? charsToMatrices.get(BA).size() : 0);
198-
int AsCounter = (charsToMatrices.containsKey(A) ? charsToMatrices.get(A).size() : 0) +AXs.size()+XAs.size();
199-
int BsCounter = (charsToMatrices.containsKey(B) ? charsToMatrices.get(B).size() : 0)+BXs.size()+XBs.size();
149+
// int AZsCounter = AZs.size();
150+
int AsCounter = (charsToMatrices.containsKey(A) ? charsToMatrices.get(A).size() : 0);
151+
int BsCounter = (charsToMatrices.containsKey(B) ? charsToMatrices.get(B).size() : 0);
200152
if(AsCounter==0 && BsCounter==0 && ABsCounter<2){
201153
pass=false;
202154
continue;
203155
}
204-
int usedAsCount = AsCounter+ABsCounter;
205156
int usedBsCount = BsCounter+ABsCounter;
206-
doSumA = charToOccurences.get(a)==usedAsCount && (outChar1 == null || a!=outChar1) && (outChar2 == null || a!=outChar2);
207157
doSumB = charToOccurences.get(b)==usedBsCount && (outChar1 == null || b!=outChar1) && (outChar2 == null || b!=outChar2);
208-
if(AZs.size()!=0) { // invalidate AZ fusion
209-
if (outChar1 != null) {
210-
if (a == outChar1 || b == outChar1) {
211-
pass=false;
212-
continue;
213-
}
214-
}
215-
if (outChar2 != null) {
216-
if (a == outChar2 || b == outChar2) {
217-
pass=false;
218-
continue;
219-
}
158+
159+
if(AZCandidates.size()==1){
160+
// if(!doSumB){
161+
// pass=false;
162+
// continue;
163+
// }
164+
int usedAsCount = AsCounter+ABsCounter+AZsCounter;
165+
doSumA = charToOccurences.get(a)==usedAsCount && (outChar1 == null || a!=outChar1) && (outChar2 == null || a!=outChar2);
166+
if(!doSumA){ // cant do AZ
167+
break;// just do AB,B,A ->AB / A
168+
}else {
169+
AZs = charsToMatrices.get(AZCandidates.iterator().next());
170+
break;//ok
220171
}
221-
if(!doSumA || !doSumB){
222-
pass=false;
223-
continue;
172+
} else if (AZCandidates.size()>=2) {
173+
doSumA = false;
174+
if(doSumB){
175+
pass=true;
176+
break; // can do it, it will create AB,B,A -> A, that will be consumed by some AZ later
224177
}
178+
pass=false;
179+
continue;
180+
225181
}
182+
int usedAsCount = AsCounter+ABsCounter;
183+
doSumA = charToOccurences.get(a)==usedAsCount && (outChar1 == null || a!=outChar1) && (outChar2 == null || a!=outChar2);
184+
226185
break;
227186
}
228187
}
@@ -232,31 +191,62 @@ else if(chars.charAt(1)==b){
232191
}
233192
String B = AB.substring(1,2);
234193
String A = AB.substring(0,1);
194+
char a = A.charAt(0);
195+
char b = B.charAt(0);
235196
Character c1 = null;
236197
Character c2 = null;
237-
EinsumRewriteType t;
198+
EinsumRewriteType t = null;
238199

239-
if(AZs.size()!=0){
240-
c1=AZs.get(0).c2;
241-
t=EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX_AZ__Z;
242-
}
243-
else if(doSumA){
200+
if(!AZs.isEmpty()){
201+
// Character azC1 = AZs.get(0).c1;
202+
Character azC2 = AZs.get(0).c2;
203+
// c1 = AZs.get(0).c2;
244204
if(doSumB) {
245-
t = EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__;
205+
t = EinsumRewriteType.AB_BA_B_A_AZ__Z;
206+
c1 = azC2;
207+
246208
}
247-
else {
248-
t = EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__B;
249-
c1 = AB.charAt(1);
209+
else if (EinsumCPInstruction.FUSE_OUTER_MULTIPLY) {
210+
if(LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(azC2), charToSize.get(AB.charAt(1)),false)||
211+
LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), azC2, charToSize.get(AB.charAt(1)), charToSize.get(azC2),false)) {
212+
// ideally this can be changed later by parent,depending on need
213+
if (outChar1 == azC2 && outChar2 == b) {
214+
t = EinsumRewriteType.AB_BA_B_A_AZ__ZB;
215+
c1 = azC2;
216+
c2 = b;
217+
} else if (outChar2 == azC2 && outChar1 == b) {
218+
t = EinsumRewriteType.AB_BA_B_A_AZ__BZ;
219+
c1 = b;
220+
c2 = azC2;
221+
} else {
222+
t = EinsumRewriteType.AB_BA_B_A_AZ__ZB;
223+
c1 = azC2;
224+
c2 = b;
225+
}
226+
227+
}
228+
}
229+
230+
if(charsToMatrices.containsKey(azC2.toString())) {
231+
Zs = charsToMatrices.get(azC2.toString());
250232
}
251233
}
252-
else if(doSumB){
253-
t= EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__A;
254-
c1= AB.charAt(0);
255-
}
256-
else {
257-
t = EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__AB;
258-
c1 = AB.charAt(0);
259-
c2 = AB.charAt(1);
234+
if(t==null) {
235+
if (doSumA) {
236+
if (doSumB) {
237+
t = EinsumRewriteType.AB_BA_B_A__;
238+
} else {
239+
t = EinsumRewriteType.AB_BA_B_A__B;
240+
c1 = AB.charAt(1);
241+
}
242+
} else if (doSumB) {
243+
t = EinsumRewriteType.AB_BA_B_A__A;
244+
c1 = AB.charAt(0);
245+
} else {
246+
t = EinsumRewriteType.AB_BA_B_A__AB;
247+
c1 = AB.charAt(0);
248+
c2 = AB.charAt(1);
249+
}
260250
}
261251
if(c1 != null){
262252
charToOccurences.put(c1, charToOccurences.get(c1)+1);
@@ -280,6 +270,7 @@ else if(doSumB){
280270
usedOperands.addAll(XAs);
281271
usedOperands.addAll(AXs);
282272
usedOperands.addAll(AZs);
273+
usedOperands.addAll(Zs);
283274

284275
for(EOpNode n : operands){
285276
if(!usedOperands.contains(n)){
@@ -302,7 +293,8 @@ else if(doSumB){
302293
As,
303294
XAs,
304295
AXs,
305-
AZs
296+
AZs,
297+
Zs
306298
);
307299
ret.add(e);
308300
return e;

0 commit comments

Comments
 (0)