Skip to content

Commit 51b53c3

Browse files
committed
[SYSTEMDS-3836] Fix distributive binary ops rewrite for broadcasting
The distributive binary operation rewrite, transforms the pattern X-X*Y into (1-Y)*X but was so far not aware of broadcasting semantics. If Y is a row or column vector but X a matrix, the rewrite yields mismatching dimension exceptions during runtime. We now simply rewrite the pattern to X*(1-Y) if Y is indeed a vector and X is not. Always rewriting the pattern to the latter cause the mmchain rewrite to no longer trigger (which is crucial for many end-to-end algorithms). The tests have, however, also shown that for the multiLogReg test we are not compiling mmchain (independent of this rewrite change) something that needs fixing before the release.
1 parent cd16f7a commit 51b53c3

File tree

4 files changed

+22
-5
lines changed

4 files changed

+22
-5
lines changed

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -886,10 +886,14 @@ private static Hop simplifyDistributiveBinaryOperation( Hop parent, Hop hi, int
886886
X = right;
887887
Y = ( right == leftC1 ) ? leftC2 : leftC1;
888888
}
889-
if( X != null ){ //rewrite 'binary +/-'
889+
if( X != null && Y.dimsKnown() ){ //rewrite 'binary +/-'
890890
LiteralOp literal = new LiteralOp(1);
891891
BinaryOp plus = HopRewriteUtils.createBinary(Y, literal, bop.getOp());
892-
BinaryOp mult = HopRewriteUtils.createBinary(plus, X, OpOp2.MULT);
892+
893+
BinaryOp mult = (plus.getDim1()==1 || plus.getDim2() == 1)
894+
&& (X.getDim1()>1 && X.getDim2()>1) ?
895+
HopRewriteUtils.createBinary(X, plus, OpOp2.MULT) :
896+
HopRewriteUtils.createBinary(plus, X, OpOp2.MULT);
893897
HopRewriteUtils.replaceChildReference(parent, hi, mult, pos);
894898
HopRewriteUtils.cleanupUnreferenced(hi, left);
895899
hi = mult;
@@ -908,10 +912,13 @@ private static Hop simplifyDistributiveBinaryOperation( Hop parent, Hop hi, int
908912
X = left;
909913
Y = ( left == rightC1 ) ? rightC2 : rightC1;
910914
}
911-
if( X != null ){ //rewrite '+/- binary'
915+
if( X != null && Y.dimsKnown() ){ //rewrite '+/- binary'
912916
LiteralOp literal = new LiteralOp(1);
913917
BinaryOp plus = HopRewriteUtils.createBinary(literal, Y, bop.getOp());
914-
BinaryOp mult = HopRewriteUtils.createBinary(plus, X, OpOp2.MULT);
918+
BinaryOp mult = (plus.getDim1()==1 || plus.getDim2() == 1)
919+
&& (X.getDim1()>1 && X.getDim2()>1) ?
920+
HopRewriteUtils.createBinary(X, plus, OpOp2.MULT) :
921+
HopRewriteUtils.createBinary(plus, X, OpOp2.MULT);
915922
HopRewriteUtils.replaceChildReference(parent, hi, mult, pos);
916923
HopRewriteUtils.cleanupUnreferenced(hi, right);
917924
hi = mult;

src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyDistributiveBinaryOperationTest.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ public void testDistrBinaryOpMultAddNoRewrite() {
8686
public void testDistrBinaryOpMultAddRewrite() {
8787
testSimplifyDistributiveBinaryOperation(4, true); //pattern: (Y*X+X) -> (Y+1)*X
8888
}
89+
90+
@Test
91+
public void testDistrBinaryOpMultMinusVectorRewrite() {
92+
testSimplifyDistributiveBinaryOperation(5, true); //pattern: (X*Y-X) -> (Y+1)*X, Y vector
93+
}
8994

9095
private void testSimplifyDistributiveBinaryOperation(int ID, boolean rewrites) {
9196
boolean oldFlag1 = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
@@ -104,7 +109,7 @@ private void testSimplifyDistributiveBinaryOperation(int ID, boolean rewrites) {
104109

105110
//create matrices
106111
double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.60d, 3);
107-
double[][] Y = getRandomMatrix(rows, cols, -1, 1, 0.60d, 5);
112+
double[][] Y = getRandomMatrix(rows, ID==5?1:cols, -1, 1, 0.60d, 5);
108113
writeInputMatrixWithMTD("X", X, true);
109114
writeInputMatrixWithMTD("Y", Y, true);
110115

src/test/scripts/functions/rewrite/RewriteSimplifyDistributiveBinaryOperation.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ if( type == 1 ) {
4444
R = (X+Y*X)
4545
} else if( type == 4 ) {
4646
R = (Y*X+X)
47+
} else if( type == 5 ) {
48+
R = (X*(Y%*%matrix(1,1,ncol(X)))-X) * 1
4749
}
4850

4951

src/test/scripts/functions/rewrite/RewriteSimplifyDistributiveBinaryOperation.dml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ else if( type == 3 ) {
3838
else if( type == 4 ) {
3939
R = (Y*X+X) * 1
4040
}
41+
else if( type == 5 ) {
42+
R = (X*Y-X) * 1
43+
}
4144

4245

4346
# Write the result matrix R

0 commit comments

Comments
 (0)