Skip to content

Commit baacab4

Browse files
committed
[SYSTEMDS-3908] Fix OOC matmult compilation w/ transpose rewrite
In CP, we rewrite t(X)%*%y to t(t(y)%*%X) if the two transposes are much smaller and especially if they are vectors because vector transpose is a meta data operation. However, if y is an OOC stream, this rewrite destroyed the pipeline (and incomplete exception handling and other primitives) made the resulting issue hard to debug.
1 parent 01ec1d7 commit baacab4

File tree

3 files changed

+7
-8
lines changed

3 files changed

+7
-8
lines changed

src/main/java/org/apache/sysds/hops/AggBinaryOp.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,8 @@ private Lop constructCPLopsMMWithLeftTransposeRewrite(ExecType et) {
624624

625625
//Handle Y or actualY for transpose
626626
Lop yLop = isYTransposed ? actualY.constructLops() : Y.constructLops();
627-
ExecType inputReorgExecType = (Y.hasFederatedOutput()) ? ExecType.FED : ExecType.CP;
627+
ExecType inputReorgExecType = (Y.hasFederatedOutput()) ? ExecType.FED :
628+
(et==ExecType.OOC) ? ExecType.OOC : ExecType.CP;
628629

629630
//right vector transpose
630631
Lop tY = (yLop instanceof Transform && ((Transform)yLop).getOp() == ReOrgOp.TRANS) ?

src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ private void findRewriteCandidates(Hop hop) {
139139
if (DMLScript.USE_OOC
140140
&& hop.getDataType().isMatrix()
141141
&& !HopRewriteUtils.isData(hop, OpOpData.TEE)
142-
&& hop.getParent().size() > 1)
142+
&& hop.getParent().size() > 1
143+
&& isSelfTranposePattern(hop)) //FIXME remove
143144
{
144145
rewriteCandidates.add(hop);
145146
}

src/test/java/org/apache/sysds/test/functions/ooc/lmDSTest.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
import org.apache.sysds.test.TestConfiguration;
3434
import org.apache.sysds.test.TestUtils;
3535
import org.junit.Assert;
36-
import org.junit.Ignore;
3736
import org.junit.Test;
3837

3938
import java.io.IOException;
@@ -47,7 +46,7 @@ public class lmDSTest extends AutomatedTestBase {
4746
private static final String INPUT_NAME2 = "y";
4847
private static final String OUTPUT_NAME = "R";
4948

50-
private final static int rows = 100000;
49+
private final static int rows = 10000;
5150
private final static int cols_wide = 500; //TODO larger than 1000
5251
private final static int cols_skinny = 10;
5352

@@ -59,13 +58,11 @@ public void setUp() {
5958
}
6059

6160
@Test
62-
@Ignore //FIXME
6361
public void testlmDS1() {
6462
runMatrixVectorMultiplicationTest(cols_wide);
6563
}
6664

6765
@Test
68-
@Ignore //FIXME
6966
public void testlmDS2() {
7067
runMatrixVectorMultiplicationTest(cols_skinny);
7168
}
@@ -80,7 +77,7 @@ private void runMatrixVectorMultiplicationTest(int cols)
8077
String HOME = SCRIPT_DIR + TEST_DIR;
8178
fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
8279
programArgs = new String[]{"-explain", "-stats", "-ooc",
83-
"-args", input(INPUT_NAME), input(INPUT_NAME2), output(OUTPUT_NAME)};
80+
"-args", input(INPUT_NAME), input(INPUT_NAME2), output(OUTPUT_NAME)};
8481

8582
// 1. Generate the data in-memory as MatrixBlock objects
8683
double[][] X_data = getRandomMatrix(rows, cols, 0, 1, 1.0, 7);
@@ -105,7 +102,7 @@ private void runMatrixVectorMultiplicationTest(int cols)
105102

106103
runTest(true, false, null, -1);
107104
MatrixBlock C = DataConverter.readMatrixFromHDFS(
108-
output(OUTPUT_NAME), Types.FileFormat.BINARY, rows, cols, 1000, 1000);
105+
output(OUTPUT_NAME), Types.FileFormat.BINARY, cols, 1, 1000, 1000);
109106

110107
//expected results
111108
MatrixBlock xtx = LibMatrixMult.matrixMultTransposeSelf(X_mb, new MatrixBlock(cols,cols,false), true);

0 commit comments

Comments
 (0)