Skip to content

Commit fcf5c6f

Browse files
author
Parth
committed
[SYSTEMDS-3933] Generalize OOC matrix-vector binary to support streamed vector input
1 parent 8de93a1 commit fcf5c6f

File tree

1 file changed

+29
-24
lines changed

1 file changed

+29
-24
lines changed

src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import java.util.HashMap;
2323

2424
import org.apache.sysds.common.Opcodes;
25-
import org.apache.sysds.conf.ConfigurationManager;
2625
import org.apache.sysds.runtime.DMLRuntimeException;
2726
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
2827
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -51,7 +50,7 @@ public static MatrixVectorBinaryOOCInstruction parseInstruction(String str) {
5150
InstructionUtils.checkNumFields(parts, 4);
5251
String opcode = parts[0];
5352
CPOperand in1 = new CPOperand(parts[1]); // the larget matrix (streamed)
54-
CPOperand in2 = new CPOperand(parts[2]); // the small vector (in-memory)
53+
CPOperand in2 = new CPOperand(parts[2]); // vector operand (may be OOC)
5554
CPOperand out = new CPOperand(parts[3]);
5655

5756
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
@@ -62,48 +61,54 @@ public static MatrixVectorBinaryOOCInstruction parseInstruction(String str) {
6261

6362
@Override
6463
public void processInstruction( ExecutionContext ec ) {
65-
// 1. Identify the inputs
66-
MatrixObject min = ec.getMatrixObject(input1); // big matrix
67-
MatrixBlock vin = ec.getMatrixObject(input2)
68-
.acquireReadAndRelease(); // in-memory vector
69-
70-
// 2. Pre-partition the in-memory vector into a hashmap
71-
HashMap<Long, MatrixBlock> partitionedVector = new HashMap<>();
72-
int blksize = vin.getDataCharacteristics().getBlocksize();
73-
if (blksize < 0)
74-
blksize = ConfigurationManager.getBlocksize();
75-
for (int i=0; i<vin.getNumRows(); i+=blksize) {
76-
long key = (long) (i/blksize) + 1; // the key starts at 1
77-
int end_row = Math.min(i + blksize, vin.getNumRows());
78-
MatrixBlock vectorSlice = vin.slice(i, end_row - 1);
79-
partitionedVector.put(key, vectorSlice);
80-
}
64+
// Fetch both inputs without assuming which one fits in memory
65+
MatrixObject min = ec.getMatrixObject(input1);
66+
MatrixObject vin = ec.getMatrixObject(input2);
8167

8268
// number of colBlocks for early block output
8369
long emitThreshold = min.getDataCharacteristics().getNumColBlocks();
8470
OOCMatrixBlockTracker aggTracker = new OOCMatrixBlockTracker(emitThreshold);
8571

86-
OOCStream<IndexedMatrixValue> qIn = min.getStreamHandle();
72+
OOCStream<IndexedMatrixValue> qIn1 = min.getStreamHandle();
73+
OOCStream<IndexedMatrixValue> qIn2 = vin.getStreamHandle(); // Stream handles for matrix and vector (both may be OOC)
8774
OOCStream<IndexedMatrixValue> qOut = createWritableStream();
8875
BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString());
8976
ec.getMatrixObject(output).setStreamHandle(qOut);
9077

9178
submitOOCTask(() -> {
92-
IndexedMatrixValue tmp = null;
79+
9380
try {
94-
while((tmp = qIn.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) {
81+
// Cache vector blocks indexed by their block row id
82+
// This removes the assumption that the vector is fully in-memory
83+
HashMap<Long, MatrixBlock> vectorCache = new HashMap<>();
84+
85+
// Consume the entire vector stream and cache it block-wise
86+
IndexedMatrixValue vecVal;
87+
while ((vecVal = qIn2.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) {
88+
vectorCache.put(
89+
vecVal.getIndexes().getRowIndex(),
90+
(MatrixBlock) vecVal.getValue());
91+
}
92+
93+
// Stream through matrix blocks and match them with vector blocks
94+
IndexedMatrixValue tmp = null;
95+
while((tmp = qIn1.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) {
9596
MatrixBlock matrixBlock = (MatrixBlock) tmp.getValue();
9697
long rowIndex = tmp.getIndexes().getRowIndex();
9798
long colIndex = tmp.getIndexes().getColumnIndex();
98-
MatrixBlock vectorSlice = partitionedVector.get(colIndex);
99+
MatrixBlock vectorSlice = vectorCache.get(colIndex);
100+
101+
// Fail fast if the corresponding vector block is missing
102+
if (vectorSlice == null)
103+
throw new DMLRuntimeException("Missing vector block for column block " + colIndex);
99104

100105
// Now, call the operation with the correct, specific operator.
101106
MatrixBlock partialResult = matrixBlock.aggregateBinaryOperations(
102107
matrixBlock, vectorSlice, new MatrixBlock(), (AggregateBinaryOperator) _optr);
103108

104109
// for single column block, no aggregation neeeded
105110
if(emitThreshold == 1) {
106-
qOut.enqueue(new IndexedMatrixValue(tmp.getIndexes(), partialResult));
111+
qOut.enqueue(new IndexedMatrixValue(new MatrixIndexes(rowIndex, 1), partialResult));
107112
}
108113
else {
109114
// aggregation
@@ -129,6 +134,6 @@ public void processInstruction( ExecutionContext ec ) {
129134
finally {
130135
qOut.closeInput();
131136
}
132-
}, qIn, qOut);
137+
}, qIn1, qIn2, qOut);
133138
}
134139
}

0 commit comments

Comments
 (0)