Skip to content

Commit 24535dd

Browse files
saving work before my ssd dies
1 parent 5dfa26f commit 24535dd

File tree

7 files changed

+940
-78
lines changed

7 files changed

+940
-78
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package org.apache.sysds.runtime.einsum;
2+
3+
public abstract class EOpNode {
4+
public Character c1;
5+
public Character c2; // nullable
6+
public EOpNode(Character c1, Character c2){
7+
this.c1 = c1;
8+
this.c2 = c2;
9+
}
10+
11+
@Override
12+
public String toString() {
13+
if(c1 == null) return "-";
14+
15+
if(c2 == null) return c1.toString();
16+
return c1.toString() + c2.toString();
17+
}
18+
}
19+
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package org.apache.sysds.runtime.einsum;
2+
3+
public class EOpNodeBinary extends EOpNode {
4+
public enum EBinaryOperand { // upper case: char has to remain, lower case: to be summed
5+
////// summations: //////
6+
aB_a,// -> B
7+
Ba_a, // -> B
8+
Ba_aC, // mmult -> BC
9+
aB_Ca,
10+
Ba_Ca, // -> BC
11+
aB_aC, // outer mult, possibly with transposing first -> BC
12+
a_a,// dot ->
13+
14+
////// elementwisemult and sums, something like ij,ij->i //////
15+
aB_aB,// elemwise and colsum -> B
16+
Ba_Ba, // elemwise and rowsum ->B
17+
Ba_aB, // elemwise, either colsum or rowsum -> B
18+
aB_Ba,
19+
20+
////// elementwise, no summations: //////
21+
A_A,// v-elemwise -> A
22+
AB_AB,// M-M elemwise -> AB
23+
AB_BA, // M-M.T elemwise -> AB
24+
AB_A, // M-v colwise -> BA!?
25+
BA_A, // M-v rowwise -> BA
26+
ab_ab,//M-M sum all
27+
ab_ba, //M-M.T sum all
28+
////// other //////
29+
A_B, // outer mult -> AB
30+
A_scalar, // v-scalar
31+
AB_scalar, // m-scalar
32+
scalar_scalar
33+
}
34+
public EOpNode left;
35+
public EOpNode right;
36+
public EBinaryOperand operand;
37+
public EOpNodeBinary(Character c1, Character c2, EOpNode left, EOpNode right, EBinaryOperand operand){
38+
super(c1,c2);
39+
this.left = left;
40+
this.right = right;
41+
this.operand = operand;
42+
}
43+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package org.apache.sysds.runtime.einsum;
2+
3+
public class EOpNodeData extends EOpNode {
4+
public int matrixIdx;
5+
public EOpNodeData(Character c1, Character c2, int matrixIdx){
6+
super(c1,c2);
7+
this.matrixIdx = matrixIdx;
8+
}
9+
}
Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
package org.apache.sysds.runtime.einsum;
2+
3+
import java.util.ArrayList;
4+
import java.util.Arrays;
5+
import java.util.HashMap;
6+
import java.util.HashSet;
7+
import java.util.List;
8+
9+
public class EOpNodeEinsumFuse extends EOpNode {
10+
public static final int AB_index=0;
11+
public static final int BA_index=1;
12+
public static final int B_index=2;
13+
public static final int XB_index=3;
14+
public static final int BX_index=4;
15+
public static final int A_index=5;
16+
public static final int XA_index=6;
17+
public static final int AX_index=7;
18+
public static final int AZ_index=8;
19+
public enum EinsumRewriteType{
20+
// inputops__output 'X' = simplySumDim
21+
AB_BA_B_XB_BX_A_XA_AX__AB,
22+
AB_BA_B_XB_BX_A_XA_AX__B,
23+
AB_BA_B_XB_BX_A_XA_AX__A,
24+
AB_BA_B_XB_BX_A_XA_AX__,
25+
26+
AB_BA_B_XB_BX_A_XA_AX_AZ__Z
27+
}
28+
public enum EinsumRewriteType_v2{ // option 2 without X dims
29+
AB_BA_B_A__AB,
30+
AB_BA_B_A__B,
31+
AB_BA_B_A__A,
32+
33+
AB_BA_B_A_AZ__Z
34+
}
35+
public final EinsumRewriteType einsumRewriteType;
36+
public final List<List<EOpNode>> operands;
37+
38+
private EOpNodeEinsumFuse(Character c1, Character c2, EinsumRewriteType einsumRewriteType, List<EOpNode>... operands) {
39+
super(c1,c2);
40+
this.einsumRewriteType = einsumRewriteType;
41+
this.operands = Arrays.asList(operands);
42+
}
43+
44+
public static EOpNodeEinsumFuse match(ArrayList<EOpNode> operands, Character outChar1, Character outChar2,/*, Set<Character> simplySummableChars,*/ ArrayList<EOpNode> ret, HashMap<Character, Integer> charToOccurences){
45+
//precompute
46+
HashSet<String> matricesChars = new HashSet<>();
47+
HashMap<String, ArrayList<EOpNode>> charsToMatrices = new HashMap<>();
48+
HashMap<Character, Integer> charsToNumberOfOperands = new HashMap<>();
49+
50+
for (EOpNode operand1 : operands) {
51+
String k;
52+
//todo remove and use input map charToOccurences
53+
if (charsToNumberOfOperands.containsKey(operand1.c1)) {
54+
charsToNumberOfOperands.put(operand1.c1, charsToNumberOfOperands.get(operand1.c1) + 1);
55+
} else {
56+
charsToNumberOfOperands.put(operand1.c1, 1);
57+
}
58+
59+
if (operand1.c2 != null) {
60+
k = operand1.c1.toString() + operand1.c2;
61+
matricesChars.add(k);
62+
if (charsToNumberOfOperands.containsKey(operand1.c2)) {
63+
charsToNumberOfOperands.put(operand1.c2, charsToNumberOfOperands.get(operand1.c2) + 1);
64+
} else {
65+
charsToNumberOfOperands.put(operand1.c2, 1);
66+
}
67+
} else {
68+
k = operand1.c1.toString();
69+
}
70+
71+
if (charsToMatrices.containsKey(k)) {
72+
charsToMatrices.get(k).add(operand1);
73+
} else {
74+
ArrayList<EOpNode> matrices = new ArrayList<>();
75+
matrices.add(operand1);
76+
charsToMatrices.put(k, matrices);
77+
}
78+
}
79+
80+
ArrayList<EOpNode> AXs = new ArrayList<>();
81+
ArrayList<EOpNode> XAs = new ArrayList<>();
82+
ArrayList<EOpNode> BXs = new ArrayList<>();
83+
ArrayList<EOpNode> XBs = new ArrayList<>();
84+
ArrayList<EOpNode> AZs = new ArrayList<>();
85+
boolean pass = false;
86+
87+
String AB = null;
88+
String BA = null;
89+
boolean doSumA=false;
90+
boolean doSumB=false;
91+
92+
for (String ABcandidate : matricesChars) {
93+
char a = ABcandidate.charAt(0);
94+
char b = ABcandidate.charAt(1);
95+
BA = "" + b + a;
96+
97+
AXs = new ArrayList<>();
98+
XAs = new ArrayList<>();
99+
BXs = new ArrayList<>();
100+
XBs = new ArrayList<>();
101+
AZs = new ArrayList<>();
102+
103+
pass=true;
104+
105+
106+
for (String chars : charsToMatrices.keySet()) {
107+
if (chars.equals(ABcandidate) || chars.equals(BA)) {
108+
// ABsCounter++;
109+
continue;
110+
}
111+
112+
if(chars.length()==1){
113+
if(chars.charAt(0)==a){
114+
// AsCounter++;
115+
}else if(chars.charAt(0)==b){
116+
// BsCounter++;
117+
}
118+
continue;
119+
//always ok
120+
}else{
121+
if(a==chars.charAt(1) && b==chars.charAt(0)){
122+
// ABsCounter++;
123+
//BA
124+
continue;
125+
}
126+
if(chars.charAt(0)==a){
127+
if(charsToNumberOfOperands.get(chars.charAt(1))==1){
128+
if(chars.charAt(1)!= outChar1 && chars.charAt(1) != outChar2) {
129+
AXs.addAll(charsToMatrices.get(chars));
130+
// AsCounter++;
131+
continue;
132+
}else{
133+
if(AZs.size()==0){
134+
AZs = charsToMatrices.get(chars);
135+
continue;
136+
}
137+
pass = false;
138+
break;
139+
}
140+
}else{
141+
//dont allow for now, in theory AZ,Z or AZ,AZ would also work, but for now do them separately
142+
pass = false;
143+
break;
144+
}
145+
}
146+
else if(chars.charAt(0)==b){
147+
if(charsToNumberOfOperands.get(chars.charAt(1))==1){
148+
if(chars.charAt(1)!= outChar1 && chars.charAt(1) != outChar2) {
149+
BXs.addAll(charsToMatrices.get(chars));
150+
// BsCounter++;
151+
continue;
152+
}else{
153+
pass = false; // no BZ, maybe experiment later
154+
break;
155+
}
156+
}else{
157+
pass = false;
158+
break;
159+
}
160+
}
161+
else if(chars.charAt(1)==a){
162+
if(charsToNumberOfOperands.get(chars.charAt(0))==1){
163+
if(chars.charAt(0)!= outChar1 && chars.charAt(0) != outChar2) {
164+
XAs.addAll(charsToMatrices.get(chars));
165+
// AsCounter++;
166+
continue;
167+
}else{
168+
pass = false;
169+
break;
170+
}
171+
}else{
172+
pass = false;
173+
break;
174+
}
175+
}
176+
else if(chars.charAt(1)==b){
177+
if(charsToNumberOfOperands.get(chars.charAt(0))==1){
178+
if(chars.charAt(0)!= outChar1 && chars.charAt(0) != outChar2) {
179+
XBs.addAll(charsToMatrices.get(chars));
180+
// BsCounter++;
181+
continue;
182+
}else{
183+
pass = false;
184+
break;
185+
}
186+
}else{
187+
pass = false;
188+
break;
189+
}
190+
}
191+
}
192+
}
193+
if(pass){
194+
AB = ABcandidate;
195+
String A = ""+a;
196+
String B = ""+b;
197+
int ABsCounter = charsToMatrices.get(ABcandidate).size()+(charsToMatrices.containsKey(BA) ? charsToMatrices.get(BA).size() : 0);
198+
int AsCounter = (charsToMatrices.containsKey(A) ? charsToMatrices.get(A).size() : 0) +AXs.size()+XAs.size();
199+
int BsCounter = (charsToMatrices.containsKey(B) ? charsToMatrices.get(B).size() : 0)+BXs.size()+XBs.size();
200+
if(AsCounter==0 && BsCounter==0 && ABsCounter<2){
201+
pass=false;
202+
continue;
203+
}
204+
int usedAsCount = AsCounter+ABsCounter;
205+
int usedBsCount = BsCounter+ABsCounter;
206+
doSumA = charToOccurences.get(a)==usedAsCount && (outChar1 == null || a!=outChar1) && (outChar2 == null || a!=outChar2);
207+
doSumB = charToOccurences.get(b)==usedBsCount && (outChar1 == null || b!=outChar1) && (outChar2 == null || b!=outChar2);
208+
if(AZs.size()!=0) { // invalidate AZ fusion
209+
if (outChar1 != null) {
210+
if (a == outChar1 || b == outChar1) {
211+
pass=false;
212+
continue;
213+
}
214+
}
215+
if (outChar2 != null) {
216+
if (a == outChar2 || b == outChar2) {
217+
pass=false;
218+
continue;
219+
}
220+
}
221+
if(!doSumA || !doSumB){
222+
pass=false;
223+
continue;
224+
}
225+
}
226+
break;
227+
}
228+
}
229+
230+
if(!pass){
231+
return null;
232+
}
233+
String B = AB.substring(1,2);
234+
String A = AB.substring(0,1);
235+
Character c1 = null;
236+
Character c2 = null;
237+
EinsumRewriteType t;
238+
239+
if(AZs.size()!=0){
240+
c1=AZs.get(0).c2;
241+
t=EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX_AZ__Z;
242+
}
243+
else if(doSumA){
244+
if(doSumB) {
245+
t = EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__;
246+
}
247+
else {
248+
t = EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__B;
249+
c1 = AB.charAt(1);
250+
}
251+
}
252+
else if(doSumB){
253+
t= EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__A;
254+
c1= AB.charAt(0);
255+
}
256+
else {
257+
t = EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__AB;
258+
c1 = AB.charAt(0);
259+
c2 = AB.charAt(1);
260+
}
261+
if(c1 != null){
262+
charToOccurences.put(c1, charToOccurences.get(c1)+1);
263+
}
264+
if(c2 != null){
265+
charToOccurences.put(c2, charToOccurences.get(c2)+1);
266+
}
267+
HashSet<EOpNode> usedOperands = new HashSet<>();
268+
269+
ArrayList<EOpNode> ABs=charsToMatrices.containsKey(AB) ? charsToMatrices.get(AB) : new ArrayList<>();
270+
ArrayList<EOpNode> BAs=charsToMatrices.containsKey(BA) ? charsToMatrices.get(BA) : new ArrayList<>();
271+
ArrayList<EOpNode> Bs=charsToMatrices.containsKey(B) ? charsToMatrices.get(B) : new ArrayList<>();
272+
ArrayList<EOpNode> As=charsToMatrices.containsKey(A) ? charsToMatrices.get(A) : new ArrayList<>();
273+
274+
usedOperands.addAll(ABs);
275+
usedOperands.addAll(BAs);
276+
usedOperands.addAll(Bs);
277+
usedOperands.addAll(As);
278+
usedOperands.addAll(XBs);
279+
usedOperands.addAll(BXs);
280+
usedOperands.addAll(XAs);
281+
usedOperands.addAll(AXs);
282+
usedOperands.addAll(AZs);
283+
284+
for(EOpNode n : operands){
285+
if(!usedOperands.contains(n)){
286+
ret.add(n);
287+
}else{
288+
if(charToOccurences != null){
289+
charToOccurences.put(n.c1, charToOccurences.get(n.c1)-1);
290+
if(charToOccurences.get(n.c2)!= null)
291+
charToOccurences.put(n.c2, charToOccurences.get(n.c2)-1);
292+
}
293+
}
294+
}
295+
296+
var e = new EOpNodeEinsumFuse(c1, c2, t,
297+
ABs,
298+
BAs,
299+
Bs,
300+
XBs,
301+
BXs,
302+
As,
303+
XAs,
304+
AXs,
305+
AZs
306+
);
307+
ret.add(e);
308+
return e;
309+
}
310+
}
311+
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package org.apache.sysds.runtime.einsum;
2+
3+
public class EOpNodeFused extends EOpNode {
4+
public EOpNodeFused(Character c1, Character c2){
5+
super(c1,c2);
6+
7+
}
8+
}

0 commit comments

Comments
 (0)