Skip to content

Commit 79122eb

Browse files
janniklindemboehm7
authored andcommitted
[SYSTEMDS-3891] New out-of-core instructions and improvements
Closes #2362.
1 parent b090e49 commit 79122eb

File tree

18 files changed

+826
-107
lines changed

18 files changed

+826
-107
lines changed

src/main/java/org/apache/sysds/hops/DataOp.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ public Lop constructLops()
297297

298298
case TEE:
299299
l = new Tee(getInput(0).constructLops(), getDataType(), getValueType());
300+
setOutputDimensions(l);
300301
break;
301302

302303
default:
@@ -488,7 +489,7 @@ else if ( getInput().get(0).areDimsBelowThreshold() )
488489

489490
@Override
490491
public void refreshSizeInformation() {
491-
if( _op == OpOpData.PERSISTENTWRITE || _op == OpOpData.TRANSIENTWRITE ) {
492+
if( _op == OpOpData.PERSISTENTWRITE || _op == OpOpData.TRANSIENTWRITE || _op == OpOpData.TEE ) {
492493
Hop input1 = getInput().get(0);
493494
setDim1(input1.getDim1());
494495
setDim2(input1.getDim2());

src/main/java/org/apache/sysds/lops/DataGen.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ private String getRandInstructionCPSpark(String output)
199199
sb.append(iLop == null ? "" : iLop.prepScalarLabel());
200200
sb.append(OPERAND_DELIMITOR);
201201

202-
if( getExecType() == ExecType.CP ) {
202+
if( getExecType() == ExecType.CP || getExecType() == ExecType.OOC ) {
203203
//append degree of parallelism
204204
sb.append( _numThreads );
205205
sb.append( OPERAND_DELIMITOR );

src/main/java/org/apache/sysds/lops/Transform.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ private String getInstructions(String input1, int numInputs, String output) {
179179
sb.append( OPERAND_DELIMITOR );
180180
sb.append( this.prepOutputOperand(output));
181181

182-
if( (getExecType()==ExecType.CP || getExecType()==ExecType.FED)
182+
if( (getExecType()==ExecType.CP || getExecType()==ExecType.FED || getExecType()==ExecType.OOC)
183183
&& (_operation == ReOrgOp.TRANS || _operation == ReOrgOp.REV || _operation == ReOrgOp.SORT) ) {
184184
sb.append( OPERAND_DELIMITOR );
185185
sb.append( _numThreads );

src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import java.io.IOException;
2323
import java.lang.ref.SoftReference;
24+
import java.util.ArrayList;
2425
import java.util.List;
2526
import java.util.concurrent.Future;
2627

@@ -528,7 +529,12 @@ protected MatrixBlock readBlobFromRDD(RDDObject rdd, MutableBoolean writeStatus)
528529

529530
@Override
530531
protected MatrixBlock readBlobFromStream(LocalTaskQueue<IndexedMatrixValue> stream) throws IOException {
531-
MatrixBlock ret = new MatrixBlock((int)getNumRows(), (int)getNumColumns(), false);
532+
boolean dimsUnknown = getNumRows() < 0 || getNumColumns() < 0;
533+
int nrows = (int)getNumRows();
534+
int ncols = (int)getNumColumns();
535+
MatrixBlock ret = dimsUnknown ? null : new MatrixBlock((int)getNumRows(), (int)getNumColumns(), false);
536+
// TODO if stream is CachingStream, block parts might be evicted resulting in null pointer exceptions
537+
List<IndexedMatrixValue> blockCache = dimsUnknown ? new ArrayList<>() : null;
532538
IndexedMatrixValue tmp = null;
533539
try {
534540
int blen = getBlocksize(), lnnz = 0;
@@ -537,12 +543,31 @@ protected MatrixBlock readBlobFromStream(LocalTaskQueue<IndexedMatrixValue> stre
537543
final int row_offset = (int) (tmp.getIndexes().getRowIndex() - 1) * blen;
538544
final int col_offset = (int) (tmp.getIndexes().getColumnIndex() - 1) * blen;
539545

540-
// Add the values of this block into the output block.
541-
((MatrixBlock)tmp.getValue()).putInto(ret, row_offset, col_offset, true);
546+
if (dimsUnknown) {
547+
nrows = Math.max(nrows, row_offset + tmp.getValue().getNumRows());
548+
ncols = Math.max(ncols, col_offset + tmp.getValue().getNumColumns());
549+
blockCache.add(tmp);
550+
} else {
551+
// Add the values of this block into the output block.
552+
((MatrixBlock) tmp.getValue()).putInto(ret, row_offset, col_offset, true);
553+
}
542554

543555
// incremental maintenance nnz
544556
lnnz += tmp.getValue().getNonZeros();
545557
}
558+
559+
if (dimsUnknown) {
560+
ret = new MatrixBlock(nrows, ncols, false);
561+
562+
for (IndexedMatrixValue _tmp : blockCache) {
563+
// compute row/column block offsets
564+
final int row_offset = (int) (_tmp.getIndexes().getRowIndex() - 1) * blen;
565+
final int col_offset = (int) (_tmp.getIndexes().getColumnIndex() - 1) * blen;
566+
567+
((MatrixBlock) _tmp.getValue()).putInto(ret, row_offset, col_offset, true);
568+
}
569+
}
570+
546571
ret.setNonZeros(lnnz);
547572
}
548573
catch(Exception ex) {

src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
import org.apache.sysds.runtime.instructions.ooc.TSMMOOCInstruction;
3737
import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction;
3838
import org.apache.sysds.runtime.instructions.ooc.MatrixVectorBinaryOOCInstruction;
39-
import org.apache.sysds.runtime.instructions.ooc.TransposeOOCInstruction;
39+
import org.apache.sysds.runtime.instructions.ooc.ReorgOOCInstruction;
4040
import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction;
4141

4242
public class OOCInstructionParser extends InstructionParser {
@@ -74,7 +74,7 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str
7474
case MMTSJ:
7575
return TSMMOOCInstruction.parseInstruction(str);
7676
case Reorg:
77-
return TransposeOOCInstruction.parseInstruction(str);
77+
return ReorgOOCInstruction.parseInstruction(str);
7878
case Tee:
7979
return TeeOOCInstruction.parseInstruction(str);
8080
case CentralMoment:

src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
package org.apache.sysds.runtime.instructions.ooc;
2121

22+
import org.apache.commons.lang3.NotImplementedException;
23+
import org.apache.sysds.runtime.DMLRuntimeException;
2224
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
2325
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
2426
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -67,12 +69,55 @@ protected void processMatrixMatrixInstruction(ExecutionContext ec) {
6769
OOCStream<IndexedMatrixValue> qOut = new SubscribableTaskQueue<>();
6870
ec.getMatrixObject(output).setStreamHandle(qOut);
6971

70-
joinOOC(qIn1, qIn2, qOut, (tmp1, tmp2) -> {
71-
IndexedMatrixValue tmpOut = new IndexedMatrixValue();
72-
tmpOut.set(tmp1.getIndexes(),
73-
tmp1.getValue().binaryOperations((BinaryOperator)_optr, tmp2.getValue(), tmpOut.getValue()));
74-
return tmpOut;
75-
}, IndexedMatrixValue::getIndexes);
72+
if (m1.getNumRows() < 0 || m1.getNumColumns() < 0 || m2.getNumRows() < 0 || m2.getNumColumns() < 0)
73+
throw new DMLRuntimeException("Cannot process (matrix, matrix) BinaryOOCInstruction with unknown dimensions.");
74+
75+
boolean isColBroadcast = m1.getNumColumns() > 1 && m2.getNumColumns() == 1;
76+
boolean isRowBroadcast = m1.getNumRows() > 1 && m2.getNumRows() == 1;
77+
78+
if (isColBroadcast && !isRowBroadcast) {
79+
final long maxProcessesPerBroadcast = m1.getNumColumns() / m1.getBlocksize();
80+
81+
broadcastJoinOOC(qIn1, qIn2, qOut, (tmp1, b) -> {
82+
IndexedMatrixValue tmpOut = new IndexedMatrixValue();
83+
tmpOut.set(tmp1.getIndexes(),
84+
tmp1.getValue().binaryOperations((BinaryOperator)_optr, b.getValue().getValue(), tmpOut.getValue()));
85+
86+
if (b.incrProcessCtrAndGet() >= maxProcessesPerBroadcast)
87+
b.release();
88+
89+
return tmpOut;
90+
}, tmp -> tmp.getIndexes().getRowIndex());
91+
}
92+
else if (isRowBroadcast && !isColBroadcast) {
93+
final long maxProcessesPerBroadcast = m1.getNumRows() / m1.getBlocksize();
94+
95+
broadcastJoinOOC(qIn1, qIn2, qOut, (tmp1, b) -> {
96+
IndexedMatrixValue tmpOut = new IndexedMatrixValue();
97+
tmpOut.set(tmp1.getIndexes(),
98+
tmp1.getValue().binaryOperations((BinaryOperator)_optr, b.getValue().getValue(), tmpOut.getValue()));
99+
100+
if (b.incrProcessCtrAndGet() >= maxProcessesPerBroadcast)
101+
b.release();
102+
103+
return tmpOut;
104+
}, tmp -> tmp.getIndexes().getColumnIndex());
105+
}
106+
else {
107+
if (m1.getNumColumns() != m2.getNumColumns() || m1.getNumRows() != m2.getNumRows())
108+
throw new NotImplementedException("Invalid dimensions for matrix-matrix binary op: "
109+
+ m1.getNumRows() + "x" + m1.getNumColumns() + " <=> "
110+
+ m2.getNumRows() + "x" + m2.getNumColumns());
111+
112+
joinOOC(qIn1, qIn2, qOut, (tmp1, tmp2) -> {
113+
IndexedMatrixValue tmpOut = new IndexedMatrixValue();
114+
tmpOut.set(tmp1.getIndexes(),
115+
tmp1.getValue().binaryOperations((BinaryOperator)_optr, tmp2.getValue(), tmpOut.getValue()));
116+
return tmpOut;
117+
}, IndexedMatrixValue::getIndexes);
118+
}
119+
120+
76121
}
77122

78123
protected void processScalarMatrixInstruction(ExecutionContext ec) {

src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ public class CachingStream implements OOCStreamable<IndexedMatrixValue> {
5252
private boolean _cacheInProgress = true; // caching in progress, in the first pass.
5353
private Map<MatrixIndexes, Integer> _index;
5454

55+
private DMLRuntimeException _failure;
56+
5557
public CachingStream(OOCStream<IndexedMatrixValue> source) {
5658
this(source, _streamSeq.getNextID());
5759
}
@@ -76,6 +78,22 @@ public CachingStream(OOCStream<IndexedMatrixValue> source, long streamId) {
7678
}
7779
} catch (InterruptedException e) {
7880
throw new DMLRuntimeException(e);
81+
} catch (DMLRuntimeException e) {
82+
// Propagate failure to subscribers
83+
_failure = e;
84+
synchronized (this) {
85+
notifyAll();
86+
}
87+
88+
Runnable[] mSubscribers = _subscribers;
89+
if(mSubscribers != null) {
90+
for(Runnable mSubscriber : mSubscribers) {
91+
try {
92+
mSubscriber.run();
93+
} catch (Exception ignored) {
94+
}
95+
}
96+
}
7997
}
8098
});
8199
}
@@ -103,7 +121,9 @@ private synchronized boolean fetchFromStream() throws InterruptedException {
103121

104122
public synchronized IndexedMatrixValue get(int idx) throws InterruptedException {
105123
while (true) {
106-
if (idx < _numBlocks) {
124+
if (_failure != null)
125+
throw _failure;
126+
else if (idx < _numBlocks) {
107127
IndexedMatrixValue out = OOCEvictionManager.get(_streamId, idx);
108128

109129
if (_index != null) // Ensure index is up to date

0 commit comments

Comments
 (0)