Skip to content

Commit deeb51f

Browse files
committed
test for lmds
1 parent 4fb9ee2 commit deeb51f

File tree

5 files changed

+260
-2
lines changed

5 files changed

+260
-2
lines changed

src/main/java/org/apache/sysds/hops/Hop.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ else if ( DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE && _etypeForced
265265
if(_etypeForced != ExecType.CP && _etypeForced != ExecType.GPU)
266266
_etypeForced = ExecType.CP;
267267
}
268-
else if (DMLScript.USE_OOC){
268+
else if (DMLScript.USE_OOC && !(this instanceof BinaryOp)){
269269
_etypeForced = ExecType.OOC;
270270
}
271271
else {

src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.apache.sysds.common.Types.OpOpData;
2525
import org.apache.sysds.hops.DataOp;
2626
import org.apache.sysds.hops.Hop;
27+
import org.apache.sysds.hops.ReorgOp;
2728

2829
import java.util.ArrayList;
2930
import java.util.HashMap;
@@ -138,7 +139,9 @@ private void findRewriteCandidates(Hop hop) {
138139
if (DMLScript.USE_OOC
139140
&& hop.getDataType().isMatrix()
140141
&& !HopRewriteUtils.isData(hop, OpOpData.TEE)
141-
&& hop.getParent().size() > 1)
142+
&& hop.getParent().size() > 1
143+
&& isSelfTranposePattern(hop)
144+
)
142145
{
143146
rewriteCandidates.add(hop);
144147
}
@@ -174,4 +177,22 @@ private void applyTopDownTeeRewrite(Hop sharedInput) {
174177
handledHop.put(sharedInput.getHopID(), teeOp);
175178
rewrittenHops.add(sharedInput.getHopID());
176179
}
180+
181+
private boolean isSelfTranposePattern (Hop hop) {
182+
boolean hasTransposeConsumer = false; // t(X)
183+
boolean hasMatrixMultiplyConsumer = false; // %*%
184+
185+
for (Hop parent: hop.getParent()) {
186+
String opString = parent.getOpString();
187+
if (parent instanceof ReorgOp) {
188+
if (HopRewriteUtils.isTransposeOperation(parent)) {
189+
hasTransposeConsumer = true;
190+
}
191+
}
192+
else if (HopRewriteUtils.isMatrixMultiply(parent)) {
193+
hasMatrixMultiplyConsumer = true;
194+
}
195+
}
196+
return hasTransposeConsumer && hasMatrixMultiplyConsumer;
197+
}
177198
}
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.ooc;
21+
22+
import org.apache.sysds.common.Types;
23+
import org.apache.sysds.runtime.io.MatrixWriter;
24+
import org.apache.sysds.runtime.io.MatrixWriterFactory;
25+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
26+
import org.apache.sysds.runtime.matrix.data.MatrixValue;
27+
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
28+
import org.apache.sysds.runtime.util.DataConverter;
29+
import org.apache.sysds.runtime.util.HDFSTool;
30+
import org.apache.sysds.test.AutomatedTestBase;
31+
import org.apache.sysds.test.TestConfiguration;
32+
import org.apache.sysds.test.TestUtils;
33+
import org.junit.Assert;
34+
import org.junit.Test;
35+
36+
import java.io.IOException;
37+
import java.util.HashMap;
38+
import java.util.Random;
39+
40+
public class lmDSTest extends AutomatedTestBase {
41+
private final static String TEST_NAME1 = "lmDS";
42+
private final static String TEST_DIR = "functions/ooc/";
43+
private final static String TEST_CLASS_DIR = TEST_DIR + lmDSTest.class.getSimpleName() + "/";
44+
private final static double eps = 1e-10;
45+
private static final String INPUT_NAME = "X";
46+
private static final String INPUT_NAME2 = "y";
47+
private static final String OUTPUT_NAME = "R";
48+
49+
private final static int rows = 100000;
50+
private final static int cols_wide = 500;
51+
private final static int cols_skinny = 500;
52+
53+
private final static double sparsity1 = 0.7;
54+
private final static double sparsity2 = 0.1;
55+
56+
@Override
57+
public void setUp() {
58+
TestUtils.clearAssertionInformation();
59+
TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1);
60+
addTestConfiguration(TEST_NAME1, config);
61+
}
62+
63+
@Test
64+
public void testlmDS1() {
65+
runMatrixVectorMultiplicationTest(cols_wide, false);
66+
}
67+
68+
@Test
69+
public void testlmDS2() {
70+
runMatrixVectorMultiplicationTest(cols_skinny, false);
71+
}
72+
73+
private void runMatrixVectorMultiplicationTest(int cols, boolean sparse )
74+
{
75+
Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE);
76+
77+
try
78+
{
79+
getAndLoadTestConfiguration(TEST_NAME1);
80+
String HOME = SCRIPT_DIR + TEST_DIR;
81+
fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
82+
programArgs = new String[]{"-explain", "-stats", "-ooc",
83+
"-args", input(INPUT_NAME), input(INPUT_NAME2), output(OUTPUT_NAME)};
84+
85+
// 1. Generate the data in-memory as MatrixBlock objects
86+
double[][] A_data = getRandomMatrix(rows, cols, 0, 1, sparse?sparsity2:sparsity1, 7);
87+
// double[][] A_data = generateFullRankMatrix(rows, cols, 10L);
88+
double[][] x_data = getRandomMatrix(rows, 1, 0, 1, 1.0, 3);
89+
// double[][] x_data = getRandomMatrix(rows, 1, 0, 1, 1.0, 20L);
90+
91+
// 2. Convert the double arrays to MatrixBlock objects
92+
MatrixBlock A_mb = DataConverter.convertToMatrixBlock(A_data);
93+
MatrixBlock x_mb = DataConverter.convertToMatrixBlock(x_data);
94+
95+
// 3. Create a binary matrix writer
96+
MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
97+
98+
// 4. Write matrix A to a binary SequenceFile
99+
writer.writeMatrixToHDFS(A_mb, input(INPUT_NAME), rows, cols, 1000, A_mb.getNonZeros());
100+
HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"), Types.ValueType.FP64,
101+
new MatrixCharacteristics(rows, cols, 1000, A_mb.getNonZeros()), Types.FileFormat.BINARY);
102+
103+
// 5. Write vector x to a binary SequenceFile
104+
writer.writeMatrixToHDFS(x_mb, input(INPUT_NAME2), rows, 1, 1000, x_mb.getNonZeros());
105+
HDFSTool.writeMetaDataFile(input(INPUT_NAME2 + ".mtd"), Types.ValueType.FP64,
106+
new MatrixCharacteristics(rows, 1, 1000, x_mb.getNonZeros()), Types.FileFormat.BINARY);
107+
108+
fullRScriptName = HOME + TEST_NAME1 + ".R";
109+
rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir();
110+
111+
boolean exceptionExpected = false;
112+
runTest(true, exceptionExpected, null, -1);
113+
// runRScript(true);
114+
115+
// HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir(OUTPUT_NAME);
116+
117+
double[][] C1 = readMatrix(output(OUTPUT_NAME), Types.FileFormat.BINARY, rows, cols, 1000, 1000);
118+
double result = 0.0;
119+
for(int i = 0; i < 100; i++) { // verify the results with Java
120+
double expected = 0.0;
121+
for(int j = 0; j < 100; j++) {
122+
expected += A_mb.get(i, j) * x_mb.get(j,0);
123+
}
124+
result = C1[i][0];
125+
System.out.println("(i): " + i + " ->> expected" + expected + ", result: " + result);
126+
// Assert.assertEquals(expected, result, eps);
127+
}
128+
}
129+
catch (IOException e) {
130+
throw new RuntimeException(e);
131+
}
132+
finally {
133+
resetExecMode(platformOld);
134+
}
135+
}
136+
137+
private static double[][] readMatrix(String fname, Types.FileFormat fmt, long rows, long cols, int brows, int bcols )
138+
throws IOException
139+
{
140+
MatrixBlock mb = DataConverter.readMatrixFromHDFS(fname, fmt, rows, cols, brows, bcols);
141+
double[][] C = DataConverter.convertToDoubleMatrix(mb);
142+
return C;
143+
}
144+
145+
/**
146+
* Generates a matrix that is guaranteed to have full column rank,
147+
* preventing a singular t(X)%*%X matrix.
148+
*
149+
* @param rows Number of rows
150+
* @param cols Number of columns (must be <= rows)
151+
* @param seed Random seed
152+
* @return A new double[][] matrix
153+
*/
154+
private double[][] generateFullRankMatrix(int rows, int cols, long seed) {
155+
if (cols > rows) {
156+
throw new IllegalArgumentException("For a full-rank matrix, cols must be <= rows.");
157+
}
158+
double[][] A = new double[rows][cols];
159+
Random rand = new Random(seed);
160+
161+
// 1. Create a dominant diagonal by starting with an identity-like structure
162+
for (int i = 0; i < cols; i++) {
163+
A[i][i] = 1.0;
164+
}
165+
166+
// 2. Add small random noise to all other elements to ensure non-singularity
167+
for (int i = 0; i < rows; i++) {
168+
for (int j = 0; j < cols; j++) {
169+
if (i != j) { // Don't overwrite the dominant diagonal
170+
A[i][j] = rand.nextDouble() * 0.1; // Small noise
171+
}
172+
}
173+
}
174+
return A;
175+
}
176+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
args<-commandArgs(TRUE)
23+
options(digits=22)
24+
library("Matrix")
25+
26+
X = as.matrix(readMM(paste(args[1], "X.mtd", sep="")))
27+
y = as.matrix(readMM(paste(args[1], "y.mtd", sep="")))
28+
# C = lm.fit(X, y)$coefficients
29+
XtX <- t(X) %*% X
30+
Xty <- t(X) %*% y
31+
R <- solve(XtX, Xty)
32+
33+
writeMM(as(R, "CsparseMatrix"), paste(args[2], "C", sep=""))
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
X = read($1)
23+
y = read($2)
24+
25+
XtX = t(X) %*% X; # 500 x 10000 -- 10000 x 500 == 500 x 500
26+
Xty = t(X) %*% y; # 500 x 10000 -- 10000 x 1 == 500 x 1
27+
R = solve(XtX, Xty)
28+
write(R, $3, format="binary")

0 commit comments

Comments
 (0)