Skip to content

Commit 7c504d0

Browse files
j143mboehm7
authored andcommitted
[SYSTEMDS-3914] New out-of-core transpose-self matmult instruction
Closes #2323.
1 parent df26e67 commit 7c504d0

File tree

6 files changed

+246
-10
lines changed

6 files changed

+246
-10
lines changed

src/main/java/org/apache/sysds/hops/AggBinaryOp.java

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ public Lop constructLops() {
179179
et = ExecType.CP;
180180
}
181181

182-
if (et == ExecType.CP || et == ExecType.GPU || et == ExecType.FED) {
182+
if (et == ExecType.CP || et == ExecType.GPU || et == ExecType.FED || et == ExecType.OOC) {
183183
//matrix mult operation selection part 3 (CP type)
184184
_method = optFindMMultMethodCP(input1.getDim1(), input1.getDim2(),
185185
input2.getDim1(), input2.getDim2(), mmtsj, chain, _hasLeftPMInput);
@@ -240,14 +240,6 @@ public Lop constructLops() {
240240
default:
241241
throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + _method + ") while constructing SPARK lops.");
242242
}
243-
} else if (et == ExecType.OOC) {
244-
Lop in1 = getInput().get(0).constructLops();
245-
Lop in2 = getInput().get(1).constructLops();
246-
MatMultCP matmult = new MatMultCP(in1, in2, getDataType(), getValueType(),
247-
et, OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
248-
setOutputDimensions(matmult);
249-
setLineNumbers(matmult);
250-
setLops(matmult);
251243
}
252244
} else
253245
throw new HopsException(this.printErrorLocation() + "Invalid operation in AggBinary Hop, aggBin(" + innerOp + "," + outerOp + ") while constructing lops.");

src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.apache.sysds.runtime.instructions.ooc.BinaryOOCInstruction;
2828
import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
2929
import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
30+
import org.apache.sysds.runtime.instructions.ooc.TSMMOOCInstruction;
3031
import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction;
3132
import org.apache.sysds.runtime.instructions.ooc.MatrixVectorBinaryOOCInstruction;
3233
import org.apache.sysds.runtime.instructions.ooc.TransposeOOCInstruction;
@@ -61,6 +62,8 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str
6162
case AggregateBinary:
6263
case MAPMM:
6364
return MatrixVectorBinaryOOCInstruction.parseInstruction(str);
65+
case MMTSJ:
66+
return TSMMOOCInstruction.parseInstruction(str);
6467
case Reorg:
6568
return TransposeOOCInstruction.parseInstruction(str);
6669

src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public abstract class OOCInstruction extends Instruction {
3333
protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName());
3434

3535
public enum OOCType {
36-
Reblock, AggregateUnary, Binary, Unary, MAPMM, Reorg, AggregateBinary
36+
Reblock, AggregateUnary, Binary, Unary, MAPMM, Reorg, AggregateBinary, MMTSJ
3737
}
3838

3939
protected final OOCInstruction.OOCType _ooctype;
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.instructions.ooc;
21+
22+
import org.apache.sysds.common.Opcodes;
23+
import org.apache.sysds.lops.MMTSJ;
24+
import org.apache.sysds.lops.MMTSJ.MMTSJType;
25+
import org.apache.sysds.runtime.DMLRuntimeException;
26+
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
27+
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
28+
import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
29+
import org.apache.sysds.runtime.functionobjects.Multiply;
30+
import org.apache.sysds.runtime.functionobjects.Plus;
31+
import org.apache.sysds.runtime.instructions.InstructionUtils;
32+
import org.apache.sysds.runtime.instructions.cp.CPOperand;
33+
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
34+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
35+
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
36+
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
37+
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
38+
import org.apache.sysds.runtime.matrix.operators.Operator;
39+
40+
public class TSMMOOCInstruction extends ComputationOOCInstruction {
41+
private final MMTSJType _type;
42+
43+
protected TSMMOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand out, MMTSJ.MMTSJType mmtsjType, String opcode, String istr) {
44+
super(type, op, in1, out, opcode, istr);
45+
_type = mmtsjType;
46+
}
47+
48+
public static TSMMOOCInstruction parseInstruction(String str) {
49+
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
50+
InstructionUtils.checkNumFields(parts, 3);
51+
String opcode = parts[0];
52+
CPOperand in1 = new CPOperand(parts[1]); // the large matrix (streamed), columns <= blocksize
53+
CPOperand out = new CPOperand(parts[2]);
54+
MMTSJ.MMTSJType mmtsjType = MMTSJ.MMTSJType.valueOf(parts[3]);
55+
56+
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
57+
AggregateBinaryOperator ba = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
58+
59+
return new TSMMOOCInstruction(OOCType.MMTSJ, ba, in1, out, mmtsjType, opcode, str);
60+
}
61+
62+
@Override
63+
public void processInstruction( ExecutionContext ec ) {
64+
MatrixObject min = ec.getMatrixObject(input1);
65+
int nRows = (int) min.getDataCharacteristics().getRows();
66+
int nCols = (int) min.getDataCharacteristics().getCols();
67+
int bLen = min.getDataCharacteristics().getBlocksize();
68+
69+
LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
70+
BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString());
71+
72+
//validation check TODO extend compiler to not create OOC otherwise
73+
if( (_type.isLeft() && nCols > bLen)
74+
|| (_type.isRight() && nRows > bLen) )
75+
{
76+
throw new UnsupportedOperationException();
77+
}
78+
79+
int dim = _type.isLeft() ? nCols : nRows;
80+
MatrixBlock resultBlock = new MatrixBlock(dim, dim, false);
81+
try {
82+
IndexedMatrixValue tmp = null;
83+
// aggregate partial tsmm outputs into result as inputs stream in
84+
while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
85+
MatrixBlock partialResult = ((MatrixBlock) tmp.getValue())
86+
.transposeSelfMatrixMultOperations(new MatrixBlock(), _type);
87+
resultBlock.binaryOperationsInPlace(plus, partialResult);
88+
}
89+
}
90+
catch(Exception ex) {
91+
throw new DMLRuntimeException(ex);
92+
}
93+
94+
ec.setMatrixOutput(output.getName(), resultBlock);
95+
}
96+
}
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.ooc;
21+
22+
import org.apache.sysds.common.Opcodes;
23+
import org.apache.sysds.common.Types;
24+
import org.apache.sysds.lops.MMTSJ;
25+
import org.apache.sysds.runtime.instructions.Instruction;
26+
import org.apache.sysds.runtime.io.MatrixWriter;
27+
import org.apache.sysds.runtime.io.MatrixWriterFactory;
28+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
29+
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
30+
import org.apache.sysds.runtime.util.DataConverter;
31+
import org.apache.sysds.runtime.util.HDFSTool;
32+
import org.apache.sysds.test.AutomatedTestBase;
33+
import org.apache.sysds.test.TestConfiguration;
34+
import org.apache.sysds.test.TestUtils;
35+
import org.junit.Assert;
36+
import org.junit.Test;
37+
38+
import java.io.IOException;
39+
40+
public class TransposeSelfMMTest extends AutomatedTestBase {
41+
private final static String TEST_NAME1 = "TSMM";
42+
private final static String TEST_DIR = "functions/ooc/";
43+
private final static String TEST_CLASS_DIR = TEST_DIR + TransposeSelfMMTest.class.getSimpleName() + "/";
44+
private final static double eps = 1e-8;
45+
private static final String INPUT_NAME = "X";
46+
private static final String OUTPUT_NAME = "res";
47+
48+
private final static int rows = 2143;
49+
private final static int cols = 123;
50+
private final static double sparsity1 = 0.7;
51+
private final static double sparsity2 = 0.1;
52+
private final int k = 1;
53+
54+
@Override
55+
public void setUp() {
56+
TestUtils.clearAssertionInformation();
57+
TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1);
58+
addTestConfiguration(TEST_NAME1, config);
59+
}
60+
61+
@Test
62+
public void testTsmmDense() {
63+
runTSMMTest(cols, false);
64+
}
65+
66+
@Test
67+
public void testTsmmSparse() {
68+
runTSMMTest(cols, false);
69+
}
70+
71+
private void runTSMMTest(int cols, boolean sparse )
72+
{
73+
Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE);
74+
75+
try
76+
{
77+
getAndLoadTestConfiguration(TEST_NAME1);
78+
String HOME = SCRIPT_DIR + TEST_DIR;
79+
fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
80+
programArgs = new String[]{"-explain", "-stats", "-ooc",
81+
"-args", input(INPUT_NAME), output(OUTPUT_NAME)};
82+
83+
// 1. Generate the data in-memory as MatrixBlock objects
84+
double[][] A_data = getRandomMatrix(rows, cols, 0, 1, sparse?sparsity2:sparsity1, 10);
85+
86+
// 2. Convert the double arrays to MatrixBlock objects
87+
MatrixBlock A_mb = DataConverter.convertToMatrixBlock(A_data);
88+
89+
// 3. Create a binary matrix writer
90+
MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
91+
92+
// 4. Write matrix A to a binary SequenceFile
93+
writer.writeMatrixToHDFS(A_mb, input(INPUT_NAME), rows, cols, 1000, A_mb.getNonZeros());
94+
HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"), Types.ValueType.FP64,
95+
new MatrixCharacteristics(rows, cols, 1000, A_mb.getNonZeros()), Types.FileFormat.BINARY);
96+
97+
runTest(true, false, null, -1);
98+
99+
//check tsmm OOC
100+
Assert.assertTrue("OOC wasn't used for TSMM",
101+
heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.TSMM));
102+
103+
//compare results
104+
MatrixBlock ret1 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME),
105+
Types.FileFormat.BINARY, cols, cols, 1000, cols*cols);
106+
MatrixBlock ret2 = new MatrixBlock(rows, rows, false);
107+
A_mb.transposeSelfMatrixMultOperations(ret2, MMTSJ.MMTSJType.LEFT, k);
108+
TestUtils.compareMatrices(ret1, ret2, eps);
109+
}
110+
catch (IOException e) {
111+
throw new RuntimeException(e);
112+
}
113+
finally {
114+
resetExecMode(platformOld);
115+
}
116+
}
117+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
# Read input matrix and operator from command line args
23+
X = read($1);
24+
25+
# Operation under test
26+
res = t(X) %*% X;
27+
28+
write(res, $2, format="binary")

0 commit comments

Comments
 (0)