Skip to content

Commit 80aa9f4

Browse files
include dimsize in EOpNode, implement the reordering of binary in the final plan
1 parent 1a6becf commit 80aa9f4

File tree

7 files changed

+217
-108
lines changed

7 files changed

+217
-108
lines changed

src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,20 @@
2121

2222
import org.apache.commons.logging.Log;
2323
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
24+
import scala.Int;
2425

2526
import java.util.ArrayList;
2627

2728
public abstract class EOpNode {
2829
public Character c1;
29-
public Character c2; // nullable
30-
public EOpNode(Character c1, Character c2){
30+
public Character c2;
31+
public Integer dim1;
32+
public Integer dim2;
33+
public EOpNode(Character c1, Character c2, Integer dim1, Integer dim2) {
3134
this.c1 = c1;
3235
this.c2 = c2;
36+
this.dim1 = dim1;
37+
this.dim2 = dim2;
3338
}
3439

3540
@Override
@@ -43,6 +48,6 @@ public String toString() {
4348

4449
public abstract MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numOfThreads, Log LOG);
4550

46-
public abstract void reorderChildren(Character outChar1, Character outChar2);
51+
public abstract EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Character outChar2);
4752
}
4853

src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java

Lines changed: 136 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.apache.sysds.runtime.functionobjects.ReduceRow;
3131
import org.apache.sysds.runtime.functionobjects.SwapIndex;
3232
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
33+
import org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction;
3334
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
3435
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
3536
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -48,57 +49,122 @@
4849

4950
public class EOpNodeBinary extends EOpNode {
5051

51-
5252
public enum EBinaryOperand { // upper case: char has to remain, lower case: to be summed
53-
////// summations: //////
54-
aB_a,// -> B
55-
Ba_a, // -> B
56-
Ba_aC, // mmult -> BC
57-
aB_Ca,
53+
////// mm: //////
54+
Ba_aC, // -> BC
55+
aB_Ca, // -> CB
5856
Ba_Ca, // -> BC
59-
aB_aC, // outer mult, possibly with transposing first -> BC
60-
a_a,// dot ->
57+
aB_aC, // -> BC
6158

62-
////// elementwisemult and sums, something like ij,ij->i //////
59+
////// elementwisemult and sums //////
6360
aB_aB,// elemwise and colsum -> B
6461
Ba_Ba, // elemwise and rowsum ->B
6562
Ba_aB, // elemwise, either colsum or rowsum -> B
6663
aB_Ba,
64+
ab_ab,//M-M sum all
65+
ab_ba, //M-M.T sum all
66+
aB_a,// -> B
67+
Ba_a, // -> B
6768

6869
////// elementwise, no summations: //////
6970
A_A,// v-elemwise -> A
7071
AB_AB,// M-M elemwise -> AB
7172
AB_BA, // M-M.T elemwise -> AB
7273
AB_A, // M-v colwise -> BA!?
7374
BA_A, // M-v rowwise -> BA
74-
ab_ab,//M-M sum all
75-
ab_ba, //M-M.T sum all
75+
7676
////// other //////
77+
a_a,// dot ->
7778
A_B, // outer mult -> AB
7879
A_scalar, // v-scalar
7980
AB_scalar, // m-scalar
8081
scalar_scalar
8182
}
82-
public EOpNode _left;
83-
public EOpNode _right;
84-
public EBinaryOperand _operand;
83+
public EOpNode left;
84+
public EOpNode right;
85+
public EBinaryOperand operand;
8586
private boolean transposeResult;
86-
public EOpNodeBinary(Character c1, Character c2, EOpNode left, EOpNode right, EBinaryOperand operand){
87-
super(c1,c2);
88-
this._left = left;
89-
this._right = right;
90-
this._operand = operand;
91-
}
87+
public EOpNodeBinary(EOpNode left, EOpNode right, EBinaryOperand operand){
88+
super(null,null,null, null);
89+
Character c1, c2;
90+
Integer dim1, dim2;
91+
switch(operand){
92+
case Ba_aC -> {
93+
c1=left.c1;
94+
c2=right.c2;
95+
dim1=left.dim1;
96+
dim2=right.dim2;
97+
}
98+
case aB_Ca -> {
99+
c1=left.c2;
100+
c2=right.c1;
101+
dim1=left.dim2;
102+
dim2=right.dim1;
103+
}
104+
case Ba_Ca -> {
105+
c1=left.c1;
106+
c2=right.c1;
107+
dim1=left.dim1;
108+
dim2=right.dim1;
109+
}
110+
case aB_aC -> {
111+
c1=left.c2;
112+
c2=right.c2;
113+
dim1=left.dim2;
114+
dim2=right.dim2;
115+
}
116+
case aB_aB, aB_Ba, aB_a -> {
117+
c1=left.c2;
118+
c2=null;
119+
dim1=left.dim2;
120+
dim2=null;
121+
}
122+
case Ba_Ba, Ba_aB, Ba_a, A_A, A_scalar -> {
123+
c1=left.c1;
124+
c2=null;
125+
dim1=left.dim1;
126+
dim2=null;
127+
}
128+
case ab_ab, ab_ba, a_a, scalar_scalar -> {
129+
c1=null;
130+
c2=null;
131+
dim1=null;
132+
dim2=null;
133+
}
134+
case AB_AB, AB_BA, AB_A, BA_A, AB_scalar ->{
135+
c1=left.c1;
136+
c2=left.c2;
137+
dim1=left.dim1;
138+
dim2=left.dim2;
139+
}
140+
case A_B -> {
141+
c1=left.c1;
142+
c2=right.c1;
143+
dim1=left.dim1;
144+
dim2=right.dim1;
145+
}
146+
default -> throw new IllegalStateException("EOpNodeBinary Unexpected type: " + operand);
147+
}
148+
// super(c1, c2, dim1, dim2); // unavailable in JDK < 22
149+
this.c1 = c1;
150+
this.c2 = c2;
151+
this.dim1 = dim1;
152+
this.dim2 = dim2;
153+
this.left = left;
154+
this.right = right;
155+
this.operand = operand;
156+
}
157+
92158
public void setTransposeResult(boolean transposeResult){
93159
this.transposeResult = transposeResult;
94160
}
95161

96162
public static EOpNodeBinary combineMatrixMultiply(EOpNode left, EOpNode right) {
97-
if (left.c2 == right.c1) { return new EOpNodeBinary(left.c1, right.c2, left, right, EBinaryOperand.Ba_aC); }
98-
if (left.c2 == right.c2) { return new EOpNodeBinary(left.c1, right.c1, left, right, EBinaryOperand.Ba_Ca); }
99-
if (left.c1 == right.c1) { return new EOpNodeBinary(left.c2, right.c2, left, right, EBinaryOperand.aB_aC); }
163+
if (left.c2 == right.c1) { return new EOpNodeBinary(left, right, EBinaryOperand.Ba_aC); }
164+
if (left.c2 == right.c2) { return new EOpNodeBinary(left, right, EBinaryOperand.Ba_Ca); }
165+
if (left.c1 == right.c1) { return new EOpNodeBinary(left, right, EBinaryOperand.aB_aC); }
100166
if (left.c1 == right.c2) {
101-
var res = new EOpNodeBinary(left.c2, right.c1, left, right, EBinaryOperand.aB_Ca);
167+
var res = new EOpNodeBinary(left, right, EBinaryOperand.aB_Ca);
102168
res.setTransposeResult(true);
103169
return res;
104170
}
@@ -107,10 +173,10 @@ public static EOpNodeBinary combineMatrixMultiply(EOpNode left, EOpNode right) {
107173

108174
@Override
109175
public String[] recursivePrintString() {
110-
String[] left = _left.recursivePrintString();
111-
String[] right = _right.recursivePrintString();
176+
String[] left = this.left.recursivePrintString();
177+
String[] right = this.right.recursivePrintString();
112178
String[] res = new String[left.length + right.length+1];
113-
res[0] = this.getClass().getSimpleName()+" ("+_operand.toString()+") "+this.toString();
179+
res[0] = this.getClass().getSimpleName()+" ("+ operand.toString()+") "+this.toString();
114180
for (int i=0; i<left.length; i++) {
115181
res[i+1] = (i==0 ? "┌─ " : " ") +left[i];
116182
}
@@ -123,16 +189,16 @@ public String[] recursivePrintString() {
123189
@Override
124190
public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numThreads, Log LOG) {
125191
EOpNodeBinary bin = this;
126-
MatrixBlock left = _left.computeEOpNode(inputs, numThreads, LOG);
127-
MatrixBlock right = _right.computeEOpNode(inputs, numThreads, LOG);
192+
MatrixBlock left = this.left.computeEOpNode(inputs, numThreads, LOG);
193+
MatrixBlock right = this.right.computeEOpNode(inputs, numThreads, LOG);
128194

129195
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
130196

131197
MatrixBlock res;
132198

133-
if(LOG.isTraceEnabled()) LOG.trace("computing binary "+bin._left +","+bin._right +"->"+bin);
199+
if(LOG.isTraceEnabled()) LOG.trace("computing binary "+bin.left +","+bin.right +"->"+bin);
134200

135-
switch (bin._operand){
201+
switch (bin.operand){
136202
case AB_AB -> {
137203
res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
138204
}
@@ -255,7 +321,7 @@ public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numThreads,
255321
return new MatrixBlock(left.get(0,0)*right.get(0,0));
256322
}
257323
default -> {
258-
throw new IllegalArgumentException("Unexpected value: " + bin._operand.toString());
324+
throw new IllegalArgumentException("Unexpected value: " + bin.operand.toString());
259325
}
260326

261327
}
@@ -267,25 +333,47 @@ public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numThreads,
267333
}
268334

269335
@Override
270-
public void reorderChildren(Character outChar1, Character outChar2) {
271-
if (this._operand==EBinaryOperand.aB_aC){
272-
if(this._right.c2 == outChar1) {
273-
var tmp = _left;
274-
_left = _right;
275-
_right = tmp;
276-
var tmp2 = c1;
277-
c1 = c2;
278-
c2 = tmp2;
336+
public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Character outChar2) {
337+
if (this.operand ==EBinaryOperand.aB_aC){
338+
if(this.right.c2 == outChar1) { // result is CB so Swap aB and aC
339+
var tmpLeft = left; left = right; right = tmpLeft;
340+
var tmpC1 = c1; c1 = c2; c2 = tmpC1;
341+
var tmpDim1 = dim1; dim1 = dim2; dim2 = tmpDim1;
279342
}
280-
_left.reorderChildren(_left.c2, _left.c1);
281-
// check if change happened:
282-
if(_left.c2 == _right.c1) {
283-
this._operand = EBinaryOperand.Ba_aC;
343+
if(EinsumCPInstruction.FUSE_OUTER_MULTIPLY && left instanceof EOpNodeFuse fuse && fuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__AB &&
344+
LibMatrixMult.isSkinnyRightHandSide(left.dim1, left.dim2, right.dim1, right.dim2, true)) {
345+
fuse.operands.get(4).add(right);
346+
fuse.einsumRewriteType = EOpNodeFuse.EinsumRewriteType.AB_BA_B_A_AZ__BZ;
347+
fuse.c1 = fuse.c2;
348+
fuse.c2 = right.c2;
349+
return fuse;
350+
}
351+
352+
left = left.reorderChildrenAndOptimize(this, left.c2, left.c1); // maybe can be reordered
353+
if(left.c2 == right.c1) { // check if change happened:
354+
this.operand = EBinaryOperand.Ba_aC;
284355
}
285-
}
356+
right = right.reorderChildrenAndOptimize(this, right.c1, right.c2);
357+
}else if (this.operand ==EBinaryOperand.Ba_Ca){
358+
if(this.right.c1 == outChar1) { // result is CB so Swap Ba and Ca
359+
var tmpLeft = left; left = right; right = tmpLeft;
360+
var tmpC1 = c1; c1 = c2; c2 = tmpC1;
361+
var tmpDim1 = dim1; dim1 = dim2; dim2 = tmpDim1;
362+
}
363+
364+
right = right.reorderChildrenAndOptimize(this, right.c2, right.c1); // maybe can be reordered
365+
if(left.c2 == right.c1) { // check if change happened:
366+
this.operand = EBinaryOperand.Ba_aC;
367+
}
368+
left = left.reorderChildrenAndOptimize(this, left.c1, left.c2);
369+
}else {
370+
left = left.reorderChildrenAndOptimize(this, left.c1, left.c2); // just recurse
371+
right = right.reorderChildrenAndOptimize(this, right.c1, right.c2);
372+
}
373+
return this;
286374
}
287375

288-
// used in old method
376+
// used in the old approach
289377
public static Triple<Integer, EBinaryOperand, Pair<Character, Character>> TryCombineAndCost(EOpNode n1 , EOpNode n2, HashMap<Character, Integer> charToSizeMap, HashMap<Character, Integer> charToOccurences, Character outChar1, Character outChar2){
290378
Predicate<Character> cannotBeSummed = (c) ->
291379
c == outChar1 || c == outChar2 || charToOccurences.get(c) > 2;
@@ -388,7 +476,7 @@ else if (n1.c2 == n2.c2) {
388476
return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2) +(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1)), EBinaryOperand.Ba_Ca, Pair.of(n1.c1, n2.c1)); // or n2.c1, n1.c1
389477
}
390478
}
391-
else { // something like ab,cd
479+
else { // something like AB,CD
392480
return null;
393481
}
394482
}

src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626

2727
public class EOpNodeData extends EOpNode {
2828
public int matrixIdx;
29-
public EOpNodeData(Character c1, Character c2, int matrixIdx){
30-
super(c1,c2);
29+
public EOpNodeData(Character c1, Character c2, Integer dim1, Integer dim2, int matrixIdx){
30+
super(c1,c2,dim1,dim2);
3131
this.matrixIdx = matrixIdx;
3232
}
3333
@Override
@@ -42,7 +42,7 @@ public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numOfThread
4242
}
4343

4444
@Override
45-
public void reorderChildren(Character outChar1, Character outChar2) {
46-
45+
public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Character outChar2) {
46+
return this;
4747
}
4848
}

src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,11 @@ public enum EinsumRewriteType{
7474
AB_BA_B_A_AZ__ZB,
7575
}
7676

77-
public final EinsumRewriteType einsumRewriteType;
77+
public EinsumRewriteType einsumRewriteType;
7878
public final List<List<EOpNode>> operands;
7979

80-
private EOpNodeFuse(Character c1, Character c2, EinsumRewriteType einsumRewriteType, List<EOpNode>... operands) {
81-
super(c1,c2);
80+
private EOpNodeFuse(Character c1, Character c2, Integer dim1, Integer dim2, EinsumRewriteType einsumRewriteType, List<EOpNode>... operands) {
81+
super(c1,c2, dim1, dim2);
8282
this.einsumRewriteType = einsumRewriteType;
8383
this.operands = Arrays.asList(operands);
8484
}
@@ -202,7 +202,7 @@ else if(chars.charAt(1)==b){
202202
if(AZCandidates.size()==1){
203203
if(!doSumB) {
204204
// check if outer is possible AB,...,AZ->BZ
205-
if(!EinsumCPInstruction.FUSE_OUTER_MULTIPLY || !LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(AB.charAt(0)),charToSize.get(AZCandidates.iterator().next().charAt(1)),false)) {
205+
if(!EinsumCPInstruction.FUSE_OUTER_MULTIPLY || !LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(AB.charAt(0)),charToSize.get(AZCandidates.iterator().next().charAt(1)),true)) {
206206
includeAz=false;
207207
}
208208
}
@@ -253,8 +253,8 @@ else if(chars.charAt(1)==b){
253253
String A = AB.substring(0,1);
254254
char a = A.charAt(0);
255255
char b = B.charAt(0);
256-
Character c1 = null;
257-
Character c2 = null;
256+
Character c1 = null, c2 = null;
257+
Integer dim1 = null, dim2 = null;
258258
EinsumRewriteType t = null;
259259

260260
if(!AZs.isEmpty()){
@@ -311,9 +311,11 @@ else if (EinsumCPInstruction.FUSE_OUTER_MULTIPLY) {
311311
}
312312
if(c1 != null){
313313
charToOccurences.put(c1, charToOccurences.get(c1)+1);
314+
dim1 = charToSize.get(c1);
314315
}
315316
if(c2 != null){
316317
charToOccurences.put(c2, charToOccurences.get(c2)+1);
318+
dim2 = charToSize.get(c2);
317319
}
318320
HashSet<EOpNode> usedOperands = new HashSet<>();
319321

@@ -340,7 +342,7 @@ else if (EinsumCPInstruction.FUSE_OUTER_MULTIPLY) {
340342
}
341343
}
342344

343-
var e = new EOpNodeFuse(c1, c2, t,
345+
var e = new EOpNodeFuse(c1, c2, dim1, dim2, t,
344346
ABs,
345347
BAs,
346348
Bs,
@@ -445,8 +447,10 @@ public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numThreads,
445447
}
446448

447449
@Override
448-
public void reorderChildren(Character outChar1, Character outChar2) {
449-
450+
public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Character outChar2) {
451+
for(List<EOpNode> list : operands)
452+
for(int i = 0; i < list.size(); i++) list.set(i,list.get(i).reorderChildrenAndOptimize(this, list.get(i).c1, list.get(i).c2));
453+
return this;
450454
}
451455

452456
private static @NotNull List<MatrixBlock> multiplyVectorsIntoOne(List<MatrixBlock> mbs, int size) {

0 commit comments

Comments
 (0)