Skip to content

Commit 4a48a68

Browse files
aarnatymboehm7
authored andcommitted
[SYSTEMDS-3889] New simplification rewrite for matrix-scalar ops
e.g., a-A-b -> (a-b)-A; a+A-b -> (a-b)+A Closes #2272.
1 parent 61afba5 commit 4a48a68

File tree

6 files changed

+249
-0
lines changed

6 files changed

+249
-0
lines changed

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
202202
hi = simplifyNegatedSubtraction(hop, hi, i); //e.g., -(B-A)->A-B
203203
hi = simplifyTransposeAddition(hop, hi, i); //e.g., t(A+s1)+s2 -> t(A)+(s1+s2) + potential constant folding
204204
hi = simplifyNotOverComparisons(hop, hi, i); //e.g., !(A>B) -> (A<=B)
205+
hi = simplifyMatrixScalarPMOperation(hop, hi, i); //e.g., a-A-b -> (a-b)-A; a+A-b -> (a-b)+A
205206
//hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
206207

207208
//process childs recursively after rewrites (to investigate pattern newly created by rewrites)
@@ -212,6 +213,40 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
212213
hop.setVisited();
213214
}
214215

216+
private Hop simplifyMatrixScalarPMOperation(Hop parent, Hop hi, int pos) {
217+
if (!(hi instanceof BinaryOp))
218+
return hi;
219+
220+
BinaryOp outer = (BinaryOp) hi;
221+
Hop left = outer.getInput(0);
222+
Hop right = outer.getInput(1);
223+
OpOp2 outerOp = outer.getOp();
224+
225+
if((outerOp != OpOp2.PLUS && outerOp != OpOp2.MINUS) || !(left instanceof BinaryOp))
226+
return hi;
227+
228+
Hop a = left.getInput(0);
229+
Hop A = left.getInput(1);
230+
Hop b = right;
231+
232+
java.util.function.Predicate<Hop> isScalar = h -> h.getDataType().isScalar();
233+
if (!isScalar.test(a) || !isScalar.test(b) || A.getDataType() != DataType.MATRIX)
234+
return hi;
235+
236+
// Determine the scalarOp (between a and b) and matrixOp (with A)
237+
OpOp2 innerOp = ((BinaryOp)left).getOp();
238+
if( innerOp != OpOp2.PLUS && innerOp != OpOp2.MINUS )
239+
return hi;
240+
OpOp2 scalarOp = (outerOp == OpOp2.PLUS) ? OpOp2.PLUS : OpOp2.MINUS;
241+
OpOp2 matrixOp = (innerOp == OpOp2.PLUS) ? OpOp2.PLUS : OpOp2.MINUS;
242+
Hop scalarCombined = HopRewriteUtils.createBinary(a, b, scalarOp);
243+
Hop result = HopRewriteUtils.createBinary(scalarCombined, A, matrixOp);
244+
245+
HopRewriteUtils.replaceChildReference(parent, hi, result, pos);
246+
LOG.debug("Applied simplifyMatrixScalarPMOperation");
247+
return result;
248+
}
249+
215250
private static Hop simplifyTransposeAddition(Hop parent, Hop hi, int pos) {
216251
//pattern: t(A+s1)+s2 -> t(A)+(s1+s2), and subsequent constant folding
217252
if (HopRewriteUtils.isBinary(hi, OpOp2.PLUS)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.sysds.test.functions.rewrite;
20+
21+
import org.apache.sysds.hops.OptimizerUtils;
22+
import org.apache.sysds.runtime.matrix.data.MatrixValue;
23+
import org.apache.sysds.test.AutomatedTestBase;
24+
import org.apache.sysds.test.TestConfiguration;
25+
import org.apache.sysds.test.TestUtils;
26+
import org.junit.Test;
27+
28+
import java.util.HashMap;
29+
30+
public class RewriteSimplifyScalarMatrixPMOperationTest extends AutomatedTestBase {
31+
private static final String TEST_NAME1 = "RewriteScalarMinusMatrixMinusScalar";
32+
private static final String TEST_NAME2 = "RewriteScalarPlusMatrixMinusScalar";
33+
private static final String TEST_DIR = "functions/rewrite/";
34+
private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyScalarMatrixPMOperationTest.class.getSimpleName() + "/";
35+
private static final int rows = 100;
36+
private static final int cols = 100;
37+
private static final double eps = 1e-6;
38+
39+
@Override
40+
public void setUp() {
41+
TestUtils.clearAssertionInformation();
42+
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"A", "a", "b", "R"}));
43+
addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"A", "a", "b", "R"}));
44+
}
45+
46+
@Test
47+
public void testScalarMinusMatrixMinusScalarRewriteEnabled() {
48+
runRewriteTest(TEST_NAME1, true);
49+
}
50+
51+
@Test
52+
public void testScalarMinusMatrixMinusScalarRewriteDisabled() {
53+
runRewriteTest(TEST_NAME1, false);
54+
}
55+
56+
@Test
57+
public void testScalarPlusMatrixMinusScalarRewriteEnabled() {
58+
runRewriteTest(TEST_NAME2, true);
59+
}
60+
61+
@Test
62+
public void testScalarPlusMatrixMinusScalarRewriteDisabled() {
63+
runRewriteTest(TEST_NAME2, false);
64+
}
65+
66+
private void runRewriteTest(String testName, boolean rewriteEnabled) {
67+
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
68+
try {
69+
TestConfiguration config = getTestConfiguration(testName);
70+
loadTestConfiguration(config);
71+
72+
String HOME = SCRIPT_DIR + TEST_DIR;
73+
fullDMLScriptName = HOME + testName + ".dml";
74+
fullRScriptName = HOME + testName + ".R";
75+
programArgs = new String[]{"-stats", "-args", input("A"), input("a"), input("b"), output("R")};
76+
rCmd = getRCmd(inputDir(), expectedDir());
77+
78+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewriteEnabled;
79+
80+
double[][] A = getRandomMatrix(rows, cols, -100, 100, 0.9, 3);
81+
double[][] a = getRandomMatrix(1, 1, -10, 10, 1.0, 7);
82+
double[][] b = getRandomMatrix(1, 1, -10, 10, 1.0, 5);
83+
84+
writeInputMatrixWithMTD("A", A, true);
85+
writeInputMatrixWithMTD("a", a, true);
86+
writeInputMatrixWithMTD("b", b, true);
87+
88+
runTest(true, false, null, -1);
89+
runRScript(true);
90+
91+
HashMap<MatrixValue.CellIndex, Double> dml = readDMLMatrixFromOutputDir("R");
92+
HashMap<MatrixValue.CellIndex, Double> r = readRMatrixFromExpectedDir("R");
93+
TestUtils.compareMatrices(dml, r, eps, "Stat-DML", "Stat-R");
94+
} finally {
95+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
96+
}
97+
}
98+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
args <- commandArgs(TRUE)
22+
library("Matrix")
23+
24+
A <- as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
25+
a <- as.numeric(readMM(paste(args[1], "a.mtx", sep="")))
26+
b <- as.numeric(readMM(paste(args[1], "b.mtx", sep="")))
27+
28+
R <- (a-b)-A
29+
30+
writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""))
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
A = read($1);
22+
a = read($2);
23+
b = read($3);
24+
25+
R = a - A - b;
26+
27+
write(R, $4);
28+
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
args <- commandArgs(TRUE)
22+
library("Matrix")
23+
24+
A <- as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
25+
a <- as.numeric(readMM(paste(args[1], "a.mtx", sep="")))
26+
b <- as.numeric(readMM(paste(args[1], "b.mtx", sep="")))
27+
28+
R <- (a-b)+A
29+
30+
writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""))
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
A = read($1);
22+
a = as.scalar(read($2));
23+
b = as.scalar(read($3));
24+
25+
# Original form: a + A - b
26+
R = a + A - b;
27+
28+
write(R, $4);

0 commit comments

Comments
 (0)