Skip to content

Commit d255fc3

Browse files
bugfix and optimize outer product decision
1 parent 80aa9f4 commit d255fc3

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ public MatrixBlock computeEOpNode(ArrayList<MatrixBlock> inputs, int numThreads,
269269
res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), numThreads);
270270
}
271271
case aB_aC -> {
272-
if(false && LibMatrixMult.isSkinnyRightHandSide(left.getNumRows(), left.getNumColumns(), right.getNumRows(), right.getNumColumns(), true)){
272+
if(false && LibMatrixMult.isSkinnyRightHandSide(left.getNumRows(), left.getNumColumns(), right.getNumRows(), right.getNumColumns(), false)){
273273
res = new MatrixBlock(left.getNumColumns(), right.getNumColumns(),false);
274274
res.allocateDenseBlock();
275275
double[] m1 = left.getDenseBlock().values(0);
@@ -341,7 +341,7 @@ public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Ch
341341
var tmpDim1 = dim1; dim1 = dim2; dim2 = tmpDim1;
342342
}
343343
if(EinsumCPInstruction.FUSE_OUTER_MULTIPLY && left instanceof EOpNodeFuse fuse && fuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__AB &&
344-
LibMatrixMult.isSkinnyRightHandSide(left.dim1, left.dim2, right.dim1, right.dim2, true)) {
344+
left.dim1 * left.dim2 * 8 > LibMatrixMult.L3_CACHESIZE && LibMatrixMult.isSkinnyRightHandSide(left.dim1, left.dim2, right.dim1, right.dim2, false)) {
345345
fuse.operands.get(4).add(right);
346346
fuse.einsumRewriteType = EOpNodeFuse.EinsumRewriteType.AB_BA_B_A_AZ__BZ;
347347
fuse.c1 = fuse.c2;

src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ else if(chars.charAt(1)==b){
202202
if(AZCandidates.size()==1){
203203
if(!doSumB) {
204204
// check if outer is possible AB,...,AZ->BZ
205-
if(!EinsumCPInstruction.FUSE_OUTER_MULTIPLY || !LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(AB.charAt(0)),charToSize.get(AZCandidates.iterator().next().charAt(1)),true)) {
205+
if(!EinsumCPInstruction.FUSE_OUTER_MULTIPLY || !LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(AB.charAt(0)),charToSize.get(AZCandidates.iterator().next().charAt(1)),false)) {
206206
includeAz=false;
207207
}
208208
}
@@ -264,7 +264,7 @@ else if(chars.charAt(1)==b){
264264
c1 = azC2;
265265
}
266266
else if (EinsumCPInstruction.FUSE_OUTER_MULTIPLY) {
267-
if(LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(AB.charAt(0)),charToSize.get(azC2),false)){
267+
if(charToSize.get(AB.charAt(0)) * charToSize.get(AB.charAt(1)) * 8 > LibMatrixMult.L3_CACHESIZE && LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(AB.charAt(0)),charToSize.get(azC2),false)){
268268
if (outChar1 == azC2 && outChar2 == b) {
269269
t = EinsumRewriteType.AB_BA_B_A_AZ__ZB;
270270
c1 = azC2;
@@ -280,10 +280,12 @@ else if (EinsumCPInstruction.FUSE_OUTER_MULTIPLY) {
280280
}
281281

282282
}else{
283+
doSumA=false;
283284
t=null;
284285
AZs=new ArrayList<>();
285286
}
286287
}else{
288+
doSumA=false;
287289
t=null;
288290
AZs=new ArrayList<>();
289291
}

0 commit comments

Comments
 (0)