Skip to content

Commit 9e649c8

Browse files
aarnatymboehm7
authored andcommitted
[SYSTEMDS-3884] Additional rewrites subtraction and addition
-(B-A)->A-B t(A+1)+2 -> t(A)+1+2 -> t(A)+3 Closes #2258.
1 parent 467c553 commit 9e649c8

File tree

7 files changed

+399
-1
lines changed

7 files changed

+399
-1
lines changed

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,8 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
199199
hi = simplifyBinaryComparisonChain(hop, hi, i); //e.g., outer(v1,v2,"==")==1 -> outer(v1,v2,"=="), outer(v1,v2,"==")==0 -> outer(v1,v2,"!="),
200200
hi = simplifyCumsumColOrFullAggregates(hi); //e.g., colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1))
201201
hi = simplifyCumsumReverse(hop, hi, i); //e.g., rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X)
202-
202+
hi = simplifyNegatedSubtraction(hop, hi, i); //e.g., -(B-A)->A-B
203+
hi = simplifyTransposeAddition(hop, hi, i); //e.g., t(A+1)+2 -> t(A)+1+2 -> t(A)+3
203204
hi = simplifyNotOverComparisons(hop, hi, i); //e.g., !(A>B) -> (A<=B)
204205
//hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
205206

@@ -211,6 +212,106 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
211212
hop.setVisited();
212213
}
213214

215+
private static Hop simplifyTransposeAddition(Hop parent, Hop hi, int pos) {
216+
if (!(hi instanceof BinaryOp)
217+
|| ((BinaryOp)hi).getOp() != OpOp2.PLUS
218+
|| hi.getDataType() != DataType.MATRIX)
219+
return hi;
220+
221+
BinaryOp bop = (BinaryOp)hi;
222+
223+
ReorgOp tSide = null;
224+
LiteralOp litSide = null;
225+
Hop in0 = bop.getInput().get(0), in1 = bop.getInput().get(1);
226+
if (in0 instanceof ReorgOp && ((ReorgOp)in0).getOp() == ReOrgOp.TRANS
227+
&& in1 instanceof LiteralOp) {
228+
tSide = (ReorgOp)in0;
229+
litSide = (LiteralOp)in1;
230+
}
231+
else if (in1 instanceof ReorgOp && ((ReorgOp)in1).getOp() == ReOrgOp.TRANS
232+
&& in0 instanceof LiteralOp) {
233+
tSide = (ReorgOp)in1;
234+
litSide = (LiteralOp)in0;
235+
}
236+
else
237+
return hi;
238+
239+
//check if only consumer
240+
if (tSide.getParent().size() > 1) {
241+
return hi;
242+
}
243+
244+
Hop inner = tSide.getInput().get(0);
245+
if (!(inner instanceof BinaryOp)
246+
|| ((BinaryOp)inner).getOp() != OpOp2.PLUS
247+
|| inner.getDataType() != DataType.MATRIX)
248+
return hi;
249+
250+
BinaryOp ib = (BinaryOp)inner;
251+
252+
Hop X = null;
253+
LiteralOp lit1 = null;
254+
Hop i0 = ib.getInput().get(0), i1 = ib.getInput().get(1);
255+
if (i0 instanceof LiteralOp) {
256+
lit1 = (LiteralOp)i0;
257+
X = i1;
258+
}
259+
else if (i1 instanceof LiteralOp) {
260+
lit1 = (LiteralOp)i1;
261+
X = i0;
262+
}
263+
else
264+
return hi;
265+
266+
double c = lit1.getDoubleValue() + litSide.getDoubleValue();
267+
268+
ReorgOp newT = HopRewriteUtils.createTranspose(X);
269+
newT.setDim1(tSide.getDim1());
270+
newT.setDim2(tSide.getDim2());
271+
272+
LiteralOp newLit = new LiteralOp(c);
273+
newLit.setDim1(1);
274+
newLit.setDim2(1);
275+
276+
//creating new binaryOp
277+
BinaryOp newPlus = HopRewriteUtils.createBinary(newT, newLit, OpOp2.PLUS);
278+
newPlus.setDim1(bop.getDim1());
279+
newPlus.setDim2(bop.getDim2());
280+
281+
HopRewriteUtils.replaceChildReference(parent, bop, newPlus, pos);
282+
HopRewriteUtils.cleanupUnreferenced(bop, tSide, ib, litSide);
283+
284+
LOG.debug("Applied simplifyTransposeAddition (line " + hi.getBeginLine() + ").");
285+
286+
return newPlus;
287+
}
288+
289+
private static Hop simplifyNegatedSubtraction(Hop parent, Hop hi, int pos) {
290+
if (hi instanceof BinaryOp
291+
&& ((BinaryOp) hi).getOp() == OpOp2.MINUS
292+
&& HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 0)
293+
&& hi.getParent().size() == 1
294+
&& hi.getInput().get(1) instanceof BinaryOp
295+
&& ((BinaryOp) hi.getInput().get(1)).getOp() == OpOp2.MINUS
296+
&& hi.getInput().get(1).getParent().size() == 1)
297+
{
298+
Hop innerMinus = hi.getInput().get(1);
299+
Hop B = innerMinus.getInput().get(0);
300+
Hop A = innerMinus.getInput().get(1);
301+
302+
BinaryOp newHop = HopRewriteUtils.createBinary(A, B, OpOp2.MINUS);
303+
304+
HopRewriteUtils.copyLineNumbers(hi, newHop);
305+
HopRewriteUtils.replaceChildReference(parent, hi, newHop, pos);
306+
HopRewriteUtils.cleanupUnreferenced(hi);
307+
hi = newHop;
308+
309+
LOG.debug("Applied simplifyNegatedSubtraction (line " + hi.getBeginLine() + ").");
310+
}
311+
return hi;
312+
}
313+
314+
214315
private static Hop removeUnnecessaryVectorizeOperation(Hop hi)
215316
{
216317
//applies to all binary matrix operations, if one input is unnecessarily vectorized
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.sysds.test.functions.rewrite;
20+
21+
import org.apache.sysds.hops.OptimizerUtils;
22+
import org.apache.sysds.runtime.matrix.data.MatrixValue;
23+
import org.apache.sysds.test.AutomatedTestBase;
24+
import org.apache.sysds.test.TestConfiguration;
25+
import org.apache.sysds.test.TestUtils;
26+
import org.apache.sysds.utils.Statistics;
27+
import org.junit.Test;
28+
import org.junit.Assert;
29+
import java.util.HashMap;
30+
31+
public class RewriteSimplifyNegatedSubtractionTest extends AutomatedTestBase {
32+
private static final String TEST_NAME = "RewriteNegatedSubtraction";
33+
private static final String TEST_DIR = "functions/rewrite/";
34+
private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyNegatedSubtractionTest.class.getSimpleName() + "/";
35+
private static final int rows = 100;
36+
private static final int cols = 100;
37+
38+
@Override
39+
public void setUp() {
40+
TestUtils.clearAssertionInformation();
41+
addTestConfiguration(TEST_NAME,
42+
new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"R"}));
43+
}
44+
45+
@Test
46+
public void testRewriteEnabled() {
47+
runRewriteTest(true);
48+
}
49+
50+
@Test
51+
public void testRewriteDisabled() {
52+
runRewriteTest(false);
53+
}
54+
55+
private void runRewriteTest(boolean rewriteEnabled) {
56+
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
57+
try {
58+
TestConfiguration config = getTestConfiguration(TEST_NAME);
59+
loadTestConfiguration(config);
60+
61+
String HOME = SCRIPT_DIR + TEST_DIR;
62+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
63+
fullRScriptName = HOME + TEST_NAME + ".R";
64+
programArgs = new String[]{"-stats", "-args", input("A"), input("B"), output("R")};
65+
rCmd = getRCmd(inputDir(), expectedDir());
66+
67+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewriteEnabled;
68+
69+
// Generate input matrices
70+
double[][] A = getRandomMatrix(rows, cols, -10, 10, 0.7, 3);
71+
double[][] B = getRandomMatrix(rows, cols, -10, 10, 0.7, 7);
72+
writeInputMatrixWithMTD("A", A, true);
73+
writeInputMatrixWithMTD("B", B, true);
74+
75+
// Run DML script
76+
runTest(true, false, null, -1);
77+
runRScript(true);
78+
79+
HashMap<MatrixValue.CellIndex, Double> dml = readDMLMatrixFromOutputDir("R");
80+
HashMap<MatrixValue.CellIndex, Double> r = readRMatrixFromExpectedDir("R");
81+
82+
Assert.assertEquals("DML and R outputs do not match", r, dml);
83+
if( rewriteEnabled )
84+
Assert.assertEquals(1, Statistics.getCPHeavyHitterCount("-"));
85+
}
86+
finally {
87+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
88+
}
89+
}
90+
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.sysds.test.functions.rewrite;
20+
21+
import java.util.HashMap;
22+
import org.apache.sysds.hops.OptimizerUtils;
23+
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
24+
import org.apache.sysds.test.AutomatedTestBase;
25+
import org.apache.sysds.test.TestConfiguration;
26+
import org.apache.sysds.test.TestUtils;
27+
import org.apache.sysds.utils.Statistics;
28+
import org.junit.Assert;
29+
import org.junit.Test;
30+
31+
public class RewriteSimplifyTransposeAdditionTest extends AutomatedTestBase {
32+
private static final String TEST_NAME = "RewriteSimplifyTransposeAddition";
33+
private static final String TEST_DIR = "functions/rewrite/";
34+
private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyTransposeAdditionTest.class.getSimpleName() + "/";
35+
36+
private static final int rows = 100;
37+
private static final int cols = 100;
38+
39+
@Override
40+
public void setUp() {
41+
TestUtils.clearAssertionInformation();
42+
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"R"}));
43+
}
44+
45+
@Test
46+
public void testRewriteEnabled() {
47+
runRewriteTest(true);
48+
}
49+
50+
@Test
51+
public void testRewriteDisabled() {
52+
runRewriteTest(false);
53+
}
54+
55+
private void runRewriteTest(boolean rewriteEnabled) {
56+
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
57+
try {
58+
TestConfiguration config = getTestConfiguration(TEST_NAME);
59+
loadTestConfiguration(config);
60+
61+
String HOME = SCRIPT_DIR + TEST_DIR;
62+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
63+
fullRScriptName = HOME + TEST_NAME + ".R";
64+
65+
// DML script parameters
66+
programArgs = new String[]{"-stats", "-args", input("A"), output("R")};
67+
rCmd = getRCmd(inputDir(), expectedDir());
68+
69+
// Set optimizer flags
70+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewriteEnabled;
71+
72+
// Generate input matrix
73+
double[][] A = getRandomMatrix(rows, cols, -10, 10, 0.7, 3);
74+
writeInputMatrixWithMTD("A", A, true);
75+
76+
// Run DML and R scripts
77+
runTest(true, false, null, -1);
78+
runRScript(true);
79+
80+
// Compare output matrices
81+
HashMap<CellIndex, Double> dml = readDMLMatrixFromOutputDir("R");
82+
HashMap<CellIndex, Double> r = readRMatrixFromExpectedDir("R");
83+
84+
Assert.assertEquals("DML and R outputs do not match", r, dml);
85+
if( rewriteEnabled )
86+
Assert.assertEquals(1, Statistics.getCPHeavyHitterCount("+"));
87+
}
88+
finally {
89+
// Reset optimizer flags
90+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
91+
}
92+
}
93+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
library("Matrix")
22+
23+
args <- commandArgs(TRUE)
24+
25+
A <- as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
26+
B <- as.matrix(readMM(paste(args[1], "B.mtx", sep="")))
27+
28+
R <- A - B
29+
30+
writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""))
31+
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
A = read($1);
22+
B = read($2);
23+
24+
# Expression that will be rewritten
25+
R = 0 - (B - A);
26+
27+
write(R, $3);
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
args <- commandArgs(TRUE)
22+
library("Matrix")
23+
24+
A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
25+
26+
# Compute t(A)+3
27+
R <- t(A)+3
28+
29+
# Write the result matrix
30+
writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""))

0 commit comments

Comments
 (0)