Skip to content

Commit 78c8d7f

Browse files
committed
[SYSTEMDS-3924] New primitive for OOC stream creation / reset
This patch improves the OOC backend by a new primitive that automatically creates or resets existing OOC streams on getStreamHandle such that OOC instruction don't need to handle this issue individually, but can still probe the existence of active streams via hasStreamHandle. As a result if there are OOC instructions that consume materialized intermediates, we automatically create new streams from these intermediates.
1 parent baacab4 commit 78c8d7f

File tree

7 files changed

+86
-5
lines changed

7 files changed

+86
-5
lines changed

src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
*/
5252
public class RewriteInjectOOCTee extends HopRewriteRule {
5353

54+
public static boolean APPLY_ONLY_XtX_PATTERN = false;
55+
5456
private static final Set<Long> rewrittenHops = new HashSet<>();
5557
private static final Map<Long, Hop> handledHop = new HashMap<>();
5658

@@ -140,7 +142,7 @@ private void findRewriteCandidates(Hop hop) {
140142
&& hop.getDataType().isMatrix()
141143
&& !HopRewriteUtils.isData(hop, OpOpData.TEE)
142144
&& hop.getParent().size() > 1
143-
&& isSelfTranposePattern(hop)) //FIXME remove
145+
&& (!APPLY_ONLY_XtX_PATTERN || isSelfTranposePattern(hop))) //FIXME remove
144146
{
145147
rewriteCandidates.add(hop);
146148
}

src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.HashMap;
2626
import java.util.Map;
2727
import java.util.concurrent.atomic.AtomicLong;
28+
import java.util.stream.LongStream;
2829

2930
import org.apache.commons.lang3.mutable.MutableBoolean;
3031
import org.apache.commons.logging.Log;
@@ -48,19 +49,22 @@
4849
import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
4950
import org.apache.sysds.runtime.instructions.gpu.context.GPUContext;
5051
import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
52+
import org.apache.sysds.runtime.instructions.ooc.ResettableStream;
5153
import org.apache.sysds.runtime.instructions.spark.data.BroadcastObject;
5254
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
5355
import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
5456
import org.apache.sysds.runtime.io.FileFormatProperties;
5557
import org.apache.sysds.runtime.io.IOUtilFunctions;
5658
import org.apache.sysds.runtime.io.ReaderWriterFederated;
5759
import org.apache.sysds.runtime.lineage.LineageItem;
60+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
5861
import org.apache.sysds.runtime.meta.DataCharacteristics;
5962
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
6063
import org.apache.sysds.runtime.meta.MetaData;
6164
import org.apache.sysds.runtime.meta.MetaDataFormat;
6265
import org.apache.sysds.runtime.util.HDFSTool;
6366
import org.apache.sysds.runtime.util.LocalFileUtils;
67+
import org.apache.sysds.runtime.util.UtilFunctions;
6468
import org.apache.sysds.utils.Statistics;
6569
import org.apache.sysds.utils.stats.InfrastructureAnalyzer;
6670

@@ -466,8 +470,44 @@ public boolean hasBroadcastHandle() {
466470
}
467471

468472
public LocalTaskQueue<IndexedMatrixValue> getStreamHandle() {
473+
if( !hasStreamHandle() ) {
474+
_streamHandle = new LocalTaskQueue<>();
475+
DataCharacteristics dc = getDataCharacteristics();
476+
MatrixBlock src = (MatrixBlock)acquireReadAndRelease();
477+
LongStream.range(0, dc.getNumBlocks())
478+
.mapToObj(i -> UtilFunctions.createIndexedMatrixBlock(src, dc, i))
479+
.forEach( blk -> {
480+
try{
481+
_streamHandle.enqueueTask(blk);
482+
}
483+
catch(Exception ex) {
484+
throw new DMLRuntimeException(ex);
485+
}});
486+
_streamHandle.closeInput();
487+
}
488+
else if(_streamHandle != null && _streamHandle.isProcessed()
489+
&& _streamHandle instanceof ResettableStream)
490+
{
491+
try {
492+
((ResettableStream)_streamHandle).reset();
493+
}
494+
catch(Exception ex) {
495+
throw new DMLRuntimeException(ex);
496+
}
497+
}
498+
469499
return _streamHandle;
470500
}
501+
502+
/**
503+
* Probes if stream handle is existing, because <code>getStreamHandle<code>
504+
* creates a new stream if not existing.
505+
*
506+
* @return true if existing, false otherwise
507+
*/
508+
public boolean hasStreamHandle() {
509+
return _streamHandle != null && !_streamHandle.isProcessed();
510+
}
471511

472512
@SuppressWarnings({ "rawtypes", "unchecked" })
473513
public void setBroadcastHandle( BroadcastObject bc ) {
@@ -592,7 +632,7 @@ && getRDDHandle() == null) ) {
592632
//mark for initial local write despite read operation
593633
_requiresLocalWrite = false;
594634
}
595-
else if( getStreamHandle() != null ) {
635+
else if( hasStreamHandle() ) {
596636
_data = readBlobFromStream( getStreamHandle() );
597637
}
598638
else if( getRDDHandle()==null || getRDDHandle().allowsShortCircuitRead() ) {
@@ -909,7 +949,7 @@ public synchronized void exportData (String fName, String outputFormat, int repl
909949
// a) get the matrix
910950
boolean federatedWrite = (outputFormat != null ) && outputFormat.contains("federated");
911951

912-
if(getStreamHandle()!=null) {
952+
if(hasStreamHandle()) {
913953
try {
914954
long totalNnz = writeStreamToHDFS(fName, outputFormat, replication, formatProperties);
915955
updateDataCharacteristics(new MatrixCharacteristics(

src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ public synchronized void closeInput()
106106
_closedInput = true;
107107
notifyAll(); //notify all waiting readers
108108
}
109+
110+
public synchronized boolean isProcessed() {
111+
return _closedInput && _data.isEmpty();
112+
}
109113

110114
@Override
111115
public synchronized String toString()

src/main/java/org/apache/sysds/runtime/instructions/ooc/ResettableStream.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ public synchronized IndexedMatrixValue dequeueTask()
8383
* This can only be called once the stream is fully consumed once.
8484
*/
8585
public synchronized void reset() throws InterruptedException {
86-
if (_cacheInProgress) {
86+
while (_cacheInProgress) {
8787
// Attempted to reset a stream that's not been fully cached yet.
8888
wait();
8989
}
@@ -94,4 +94,9 @@ public synchronized void reset() throws InterruptedException {
9494
public synchronized void closeInput() {
9595
_source.closeInput();
9696
}
97+
98+
@Override
99+
public synchronized boolean isProcessed() {
100+
return false;
101+
}
97102
}

src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,10 @@
5353
import org.apache.sysds.runtime.frame.data.columns.HashIntegerArray;
5454
import org.apache.sysds.runtime.frame.data.columns.HashLongArray;
5555
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
56+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
5657
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
5758
import org.apache.sysds.runtime.matrix.data.Pair;
59+
import org.apache.sysds.runtime.meta.DataCharacteristics;
5860
import org.apache.sysds.runtime.meta.TensorCharacteristics;
5961
import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
6062

@@ -1471,4 +1473,26 @@ public static String[] cleanAndTokenizeRow(String[] row) {
14711473

14721474
return joined.split("\\s+");
14731475
}
1476+
1477+
public static IndexedMatrixValue createIndexedMatrixBlock(MatrixBlock mb, DataCharacteristics mc, long ix) {
1478+
try {
1479+
//compute block indexes
1480+
long blockRow = ix / mc.getNumColBlocks();
1481+
long blockCol = ix % mc.getNumColBlocks();
1482+
//compute block sizes
1483+
int maxRow = UtilFunctions.computeBlockSize(mc.getRows(), blockRow+1, mc.getBlocksize());
1484+
int maxCol = UtilFunctions.computeBlockSize(mc.getCols(), blockCol+1, mc.getBlocksize());
1485+
//copy sub-matrix to block
1486+
MatrixBlock block = new MatrixBlock(maxRow, maxCol, mb.isInSparseFormat());
1487+
int row_offset = (int)blockRow*mc.getBlocksize();
1488+
int col_offset = (int)blockCol*mc.getBlocksize();
1489+
block = mb.slice( row_offset, row_offset+maxRow-1,
1490+
col_offset, col_offset+maxCol-1, false, block );
1491+
//create key-value pair
1492+
return new IndexedMatrixValue(new MatrixIndexes(blockRow+1, blockCol+1), block);
1493+
}
1494+
catch(DMLRuntimeException ex) {
1495+
throw new RuntimeException(ex);
1496+
}
1497+
}
14741498
}

src/test/java/org/apache/sysds/test/functions/ooc/lmDSTest.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.apache.sysds.test.functions.ooc;
2121

2222
import org.apache.sysds.common.Types;
23+
import org.apache.sysds.hops.rewrite.RewriteInjectOOCTee;
2324
import org.apache.sysds.runtime.io.MatrixWriter;
2425
import org.apache.sysds.runtime.io.MatrixWriterFactory;
2526
import org.apache.sysds.runtime.matrix.data.LibCommonsMath;
@@ -70,9 +71,12 @@ public void testlmDS2() {
7071
private void runMatrixVectorMultiplicationTest(int cols)
7172
{
7273
Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE);
73-
74+
boolean oldFlag = RewriteInjectOOCTee.APPLY_ONLY_XtX_PATTERN;
75+
7476
try
7577
{
78+
RewriteInjectOOCTee.APPLY_ONLY_XtX_PATTERN = true;
79+
7680
getAndLoadTestConfiguration(TEST_NAME1);
7781
String HOME = SCRIPT_DIR + TEST_DIR;
7882
fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
@@ -117,6 +121,7 @@ private void runMatrixVectorMultiplicationTest(int cols)
117121
}
118122
finally {
119123
resetExecMode(platformOld);
124+
RewriteInjectOOCTee.APPLY_ONLY_XtX_PATTERN = oldFlag;
120125
}
121126
}
122127
}

src/test/scripts/functions/ooc/lmDS.dml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,5 @@ y = read($2)
2525
XtX = t(X) %*% X; # 500 x 500
2626
Xty = t(X) %*% y; # 500 x 1
2727
R = solve(XtX, Xty)
28+
print(sum(R!=0))
2829
write(R, $3, format="binary")

0 commit comments

Comments
 (0)