Skip to content

Commit 1e86da3

Browse files
committed
[SYSTEMDS-3806] Robustness simplifyDotProductSum rewrite
This patch fixes an issue of incorrect application of the simplifyDotProductSum rewrite. Specifically, sum(s*V) was rewritten to t(s) %*% V because s was assumed to be a vector of equal size than V but was a scalar. The root cause of an incorrect size propagation for the new scalar right indexing, but for robustness we now also check that both inputs are actually matrices.
1 parent 0743613 commit 1e86da3

File tree

3 files changed

+9
-2
lines changed

3 files changed

+9
-2
lines changed

src/main/java/org/apache/sysds/hops/IndexingOp.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,14 @@ private static IndexingMethod optFindIndexingMethod( boolean singleRow, boolean
370370
@Override
371371
public void refreshSizeInformation()
372372
{
373+
// early abort for scalar right indexing
374+
// (important to prevent incorrect dynamic rewrites)
375+
if( isScalar() ) {
376+
setDim1(0);
377+
setDim2(0);
378+
return;
379+
}
380+
373381
Hop input1 = getInput().get(0); //matrix
374382
Hop input2 = getInput().get(1); //inpRowL
375383
Hop input3 = getInput().get(2); //inpRowU

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2312,6 +2312,7 @@ private static Hop simplifyDotProductSum(Hop parent, Hop hi, int pos) {
23122312
//check for sum(v1*v2), but prevent to rewrite sum(v1*v2*v3) which is later compiled into a ta+* lop
23132313
else if( HopRewriteUtils.isBinary(hi2, OpOp2.MULT, 1) //no other consumer than sum
23142314
&& hi2.getInput().get(0).getDim2()==1 && hi2.getInput().get(1).getDim2()==1
2315+
&& hi2.getInput().get(0).isMatrix() && hi2.getInput().get(1).isMatrix()
23152316
&& !HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT)
23162317
&& !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT)
23172318
&& ( !ALLOW_SUM_PRODUCT_REWRITES

src/test/scripts/functions/unary/matrix/eigen.dml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@ numEval = $2;
3333
D = matrix(1, numEval, 1);
3434
for ( i in 1:numEval ) {
3535
Av = A %*% evec[,i];
36-
while(FALSE){} #fix incorrect rewrite sequence
3736
rhs = as.scalar(eval[i,1]) * evec[,i];
38-
while(FALSE){} #fix incorrect rewrite sequence
3937
diff = sum(Av-rhs);
4038
D[i,1] = diff;
4139
}

0 commit comments

Comments
 (0)