Skip to content

Commit 01ec1d7

Browse files
j143mboehm7
authored andcommitted
[SYSTEMDS-3907] lmDS Algorithm Test for OOC Backend
Closes #2338.
1 parent c8d5460 commit 01ec1d7

File tree

5 files changed

+176
-3
lines changed

5 files changed

+176
-3
lines changed

src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.apache.sysds.common.Types.OpOpData;
2525
import org.apache.sysds.hops.DataOp;
2626
import org.apache.sysds.hops.Hop;
27+
import org.apache.sysds.hops.ReorgOp;
2728

2829
import java.util.ArrayList;
2930
import java.util.HashMap;
@@ -138,7 +139,7 @@ private void findRewriteCandidates(Hop hop) {
138139
if (DMLScript.USE_OOC
139140
&& hop.getDataType().isMatrix()
140141
&& !HopRewriteUtils.isData(hop, OpOpData.TEE)
141-
&& hop.getParent().size() > 1)
142+
&& hop.getParent().size() > 1)
142143
{
143144
rewriteCandidates.add(hop);
144145
}
@@ -174,4 +175,22 @@ private void applyTopDownTeeRewrite(Hop sharedInput) {
174175
handledHop.put(sharedInput.getHopID(), teeOp);
175176
rewrittenHops.add(sharedInput.getHopID());
176177
}
178+
179+
@SuppressWarnings("unused")
180+
private boolean isSelfTranposePattern (Hop hop) {
181+
boolean hasTransposeConsumer = false; // t(X)
182+
boolean hasMatrixMultiplyConsumer = false; // %*%
183+
184+
for (Hop parent: hop.getParent()) {
185+
if (parent instanceof ReorgOp) {
186+
if (HopRewriteUtils.isTransposeOperation(parent)) {
187+
hasTransposeConsumer = true;
188+
}
189+
}
190+
else if (HopRewriteUtils.isMatrixMultiply(parent)) {
191+
hasMatrixMultiplyConsumer = true;
192+
}
193+
}
194+
return hasTransposeConsumer && hasMatrixMultiplyConsumer;
195+
}
177196
}

src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ public static MatrixBlock matrixMatrixOperations(MatrixBlock in1, MatrixBlock in
201201
* @param in2 matrix object 2
202202
* @return matrix block
203203
*/
204-
private static MatrixBlock computeSolve(MatrixBlock in1, MatrixBlock in2) {
204+
public static MatrixBlock computeSolve(MatrixBlock in1, MatrixBlock in2) {
205205
//convert to commons math BlockRealMatrix instead of Array2DRowRealMatrix
206206
//to avoid unnecessary conversion as QR internally creates a BlockRealMatrix
207207
BlockRealMatrix matrixInput = DataConverter.convertToBlockRealMatrix(in1);

src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,9 @@ public static void matrixMultChain(MatrixBlock mX, MatrixBlock mV, MatrixBlock m
452452
// "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop());
453453
}
454454

455-
public static void matrixMultTransposeSelf( MatrixBlock m1, MatrixBlock ret, boolean leftTranspose ) {
455+
public static MatrixBlock matrixMultTransposeSelf( MatrixBlock m1, MatrixBlock ret, boolean leftTranspose ) {
456456
matrixMultTransposeSelf(m1, ret, leftTranspose, true);
457+
return ret;
457458
}
458459

459460
public static void matrixMultTransposeSelf(MatrixBlock m1, MatrixBlock ret, boolean leftTranspose, boolean copyToLowerTriangle){
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.ooc;
21+
22+
import org.apache.sysds.common.Types;
23+
import org.apache.sysds.runtime.io.MatrixWriter;
24+
import org.apache.sysds.runtime.io.MatrixWriterFactory;
25+
import org.apache.sysds.runtime.matrix.data.LibCommonsMath;
26+
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
27+
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
28+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
29+
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
30+
import org.apache.sysds.runtime.util.DataConverter;
31+
import org.apache.sysds.runtime.util.HDFSTool;
32+
import org.apache.sysds.test.AutomatedTestBase;
33+
import org.apache.sysds.test.TestConfiguration;
34+
import org.apache.sysds.test.TestUtils;
35+
import org.junit.Assert;
36+
import org.junit.Ignore;
37+
import org.junit.Test;
38+
39+
import java.io.IOException;
40+
41+
public class lmDSTest extends AutomatedTestBase {
42+
private final static String TEST_NAME1 = "lmDS";
43+
private final static String TEST_DIR = "functions/ooc/";
44+
private final static String TEST_CLASS_DIR = TEST_DIR + lmDSTest.class.getSimpleName() + "/";
45+
private final static double eps = 1e-10;
46+
private static final String INPUT_NAME = "X";
47+
private static final String INPUT_NAME2 = "y";
48+
private static final String OUTPUT_NAME = "R";
49+
50+
private final static int rows = 100000;
51+
private final static int cols_wide = 500; //TODO larger than 1000
52+
private final static int cols_skinny = 10;
53+
54+
@Override
55+
public void setUp() {
56+
TestUtils.clearAssertionInformation();
57+
TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1);
58+
addTestConfiguration(TEST_NAME1, config);
59+
}
60+
61+
@Test
62+
@Ignore //FIXME
63+
public void testlmDS1() {
64+
runMatrixVectorMultiplicationTest(cols_wide);
65+
}
66+
67+
@Test
68+
@Ignore //FIXME
69+
public void testlmDS2() {
70+
runMatrixVectorMultiplicationTest(cols_skinny);
71+
}
72+
73+
private void runMatrixVectorMultiplicationTest(int cols)
74+
{
75+
Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE);
76+
77+
try
78+
{
79+
getAndLoadTestConfiguration(TEST_NAME1);
80+
String HOME = SCRIPT_DIR + TEST_DIR;
81+
fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
82+
programArgs = new String[]{"-explain", "-stats", "-ooc",
83+
"-args", input(INPUT_NAME), input(INPUT_NAME2), output(OUTPUT_NAME)};
84+
85+
// 1. Generate the data in-memory as MatrixBlock objects
86+
double[][] X_data = getRandomMatrix(rows, cols, 0, 1, 1.0, 7);
87+
double[][] y_data = getRandomMatrix(rows, 1, 0, 1, 1.0, 3);
88+
89+
// 2. Convert the double arrays to MatrixBlock objects
90+
MatrixBlock X_mb = DataConverter.convertToMatrixBlock(X_data);
91+
MatrixBlock y_mb = DataConverter.convertToMatrixBlock(y_data);
92+
93+
// 3. Create a binary matrix writer
94+
MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
95+
96+
// 4. Write matrix A to a binary SequenceFile
97+
writer.writeMatrixToHDFS(X_mb, input(INPUT_NAME), rows, cols, 1000, X_mb.getNonZeros());
98+
HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"), Types.ValueType.FP64,
99+
new MatrixCharacteristics(rows, cols, 1000, X_mb.getNonZeros()), Types.FileFormat.BINARY);
100+
101+
// 5. Write vector x to a binary SequenceFile
102+
writer.writeMatrixToHDFS(y_mb, input(INPUT_NAME2), rows, 1, 1000, y_mb.getNonZeros());
103+
HDFSTool.writeMetaDataFile(input(INPUT_NAME2 + ".mtd"), Types.ValueType.FP64,
104+
new MatrixCharacteristics(rows, 1, 1000, y_mb.getNonZeros()), Types.FileFormat.BINARY);
105+
106+
runTest(true, false, null, -1);
107+
MatrixBlock C = DataConverter.readMatrixFromHDFS(
108+
output(OUTPUT_NAME), Types.FileFormat.BINARY, rows, cols, 1000, 1000);
109+
110+
//expected results
111+
MatrixBlock xtx = LibMatrixMult.matrixMultTransposeSelf(X_mb, new MatrixBlock(cols,cols,false), true);
112+
MatrixBlock xt = LibMatrixReorg.transpose(X_mb);
113+
MatrixBlock xty = LibMatrixMult.matrixMult(xt, y_mb);
114+
MatrixBlock ret = LibCommonsMath.computeSolve(xtx, xty);
115+
for(int i = 0; i < cols; i++)
116+
Assert.assertEquals(ret.get(i, 0), C.get(i,0), eps);
117+
}
118+
catch (IOException e) {
119+
throw new RuntimeException(e);
120+
}
121+
finally {
122+
resetExecMode(platformOld);
123+
}
124+
}
125+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
X = read($1)
23+
y = read($2)
24+
25+
XtX = t(X) %*% X; # 500 x 500
26+
Xty = t(X) %*% y; # 500 x 1
27+
R = solve(XtX, Xty)
28+
write(R, $3, format="binary")

0 commit comments

Comments
 (0)