Skip to content

Commit 4b2d83e

Browse files
aarnatymboehm7
authored andcommitted
[SYSTEMDS-3861] Fix redundant transposes due to multi-level rewrites
Closes #2249.
1 parent 693ef52 commit 4b2d83e

File tree

9 files changed

+519
-253
lines changed

9 files changed

+519
-253
lines changed

src/main/java/org/apache/sysds/hops/AggBinaryOp.java

Lines changed: 234 additions & 253 deletions
Large diffs are not rendered by default.

src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,22 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
120
package org.apache.sysds.hops.fedplanner;
221

322
import org.apache.commons.lang3.tuple.Pair;
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.sysds.test.functions.rewrite;
20+
21+
import org.apache.sysds.hops.OptimizerUtils;
22+
import org.apache.sysds.runtime.matrix.data.MatrixValue;
23+
import org.apache.sysds.test.AutomatedTestBase;
24+
import org.apache.sysds.test.TestConfiguration;
25+
import org.apache.sysds.test.TestUtils;
26+
import org.apache.sysds.utils.Statistics;
27+
import org.junit.Assert;
28+
import org.junit.Test;
29+
import java.util.HashMap;
30+
31+
public class RewriteTransposeTest extends AutomatedTestBase {
32+
private final static String TEST_NAME1 = "RewriteTransposeCase1"; // t(X)%*%Y
33+
private final static String TEST_NAME2 = "RewriteTransposeCase2"; // X=t(A); t(X)%*%Y
34+
private final static String TEST_NAME3 = "RewriteTransposeCase3"; // Y=t(A); t(X)%*%Y
35+
36+
private final static String TEST_DIR = "functions/rewrite/";
37+
private static final String TEST_CLASS_DIR = TEST_DIR + RewriteTransposeTest.class.getSimpleName() + "/";
38+
39+
private static final double eps = 1e-9;
40+
41+
@Override
42+
public void setUp() {
43+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION=false;
44+
45+
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"R"}));
46+
addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"R"}));
47+
addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[]{"R"}));
48+
}
49+
50+
@Test
51+
public void testTransposeRewrite1CP() {
52+
runTransposeRewriteTest(TEST_NAME1, false);
53+
}
54+
55+
@Test
56+
public void testTransposeRewrite2CP() {
57+
runTransposeRewriteTest(TEST_NAME2, true);
58+
}
59+
60+
@Test
61+
public void testTransposeRewrite3CP() {
62+
runTransposeRewriteTest(TEST_NAME3, false);
63+
}
64+
65+
private void runTransposeRewriteTest(String testname, boolean expectedMerge) {
66+
TestConfiguration config = getTestConfiguration(testname);
67+
loadTestConfiguration(config);
68+
69+
String HOME = SCRIPT_DIR + TEST_DIR;
70+
fullDMLScriptName = HOME + testname + ".dml";
71+
72+
programArgs = new String[]{"-explain", "-stats", "-args", output("R")};
73+
74+
fullRScriptName = HOME + testname + ".R";
75+
rCmd = getRCmd(expectedDir());
76+
77+
runTest(true, false, null, -1);
78+
runRScript(true);
79+
80+
HashMap<MatrixValue.CellIndex, Double> dmlOutput = readDMLMatrixFromOutputDir("R");
81+
HashMap<MatrixValue.CellIndex, Double> rOutput = readRMatrixFromExpectedDir("R");
82+
TestUtils.compareMatrices(dmlOutput, rOutput, eps, "Stat-DML", "Stat-R");
83+
84+
Assert.assertTrue(Statistics.getCPHeavyHitterCount("r'") <= 2);
85+
}
86+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
args <- commandArgs(TRUE)
23+
24+
library("Matrix")
25+
library("matrixStats")
26+
27+
X <- matrix(seq(1, 20), nrow=4, ncol=5, byrow=TRUE)
28+
Y <- matrix(seq(1, 12), nrow=4, ncol=3, byrow=TRUE)
29+
30+
R <- t(t(Y)%*%X)
31+
32+
writeMM(as(R, "CsparseMatrix"), paste(args[1], "R", sep=""));
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
X = matrix(seq(1, 20), rows=4, cols=5);
23+
Y = matrix(seq(1, 12), rows=4, cols=3);
24+
25+
R = t(X)%*%Y;
26+
27+
write(R, $1);
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
args <- commandArgs(TRUE)
23+
24+
library("Matrix")
25+
library("matrixStats")
26+
A = matrix(seq(1, 20), nrow=5, ncol=4, byrow=TRUE)
27+
Y = matrix(seq(1, 12), nrow=4, ncol=3, byrow=TRUE)
28+
X = t(A)
29+
30+
R <- t(t(Y)%*%X)
31+
32+
writeMM(as(R, "CsparseMatrix"), paste(args[1], "R", sep=""));
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
A = matrix(seq(1, 20), rows=5, cols=4);
23+
Y = matrix(seq(1, 12), rows=4, cols=3);
24+
X = t(A);
25+
26+
R = t(X) %*% Y;
27+
28+
write(R, $1);
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
args <- commandArgs(TRUE)
23+
24+
library("Matrix")
25+
library("matrixStats")
26+
27+
X <- matrix(seq(1, 20), nrow=4, ncol=5, byrow=TRUE)
28+
A <- matrix(seq(1, 12), nrow=3, ncol=4, byrow=TRUE)
29+
Y <- t(A)
30+
31+
R <- t(t(Y)%*%X)
32+
33+
writeMM(as(R, "CsparseMatrix"), paste(args[1], "R", sep=""));
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
X = matrix(seq(1, 20), rows=4, cols=5);
23+
A = matrix(seq(1, 12), rows=3, cols=4);
24+
Y = t(A);
25+
26+
R = t(X) %*% Y;
27+
28+
write(R, $1);

0 commit comments

Comments
 (0)