Skip to content

Commit 6514887

Browse files
janniklindemboehm7
authored andcommitted
[SYSTEMDS-3918] New out-of-core queues and primitives
Closes #2347.
1 parent e40bbfe commit 6514887

27 files changed

+1178
-318
lines changed

src/main/java/org/apache/sysds/hops/BinaryOp.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ op, getDataType(), getValueType(), et,
478478
setLineNumbers(softmax);
479479
setLops(softmax);
480480
}
481-
else if ( et == ExecType.CP || et == ExecType.GPU || et == ExecType.FED )
481+
else if ( et == ExecType.CP || et == ExecType.GPU || et == ExecType.FED || et == ExecType.OOC )
482482
{
483483
Lop binary = null;
484484

src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@
4949
import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
5050
import org.apache.sysds.runtime.instructions.gpu.context.GPUContext;
5151
import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
52-
import org.apache.sysds.runtime.instructions.ooc.ResettableStream;
52+
import org.apache.sysds.runtime.instructions.ooc.OOCStream;
53+
import org.apache.sysds.runtime.instructions.ooc.OOCStreamable;
54+
import org.apache.sysds.runtime.instructions.ooc.SubscribableTaskQueue;
5355
import org.apache.sysds.runtime.instructions.spark.data.BroadcastObject;
5456
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
5557
import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
@@ -223,7 +225,7 @@ public enum CacheStatus {
223225
private BroadcastObject<T> _bcHandle = null; //Broadcast handle
224226
protected HashMap<GPUContext, GPUObject> _gpuObjects = null; //Per GPUContext object allocated on GPU
225227
//TODO generalize for frames
226-
private LocalTaskQueue<IndexedMatrixValue> _streamHandle = null;
228+
private OOCStreamable<IndexedMatrixValue> _streamHandle = null;
227229

228230
private LineageItem _lineage = null;
229231

@@ -469,34 +471,25 @@ public boolean hasBroadcastHandle() {
469471
return _bcHandle != null && _bcHandle.hasBackReference();
470472
}
471473

472-
public LocalTaskQueue<IndexedMatrixValue> getStreamHandle() {
474+
public OOCStream<IndexedMatrixValue> getStreamHandle() {
473475
if( !hasStreamHandle() ) {
474-
_streamHandle = new LocalTaskQueue<>();
476+
final SubscribableTaskQueue<IndexedMatrixValue> _mStream = new SubscribableTaskQueue<>();
477+
_streamHandle = _mStream;
475478
DataCharacteristics dc = getDataCharacteristics();
476479
MatrixBlock src = (MatrixBlock)acquireReadAndRelease();
477480
LongStream.range(0, dc.getNumBlocks())
478481
.mapToObj(i -> UtilFunctions.createIndexedMatrixBlock(src, dc, i))
479482
.forEach( blk -> {
480483
try{
481-
_streamHandle.enqueueTask(blk);
484+
_mStream.enqueue(blk);
482485
}
483486
catch(Exception ex) {
484-
throw new DMLRuntimeException(ex);
487+
throw ex instanceof DMLRuntimeException ? (DMLRuntimeException) ex : new DMLRuntimeException(ex);
485488
}});
486-
_streamHandle.closeInput();
487-
}
488-
else if(_streamHandle != null && _streamHandle.isProcessed()
489-
&& _streamHandle instanceof ResettableStream)
490-
{
491-
try {
492-
((ResettableStream)_streamHandle).reset();
493-
}
494-
catch(Exception ex) {
495-
throw new DMLRuntimeException(ex);
496-
}
489+
_mStream.closeInput();
497490
}
498491

499-
return _streamHandle;
492+
return _streamHandle.getReadStream();
500493
}
501494

502495
/**
@@ -539,7 +532,7 @@ public synchronized void removeGPUObject(GPUContext gCtx) {
539532
_gpuObjects.remove(gCtx);
540533
}
541534

542-
public synchronized void setStreamHandle(LocalTaskQueue<IndexedMatrixValue> q) {
535+
public synchronized void setStreamHandle(OOCStreamable<IndexedMatrixValue> q) {
543536
_streamHandle = q;
544537
}
545538

@@ -633,7 +626,7 @@ && getRDDHandle() == null) ) {
633626
_requiresLocalWrite = false;
634627
}
635628
else if( hasStreamHandle() ) {
636-
_data = readBlobFromStream( getStreamHandle() );
629+
_data = readBlobFromStream( getStreamHandle().toLocalTaskQueue() );
637630
}
638631
else if( getRDDHandle()==null || getRDDHandle().allowsShortCircuitRead() ) {
639632
if( DMLScript.STATISTICS )

src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ protected long writeStreamToHDFS(String fname, String ofmt, int rep, FileFormatP
611611
MetaDataFormat iimd = (MetaDataFormat) _metaData;
612612
FileFormat fmt = (ofmt != null ? FileFormat.safeValueOf(ofmt) : iimd.getFileFormat());
613613
MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(fmt, rep, fprop);
614-
return writer.writeMatrixFromStream(fname, getStreamHandle(),
614+
return writer.writeMatrixFromStream(fname, getStreamHandle().toLocalTaskQueue(),
615615
getNumRows(), getNumColumns(), ConfigurationManager.getBlocksize());
616616
}
617617

src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ public class LocalTaskQueue<T>
4343
public static final int MAX_SIZE = 100000; //main memory constraint
4444
public static final Object NO_MORE_TASKS = null; //object to signal NO_MORE_TASKS
4545

46-
private LinkedList<T> _data = null;
47-
private boolean _closedInput = false;
46+
protected LinkedList<T> _data = null;
47+
protected boolean _closedInput = false;
4848
private DMLRuntimeException _failure = null;
4949
private static final Log LOG = LogFactory.getLog(LocalTaskQueue.class.getName());
5050

@@ -60,21 +60,19 @@ public LocalTaskQueue()
6060
* @param t task
6161
* @throws InterruptedException if InterruptedException occurs
6262
*/
63-
public synchronized void enqueueTask( T t )
63+
public synchronized void enqueueTask( T t )
6464
throws InterruptedException
6565
{
66-
while( _data.size() + 1 > MAX_SIZE && _failure == null )
67-
{
66+
while(_data.size() + 1 > MAX_SIZE && _failure == null) {
6867
LOG.warn("MAX_SIZE of task queue reached.");
6968
wait(); //max constraint reached, wait for read
7069
}
7170

72-
if ( _failure != null )
71+
if(_failure != null)
7372
throw _failure;
74-
75-
_data.addLast( t );
76-
77-
notify(); //notify waiting readers
73+
74+
_data.addLast(t);
75+
notify();
7876
}
7977

8078
/**
@@ -97,22 +95,22 @@ public synchronized T dequeueTask()
9795

9896
if ( _failure != null )
9997
throw _failure;
100-
98+
10199
T t = _data.removeFirst();
102100

103101
notify(); // notify waiting writers
104102

105103
return t;
106104
}
107-
105+
108106
/**
109107
* Synchronized (logical) insert of a NO_MORE_TASKS symbol at the end of the FIFO queue in order to
110108
* mark that no more tasks will be inserted into the queue.
111109
*/
112110
public synchronized void closeInput()
113111
{
114112
_closedInput = true;
115-
notifyAll(); //notify all waiting readers
113+
notifyAll();
116114
}
117115

118116
public synchronized boolean isProcessed() {

src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ public void processInstruction( ExecutionContext ec ) {
7676
//setup operators and input queue
7777
AggregateUnaryOperator aggun = (AggregateUnaryOperator) getOperator();
7878
MatrixObject min = ec.getMatrixObject(input1);
79-
LocalTaskQueue<IndexedMatrixValue> q = min.getStreamHandle();
79+
OOCStream<IndexedMatrixValue> q = min.getStreamHandle();
8080
int blen = ConfigurationManager.getBlocksize();
8181

8282
if (aggun.isRowAggregate() || aggun.isColAggregate()) {
@@ -86,13 +86,13 @@ public void processInstruction( ExecutionContext ec ) {
8686
OOCMatrixBlockTracker aggTracker = new OOCMatrixBlockTracker(emitThreshold);
8787
HashMap<Long, MatrixBlock> corrs = new HashMap<>(); // correction blocks
8888

89-
LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>();
89+
OOCStream<IndexedMatrixValue> qOut = createWritableStream();
9090
ec.getMatrixObject(output).setStreamHandle(qOut);
9191

9292
submitOOCTask(() -> {
9393
IndexedMatrixValue tmp = null;
9494
try {
95-
while((tmp = q.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
95+
while((tmp = q.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) {
9696
long idx = aggun.isRowAggregate() ?
9797
tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex();
9898
MatrixBlock ret = aggTracker.get(idx);
@@ -139,7 +139,7 @@ public void processInstruction( ExecutionContext ec ) {
139139
new MatrixIndexes(1, tmp.getIndexes().getColumnIndex());
140140
IndexedMatrixValue tmpOut = new IndexedMatrixValue(midx, ret);
141141

142-
qOut.enqueueTask(tmpOut);
142+
qOut.enqueue(tmpOut);
143143
// drop intermediate states
144144
aggTracker.remove(idx);
145145
corrs.remove(idx);
@@ -159,7 +159,7 @@ public void processInstruction( ExecutionContext ec ) {
159159
MatrixBlock ret = new MatrixBlock(1,1+extra,false);
160160
MatrixBlock corr = new MatrixBlock(1,1+extra,false);
161161
try {
162-
while((tmp = q.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
162+
while((tmp = q.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) {
163163
//block aggregation
164164
MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue())
165165
.aggregateUnaryOperations(aggun, new MatrixBlock(), blen, tmp.getIndexes());

src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,14 @@
1919

2020
package org.apache.sysds.runtime.instructions.ooc;
2121

22-
import org.apache.sysds.common.Types.DataType;
23-
import org.apache.sysds.runtime.DMLRuntimeException;
2422
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
2523
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
26-
import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
2724
import org.apache.sysds.runtime.instructions.InstructionUtils;
2825
import org.apache.sysds.runtime.instructions.cp.CPOperand;
2926
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
3027
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
3128
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
29+
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
3230
import org.apache.sysds.runtime.matrix.operators.Operator;
3331
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
3432

@@ -54,33 +52,46 @@ public static BinaryOOCInstruction parseInstruction(String str) {
5452

5553
@Override
5654
public void processInstruction( ExecutionContext ec ) {
57-
//TODO support all types, currently only binary matrix-scalar
58-
55+
if (input1.isMatrix() && input2.isMatrix())
56+
processMatrixMatrixInstruction(ec);
57+
else
58+
processScalarMatrixInstruction(ec);
59+
}
60+
61+
protected void processMatrixMatrixInstruction(ExecutionContext ec) {
62+
MatrixObject m1 = ec.getMatrixObject(input1);
63+
MatrixObject m2 = ec.getMatrixObject(input2);
64+
65+
OOCStream<IndexedMatrixValue> qIn1 = m1.getStreamHandle();
66+
OOCStream<IndexedMatrixValue> qIn2 = m2.getStreamHandle();
67+
OOCStream<IndexedMatrixValue> qOut = new SubscribableTaskQueue<>();
68+
ec.getMatrixObject(output).setStreamHandle(qOut);
69+
70+
joinOOC(qIn1, qIn2, qOut, (tmp1, tmp2) -> {
71+
IndexedMatrixValue tmpOut = new IndexedMatrixValue();
72+
tmpOut.set(tmp1.getIndexes(),
73+
tmp1.getValue().binaryOperations((BinaryOperator)_optr, tmp2.getValue(), tmpOut.getValue()));
74+
return tmpOut;
75+
}, IndexedMatrixValue::getIndexes);
76+
}
77+
78+
protected void processScalarMatrixInstruction(ExecutionContext ec) {
5979
//get operator and scalar
60-
CPOperand scalar = ( input1.getDataType() == DataType.MATRIX ) ? input2 : input1;
80+
CPOperand scalar = input1.isMatrix() ? input2 : input1;
6181
ScalarObject constant = ec.getScalarInput(scalar);
6282
ScalarOperator sc_op = ((ScalarOperator)_optr).setConstant(constant.getDoubleValue());
63-
83+
6484
//create thread and process binary operation
65-
MatrixObject min = ec.getMatrixObject(input1);
66-
LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
67-
LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>();
85+
MatrixObject min = ec.getMatrixObject(input1.isMatrix() ? input1 : input2);
86+
OOCStream<IndexedMatrixValue> qIn = min.getStreamHandle();
87+
OOCStream<IndexedMatrixValue> qOut = createWritableStream();
6888
ec.getMatrixObject(output).setStreamHandle(qOut);
69-
70-
submitOOCTask(() -> {
71-
IndexedMatrixValue tmp = null;
72-
try {
73-
while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
74-
IndexedMatrixValue tmpOut = new IndexedMatrixValue();
75-
tmpOut.set(tmp.getIndexes(),
76-
tmp.getValue().scalarOperations(sc_op, new MatrixBlock()));
77-
qOut.enqueueTask(tmpOut);
78-
}
79-
qOut.closeInput();
80-
}
81-
catch(Exception ex) {
82-
throw new DMLRuntimeException(ex);
83-
}
84-
}, qIn, qOut);
89+
90+
mapOOC(qIn, qOut, tmp -> {
91+
IndexedMatrixValue tmpOut = new IndexedMatrixValue();
92+
tmpOut.set(tmp.getIndexes(),
93+
tmp.getValue().scalarOperations(sc_op, new MatrixBlock()));
94+
return tmpOut;
95+
});
8596
}
8697
}

0 commit comments

Comments
 (0)