Skip to content

Commit 0c7e046

Browse files
j143mboehm7
authored andcommitted
[SYSTEMDS-3904] New OOC matrix-vector multiplication
Closes #2305.
1 parent a61919e commit 0c7e046

File tree

6 files changed

+316
-1
lines changed

6 files changed

+316
-1
lines changed

src/main/java/org/apache/sysds/hops/AggBinaryOp.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,14 @@ public Lop constructLops() {
240240
default:
241241
throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + _method + ") while constructing SPARK lops.");
242242
}
243+
} else if (et == ExecType.OOC) {
244+
Lop in1 = getInput().get(0).constructLops();
245+
Lop in2 = getInput().get(1).constructLops();
246+
MatMultCP matmult = new MatMultCP(in1, in2, getDataType(), getValueType(),
247+
et, OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
248+
setOutputDimensions(matmult);
249+
setLineNumbers(matmult);
250+
setLops(matmult);
243251
}
244252
} else
245253
throw new HopsException(this.printErrorLocation() + "Invalid operation in AggBinary Hop, aggBin(" + innerOp + "," + outerOp + ") while constructing lops.");

src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
2929
import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
3030
import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction;
31+
import org.apache.sysds.runtime.instructions.ooc.MatrixVectorBinaryOOCInstruction;
3132

3233
public class OOCInstructionParser extends InstructionParser {
3334
protected static final Log LOG = LogFactory.getLog(OOCInstructionParser.class.getName());
@@ -56,6 +57,9 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str
5657
return UnaryOOCInstruction.parseInstruction(str);
5758
case Binary:
5859
return BinaryOOCInstruction.parseInstruction(str);
60+
case AggregateBinary:
61+
case MAPMM:
62+
return MatrixVectorBinaryOOCInstruction.parseInstruction(str);
5963

6064
default:
6165
throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype);
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.instructions.ooc;
21+
22+
import java.util.HashMap;
23+
import java.util.Map;
24+
import java.util.concurrent.ExecutorService;
25+
26+
import org.apache.sysds.common.Opcodes;
27+
import org.apache.sysds.conf.ConfigurationManager;
28+
import org.apache.sysds.runtime.DMLRuntimeException;
29+
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
30+
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
31+
import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
32+
import org.apache.sysds.runtime.functionobjects.Multiply;
33+
import org.apache.sysds.runtime.functionobjects.Plus;
34+
import org.apache.sysds.runtime.instructions.InstructionUtils;
35+
import org.apache.sysds.runtime.instructions.cp.CPOperand;
36+
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
37+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
38+
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
39+
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
40+
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
41+
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
42+
import org.apache.sysds.runtime.matrix.operators.Operator;
43+
import org.apache.sysds.runtime.util.CommonThreadPool;
44+
45+
public class MatrixVectorBinaryOOCInstruction extends ComputationOOCInstruction {
46+
47+
48+
protected MatrixVectorBinaryOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
49+
super(type, op, in1, in2, out, opcode, istr);
50+
}
51+
52+
public static MatrixVectorBinaryOOCInstruction parseInstruction(String str) {
53+
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
54+
InstructionUtils.checkNumFields(parts, 4);
55+
String opcode = parts[0];
56+
CPOperand in1 = new CPOperand(parts[1]); // the larget matrix (streamed)
57+
CPOperand in2 = new CPOperand(parts[2]); // the small vector (in-memory)
58+
CPOperand out = new CPOperand(parts[3]);
59+
60+
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
61+
AggregateBinaryOperator ba = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
62+
63+
return new MatrixVectorBinaryOOCInstruction(OOCType.MAPMM, ba, in1, in2, out, opcode, str);
64+
}
65+
66+
@Override
67+
public void processInstruction( ExecutionContext ec ) {
68+
// 1. Identify the inputs
69+
MatrixObject min = ec.getMatrixObject(input1); // big matrix
70+
MatrixBlock vin = ec.getMatrixObject(input2)
71+
.acquireReadAndRelease(); // in-memory vector
72+
73+
// 2. Pre-partition the in-memory vector into a hashmap
74+
HashMap<Long, MatrixBlock> partitionedVector = new HashMap<>();
75+
int blksize = vin.getDataCharacteristics().getBlocksize();
76+
if (blksize < 0)
77+
blksize = ConfigurationManager.getBlocksize();
78+
for (int i=0; i<vin.getNumRows(); i+=blksize) {
79+
long key = (long) (i/blksize) + 1; // the key starts at 1
80+
int end_row = Math.min(i + blksize, vin.getNumRows());
81+
MatrixBlock vectorSlice = vin.slice(i, end_row - 1);
82+
partitionedVector.put(key, vectorSlice);
83+
}
84+
85+
LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
86+
LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>();
87+
BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString());
88+
ec.getMatrixObject(output).setStreamHandle(qOut);
89+
90+
ExecutorService pool = CommonThreadPool.get();
91+
try {
92+
// Core logic: background thread
93+
pool.submit(() -> {
94+
IndexedMatrixValue tmp = null;
95+
try {
96+
HashMap<Long, MatrixBlock> partialResults = new HashMap<>();
97+
while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
98+
MatrixBlock matrixBlock = (MatrixBlock) tmp.getValue();
99+
long rowIndex = tmp.getIndexes().getRowIndex();
100+
long colIndex = tmp.getIndexes().getColumnIndex();
101+
MatrixBlock vectorSlice = partitionedVector.get(colIndex);
102+
103+
// Now, call the operation with the correct, specific operator.
104+
MatrixBlock partialResult = matrixBlock.aggregateBinaryOperations(
105+
matrixBlock, vectorSlice, new MatrixBlock(), (AggregateBinaryOperator) _optr);
106+
107+
// for single column block, no aggregation neeeded
108+
if( min.getNumColumns() <= min.getBlocksize() ) {
109+
qOut.enqueueTask(new IndexedMatrixValue(tmp.getIndexes(), partialResult));
110+
}
111+
else {
112+
MatrixBlock currAgg = partialResults.get(rowIndex);
113+
if (currAgg == null)
114+
partialResults.put(rowIndex, partialResult);
115+
else
116+
currAgg.binaryOperationsInPlace(plus, partialResult);
117+
}
118+
}
119+
120+
// emit aggregated blocks
121+
if( min.getNumColumns() > min.getBlocksize() ) {
122+
for (Map.Entry<Long, MatrixBlock> entry : partialResults.entrySet()) {
123+
MatrixIndexes outIndexes = new MatrixIndexes(entry.getKey(), 1L);
124+
qOut.enqueueTask(new IndexedMatrixValue(outIndexes, entry.getValue()));
125+
}
126+
}
127+
}
128+
catch(Exception ex) {
129+
throw new DMLRuntimeException(ex);
130+
}
131+
finally {
132+
qOut.closeInput();
133+
}
134+
});
135+
} catch (Exception e) {
136+
throw new DMLRuntimeException(e);
137+
}
138+
finally {
139+
pool.shutdown();
140+
}
141+
}
142+
}

src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public abstract class OOCInstruction extends Instruction {
3030
protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName());
3131

3232
public enum OOCType {
33-
Reblock, AggregateUnary, Binary, Unary
33+
Reblock, AggregateUnary, Binary, Unary, MAPMM, AggregateBinary
3434
}
3535

3636
protected final OOCInstruction.OOCType _ooctype;
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.ooc;
21+
22+
import org.apache.sysds.common.Types;
23+
import org.apache.sysds.runtime.io.MatrixWriter;
24+
import org.apache.sysds.runtime.io.MatrixWriterFactory;
25+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
26+
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
27+
import org.apache.sysds.runtime.util.DataConverter;
28+
import org.apache.sysds.runtime.util.HDFSTool;
29+
import org.apache.sysds.test.AutomatedTestBase;
30+
import org.apache.sysds.test.TestConfiguration;
31+
import org.apache.sysds.test.TestUtils;
32+
import org.junit.Assert;
33+
import org.junit.Test;
34+
35+
import java.io.IOException;
36+
37+
public class MatrixVectorBinaryMultiplicationTest extends AutomatedTestBase {
38+
private final static String TEST_NAME1 = "MatrixVectorMultiplication";
39+
private final static String TEST_DIR = "functions/ooc/";
40+
private final static String TEST_CLASS_DIR = TEST_DIR + MatrixVectorBinaryMultiplicationTest.class.getSimpleName() + "/";
41+
private final static double eps = 1e-10;
42+
private static final String INPUT_NAME = "X";
43+
private static final String INPUT_NAME2 = "v";
44+
private static final String OUTPUT_NAME = "res";
45+
46+
private final static int rows = 5000;
47+
private final static int cols_wide = 2000;
48+
private final static int cols_skinny = 500;
49+
50+
private final static double sparsity1 = 0.7;
51+
private final static double sparsity2 = 0.1;
52+
53+
@Override
54+
public void setUp() {
55+
TestUtils.clearAssertionInformation();
56+
TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1);
57+
addTestConfiguration(TEST_NAME1, config);
58+
}
59+
60+
@Test
61+
public void testMVBinaryMultiplication1() {
62+
runMatrixVectorMultiplicationTest(cols_wide, false);
63+
}
64+
65+
@Test
66+
public void testMVBinaryMultiplication2() {
67+
runMatrixVectorMultiplicationTest(cols_skinny, false);
68+
}
69+
70+
private void runMatrixVectorMultiplicationTest(int cols, boolean sparse )
71+
{
72+
Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE);
73+
74+
try
75+
{
76+
getAndLoadTestConfiguration(TEST_NAME1);
77+
String HOME = SCRIPT_DIR + TEST_DIR;
78+
fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
79+
programArgs = new String[]{"-explain", "-stats", "-ooc",
80+
"-args", input(INPUT_NAME), input(INPUT_NAME2), output(OUTPUT_NAME)};
81+
82+
// 1. Generate the data in-memory as MatrixBlock objects
83+
double[][] A_data = getRandomMatrix(rows, cols, 0, 1, sparse?sparsity2:sparsity1, 10);
84+
double[][] x_data = getRandomMatrix(cols, 1, 0, 1, 1.0, 10);
85+
86+
// 2. Convert the double arrays to MatrixBlock objects
87+
MatrixBlock A_mb = DataConverter.convertToMatrixBlock(A_data);
88+
MatrixBlock x_mb = DataConverter.convertToMatrixBlock(x_data);
89+
90+
// 3. Create a binary matrix writer
91+
MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
92+
93+
// 4. Write matrix A to a binary SequenceFile
94+
writer.writeMatrixToHDFS(A_mb, input(INPUT_NAME), rows, cols, 1000, A_mb.getNonZeros());
95+
HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"), Types.ValueType.FP64,
96+
new MatrixCharacteristics(rows, cols, 1000, A_mb.getNonZeros()), Types.FileFormat.BINARY);
97+
98+
// 5. Write vector x to a binary SequenceFile
99+
writer.writeMatrixToHDFS(x_mb, input(INPUT_NAME2), cols, 1, 1000, x_mb.getNonZeros());
100+
HDFSTool.writeMetaDataFile(input(INPUT_NAME2 + ".mtd"), Types.ValueType.FP64,
101+
new MatrixCharacteristics(cols, 1, 1000, x_mb.getNonZeros()), Types.FileFormat.BINARY);
102+
103+
boolean exceptionExpected = false;
104+
runTest(true, exceptionExpected, null, -1);
105+
106+
double[][] C1 = readMatrix(output(OUTPUT_NAME), Types.FileFormat.BINARY, rows, cols, 1000, 1000);
107+
double result = 0.0;
108+
for(int i = 0; i < rows; i++) { // verify the results with Java
109+
double expected = 0.0;
110+
for(int j = 0; j < cols; j++) {
111+
expected += A_mb.get(i, j) * x_mb.get(j,0);
112+
}
113+
result = C1[i][0];
114+
Assert.assertEquals(expected, result, eps);
115+
}
116+
}
117+
catch (IOException e) {
118+
throw new RuntimeException(e);
119+
}
120+
finally {
121+
resetExecMode(platformOld);
122+
}
123+
}
124+
125+
private static double[][] readMatrix(String fname, Types.FileFormat fmt, long rows, long cols, int brows, int bcols )
126+
throws IOException
127+
{
128+
MatrixBlock mb = DataConverter.readMatrixFromHDFS(fname, fmt, rows, cols, brows, bcols);
129+
double[][] C = DataConverter.convertToDoubleMatrix(mb);
130+
return C;
131+
}
132+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
# Read input matrix and operator from command line args
23+
X = read($1);
24+
v = read($2);
25+
26+
# Operation under test
27+
res = X %*% v;
28+
29+
write(res, $3, format="binary")

0 commit comments

Comments
 (0)