Skip to content

Commit 3ce16d0

Browse files
aarnatymboehm7
authored andcommitted
[SYSTEMDS-3864] Additional trace simplification rewrites
Closes #2254.
1 parent b2d4e24 commit 3ce16d0

File tree

8 files changed

+355
-2
lines changed

8 files changed

+355
-2
lines changed

src/main/java/org/apache/sysds/hops/AggBinaryOp.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
import org.apache.sysds.lops.PMMJ;
4444
import org.apache.sysds.lops.PMapMult;
4545
import org.apache.sysds.lops.Transform;
46-
import org.apache.sysds.runtime.DMLRuntimeException;
4746
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
4847
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
4948
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
176176
if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
177177
hi = fuseBinarySubDAGToUnaryOperation(hop, hi, i); //e.g., X*(1-X)-> sprop(X) || 1/(1+exp(-X)) -> sigmoid(X) || X*(X>0) -> selp(X)
178178
hi = simplifyTraceMatrixMult(hop, hi, i); //e.g., trace(X%*%Y)->sum(X*t(Y));
179+
hi = simplifyTraceSum(hop, hi, i); //e.g. , trace(A+B)->trace(A)+trace(B);
180+
hi = simplifyTraceTranspose(hop, hi, i); //e.g. , trace(t(A))->trace(A)
179181
hi = simplifySlicedMatrixMult(hop, hi, i); //e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1];
180182
hi = simplifyListIndexing(hi); //e.g., L[i:i, 1:ncol(L)] -> L[i:i, 1:1]
181183
hi = simplifyScalarIndexing(hop, hi, i); //e.g., as.scalar(X[i,1])->X[i,1] w/ scalar output
@@ -201,7 +203,6 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
201203
hi = simplifyNotOverComparisons(hop, hi, i); //e.g., !(A>B) -> (A<=B)
202204
//hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
203205

204-
205206
//process childs recursively after rewrites (to investigate pattern newly created by rewrites)
206207
if( !descendFirst )
207208
rule_AlgebraicSimplification(hi, descendFirst);
@@ -1603,6 +1604,45 @@ private static Hop simplifyTraceMatrixMult(Hop parent, Hop hi, int pos)
16031604
return hi;
16041605
}
16051606

1607+
private static Hop simplifyTraceSum(Hop parent, Hop hi, int pos) {
1608+
if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.TRACE) {
1609+
Hop hi2 = hi.getInput().get(0);
1610+
if (HopRewriteUtils.isBinary(hi2, OpOp2.PLUS) && hi2.getParent().size() == 1) {
1611+
Hop left = hi2.getInput().get(0);
1612+
Hop right = hi2.getInput().get(1);
1613+
1614+
// Create trace nodes
1615+
AggUnaryOp traceLeft = HopRewriteUtils.createAggUnaryOp(left, AggOp.TRACE, Direction.RowCol);
1616+
AggUnaryOp traceRight = HopRewriteUtils.createAggUnaryOp(right, AggOp.TRACE, Direction.RowCol);
1617+
1618+
// Add them
1619+
BinaryOp sum = HopRewriteUtils.createBinary(traceLeft, traceRight, OpOp2.PLUS);
1620+
1621+
// Replace in DAG
1622+
HopRewriteUtils.replaceChildReference(parent, hi, sum, pos);
1623+
HopRewriteUtils.cleanupUnreferenced(hi, hi2);
1624+
1625+
LOG.debug("Applied simplifyTraceSum rewrite");
1626+
return sum;
1627+
}
1628+
}
1629+
return hi;
1630+
}
1631+
1632+
private static Hop simplifyTraceTranspose(Hop parent, Hop hi, int pos) {
1633+
// Check if the current Hop is a trace operation
1634+
if ( HopRewriteUtils.isAggUnaryOp(hi, AggOp.TRACE) ) {
1635+
Hop input = hi.getInput().get(0);
1636+
1637+
// Check if input is a transpose and it is only consumer
1638+
if (HopRewriteUtils.isReorg(input, ReOrgOp.TRANS) && input.getParent().size() == 1) {
1639+
HopRewriteUtils.replaceChildReference(hi, input, input.getInput(0));
1640+
LOG.debug("Applied simplifyTraceTranspose rewrite");
1641+
}
1642+
}
1643+
return hi;
1644+
}
1645+
16061646
private static Hop simplifySlicedMatrixMult(Hop parent, Hop hi, int pos)
16071647
{
16081648
//e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1]
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.sysds.test.functions.rewrite;
20+
21+
import org.apache.sysds.hops.OptimizerUtils;
22+
import org.apache.sysds.runtime.matrix.data.MatrixValue;
23+
import org.apache.sysds.test.AutomatedTestBase;
24+
import org.apache.sysds.test.TestConfiguration;
25+
import org.apache.sysds.test.TestUtils;
26+
import org.apache.sysds.utils.Statistics;
27+
import org.junit.Assert;
28+
import org.junit.Test;
29+
30+
import java.util.HashMap;
31+
32+
public class RewriteSimplifyTraceSumTest extends AutomatedTestBase {
33+
private static final String TEST_NAME = "RewriteSimplifyTraceSum";
34+
private static final String TEST_DIR = "functions/rewrite/";
35+
private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyTraceSumTest.class.getSimpleName() + "/";
36+
37+
private static final int rows = 500;
38+
private static final int cols = 500;
39+
private static final double eps = 1e-10;
40+
41+
@Override
42+
public void setUp() {
43+
TestUtils.clearAssertionInformation();
44+
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"R"}));
45+
}
46+
47+
@Test
48+
public void testSimplifyTraceSumRewrite() {
49+
runTraceRewriteTest(TEST_NAME, true);
50+
}
51+
52+
@Test
53+
public void testSimplifyTraceSumNoRewrite() {
54+
runTraceRewriteTest(TEST_NAME, false);
55+
}
56+
57+
private void runTraceRewriteTest(String testname, boolean rewrites) {
58+
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
59+
try {
60+
TestConfiguration config = getTestConfiguration(testname);
61+
loadTestConfiguration(config);
62+
63+
String HOME = SCRIPT_DIR + TEST_DIR;
64+
fullDMLScriptName = HOME + testname + ".dml";
65+
fullRScriptName = HOME + testname + ".R";
66+
67+
programArgs = new String[]{"-explain", "-stats", "-args", input("A"), input("B"), output("R")};
68+
rCmd = getRCmd(inputDir(), expectedDir());
69+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
70+
double[][] A = getRandomMatrix(rows, cols, -1, 1, 0.70d, 7);
71+
double[][] B = getRandomMatrix(cols, rows, -1, 1, 0.70d, 6);
72+
writeInputMatrixWithMTD("A", A, true);
73+
writeInputMatrixWithMTD("B", B, true);
74+
// Run SystemDS and R scripts
75+
runTest(true, false, null, -1);
76+
runRScript(true);
77+
78+
// Compare DML and R outputs
79+
HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLScalarFromOutputDir("R");
80+
HashMap<MatrixValue.CellIndex, Double> rfile = readRScalarFromExpectedDir("R");
81+
82+
// Ensure they're equal (within tolerance)
83+
TestUtils.compareMatrices(dmlfile, rfile, eps, "DMLResult", "RResult");
84+
Assert.assertEquals(rewrites?2:1, Statistics.getCPHeavyHitterCount("uaktrace"));
85+
} finally {
86+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
87+
}
88+
}
89+
}
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.sysds.test.functions.rewrite;
20+
21+
import org.apache.sysds.hops.OptimizerUtils;
22+
import org.apache.sysds.runtime.matrix.data.MatrixValue;
23+
import org.apache.sysds.test.AutomatedTestBase;
24+
import org.apache.sysds.test.TestConfiguration;
25+
import org.apache.sysds.test.TestUtils;
26+
import org.junit.Assert;
27+
import org.junit.Test;
28+
29+
import java.util.HashMap;
30+
31+
public class RewriteSimplifyTraceTransposeTest extends AutomatedTestBase {
32+
private static final String TEST_NAME = "RewriteSimplifyTraceTranspose";
33+
private static final String TEST_DIR = "functions/rewrite/";
34+
private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyTraceTransposeTest.class.getSimpleName() + "/";
35+
36+
private static final int rows = 100;
37+
private static final int cols = 100;
38+
private static final double eps = 1e-6;
39+
40+
@Override
41+
public void setUp() {
42+
TestUtils.clearAssertionInformation();
43+
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"R"}));
44+
}
45+
46+
@Test
47+
public void testRewriteEnabled() {
48+
runRewriteTest(true);
49+
}
50+
51+
@Test
52+
public void testRewriteDisabled() {
53+
runRewriteTest(false);
54+
}
55+
56+
private void runRewriteTest(boolean rewriteEnabled) {
57+
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
58+
try {
59+
TestConfiguration config = getTestConfiguration(TEST_NAME);
60+
loadTestConfiguration(config);
61+
62+
String HOME = SCRIPT_DIR + TEST_DIR;
63+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
64+
fullRScriptName = HOME + TEST_NAME + ".R";
65+
programArgs = new String[]{"-stats", "-args", input("A"), output("R")};
66+
rCmd = getRCmd(inputDir(), expectedDir());
67+
68+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewriteEnabled;
69+
double[][] A = getRandomMatrix(rows, cols, -1, 1, 0.70d, 7);
70+
writeInputMatrixWithMTD("A", A, true);
71+
runTest(true, false, null, -1);
72+
runRScript(true);
73+
74+
// Read DML scalar output
75+
HashMap<MatrixValue.CellIndex, Double> dmlMap = readDMLScalarFromOutputDir("R");
76+
double dmlTrace = dmlMap.get(new MatrixValue.CellIndex(1, 1));
77+
78+
// Read R scalar output
79+
HashMap<MatrixValue.CellIndex, Double> rMap = readRScalarFromExpectedDir("R");
80+
double rTrace = rMap.get(new MatrixValue.CellIndex(1, 1));
81+
82+
// Compare the scalar values within the given tolerance
83+
Assert.assertEquals("Trace result mismatch", rTrace, dmlTrace, eps);
84+
Assert.assertTrue(heavyHittersContainsString("r'")!=rewriteEnabled);
85+
}
86+
finally {
87+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
88+
}
89+
}
90+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
args <- commandArgs(TRUE)
22+
23+
# Set options for numeric precision
24+
options(digits=22)
25+
26+
library("Matrix")
27+
library("matrixStats")
28+
29+
A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
30+
B = as.matrix(readMM(paste(args[1], "B.mtx", sep="")))
31+
32+
# Perform the matrix operation
33+
R = sum(diag(A))+sum(diag(B))
34+
35+
# Write the result scalar R
36+
write(R, paste(args[2], "R" ,sep=""))
37+
38+
39+
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
# Load matrices A, B
22+
A = read($1)
23+
B = read($2)
24+
25+
# Perform the operation
26+
R = trace(A+B)
27+
28+
# Write the result R
29+
write(R, $3)
30+
31+
32+
33+
34+
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
args <- commandArgs(TRUE)
22+
23+
library("Matrix")
24+
library("matrixStats")
25+
26+
A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
27+
28+
R <- sum(diag(t(A)))
29+
30+
# Write the result scalar R
31+
write(R, paste(args[2], "R" ,sep=""))
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
# Read input matrix A
22+
A = read($1);
23+
24+
# Compute trace of transpose
25+
result = trace(t(A));
26+
27+
# Write scalar result to output
28+
write(result, $2);
29+
30+
31+

0 commit comments

Comments
 (0)