Skip to content

Commit 2cb1a60

Browse files
jessicapriebemboehm7
authored andcommitted
[SYSTEMDS-3908] Improved OOC matrix multiplication w/ early outputs
Closes #2310.
1 parent 9d2985d commit 2cb1a60

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
package org.apache.sysds.runtime.instructions.ooc;
2121

2222
import java.util.HashMap;
23-
import java.util.Map;
2423
import java.util.concurrent.ExecutorService;
2524

2625
import org.apache.sysds.common.Opcodes;
@@ -82,6 +81,9 @@ public void processInstruction( ExecutionContext ec ) {
8281
partitionedVector.put(key, vectorSlice);
8382
}
8483

84+
// number of colBlocks for early block output
85+
long nBlocks = min.getDataCharacteristics().getNumColBlocks();
86+
8587
LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
8688
LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>();
8789
BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString());
@@ -94,6 +96,7 @@ public void processInstruction( ExecutionContext ec ) {
9496
IndexedMatrixValue tmp = null;
9597
try {
9698
HashMap<Long, MatrixBlock> partialResults = new HashMap<>();
99+
HashMap<Long, Integer> cnt = new HashMap<>();
97100
while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
98101
MatrixBlock matrixBlock = (MatrixBlock) tmp.getValue();
99102
long rowIndex = tmp.getIndexes().getRowIndex();
@@ -109,19 +112,29 @@ public void processInstruction( ExecutionContext ec ) {
109112
qOut.enqueueTask(new IndexedMatrixValue(tmp.getIndexes(), partialResult));
110113
}
111114
else {
115+
// aggregation
112116
MatrixBlock currAgg = partialResults.get(rowIndex);
113-
if (currAgg == null)
117+
if (currAgg == null) {
114118
partialResults.put(rowIndex, partialResult);
115-
else
119+
cnt.put(rowIndex, 1);
120+
}
121+
else {
116122
currAgg.binaryOperationsInPlace(plus, partialResult);
117-
}
118-
}
119-
120-
// emit aggregated blocks
121-
if( min.getNumColumns() > min.getBlocksize() ) {
122-
for (Map.Entry<Long, MatrixBlock> entry : partialResults.entrySet()) {
123-
MatrixIndexes outIndexes = new MatrixIndexes(entry.getKey(), 1L);
124-
qOut.enqueueTask(new IndexedMatrixValue(outIndexes, entry.getValue()));
123+
int newCnt = cnt.get(rowIndex) + 1;
124+
125+
if(newCnt == nBlocks){
126+
// early block output: emit aggregated block
127+
MatrixIndexes idx = new MatrixIndexes(rowIndex, 1L);
128+
MatrixBlock result = partialResults.get(rowIndex);
129+
qOut.enqueueTask(new IndexedMatrixValue(idx, result));
130+
partialResults.remove(rowIndex);
131+
cnt.remove(rowIndex);
132+
}
133+
else {
134+
// maintain aggregation counts if not output-ready yet
135+
cnt.replace(rowIndex, newCnt);
136+
}
137+
}
125138
}
126139
}
127140
}

0 commit comments

Comments
 (0)