From e12989b3b861841358f2d69b70f5889116577196 Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Thu, 18 Dec 2025 16:28:21 +0100 Subject: [PATCH] Unification of Cache and Source Reads to prevent Eviction of Disk Backed Tiles --- .../instructions/ooc/CachingStream.java | 22 +- .../ooc/ReblockOOCInstruction.java | 56 +--- .../runtime/ooc/cache/OOCCacheManager.java | 26 ++ .../runtime/ooc/cache/OOCCacheScheduler.java | 22 ++ .../sysds/runtime/ooc/cache/OOCIOHandler.java | 82 ++++++ .../ooc/cache/OOCLRUCacheScheduler.java | 28 +- .../runtime/ooc/cache/OOCMatrixIOHandler.java | 276 +++++++++++++++++- .../runtime/ooc/stream/OOCSourceStream.java | 52 ++++ .../cache/SourceBackedCacheSchedulerTest.java | 106 +++++++ .../SourceBackedReadOOCIOHandlerTest.java | 100 +++++++ .../ooc/SourceReadOOCIOHandlerTest.java | 143 +++++++++ 11 files changed, 855 insertions(+), 58 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/ooc/stream/OOCSourceStream.java create mode 100644 src/test/java/org/apache/sysds/runtime/ooc/cache/SourceBackedCacheSchedulerTest.java create mode 100644 src/test/java/org/apache/sysds/runtime/ooc/cache/SourceBackedReadOOCIOHandlerTest.java create mode 100644 src/test/java/org/apache/sysds/test/functions/ooc/SourceReadOOCIOHandlerTest.java diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java index 9cda04e0c77..f9869b20f9a 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java @@ -25,7 +25,9 @@ import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; import org.apache.sysds.runtime.ooc.cache.BlockKey; +import org.apache.sysds.runtime.ooc.cache.OOCIOHandler; import org.apache.sysds.runtime.ooc.cache.OOCCacheManager; +import org.apache.sysds.runtime.ooc.stream.OOCSourceStream; import shaded.parquet.it.unimi.dsi.fastutil.ints.IntArrayList; import java.util.HashMap; @@ -89,10 +91,22 @@ public CachingStream(OOCStream source, long streamId) { if(task != LocalTaskQueue.NO_MORE_TASKS) { if (!_cacheInProgress) throw new DMLRuntimeException("Stream is closed"); - if (mSubscribers == null || mSubscribers.length == 0) - OOCCacheManager.put(_streamId, _numBlocks, task); - else - mCallback = OOCCacheManager.putAndPin(_streamId, _numBlocks, task); + OOCIOHandler.SourceBlockDescriptor descriptor = null; + if (_source instanceof OOCSourceStream src) { + descriptor = src.getDescriptor(task.getIndexes()); + } + if (descriptor == null) { + if (mSubscribers == null || mSubscribers.length == 0) + OOCCacheManager.put(_streamId, _numBlocks, task); + else + mCallback = OOCCacheManager.putAndPin(_streamId, _numBlocks, task); + } + else { + if (mSubscribers == null || mSubscribers.length == 0) + OOCCacheManager.putSourceBacked(_streamId, _numBlocks, task, descriptor); + else + mCallback = OOCCacheManager.putAndPinSourceBacked(_streamId, _numBlocks, task, descriptor); + } if (_index != null) _index.put(task.getIndexes(), _numBlocks); blk = _numBlocks; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java index 74b15c9fb0e..f744b97506b 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java @@ -19,24 +19,19 @@ package org.apache.sysds.runtime.instructions.ooc; -import org.apache.hadoop.fs.FileSystem; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.io.SequenceFile; -import org.apache.hadoop.mapred.JobConf; import org.apache.sysds.common.Opcodes; -import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.common.Types; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; -import org.apache.sysds.runtime.io.IOUtilFunctions; -import org.apache.sysds.runtime.io.MatrixReader; -import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.matrix.data.MatrixIndexes; import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.meta.DataCharacteristics; +import org.apache.sysds.runtime.ooc.cache.OOCCacheManager; +import org.apache.sysds.runtime.ooc.cache.OOCIOHandler; +import org.apache.sysds.runtime.ooc.stream.OOCSourceStream; public class ReblockOOCInstruction extends ComputationOOCInstruction { private int blen; @@ -74,40 +69,19 @@ public void processInstruction(ExecutionContext ec) { //TODO support other formats than binary //create queue, spawn thread for asynchronous reading, and return - OOCStream q = createWritableStream(); - submitOOCTask(() -> readBinaryBlock(q, min.getFileName()), q); + OOCStream q = new OOCSourceStream(); + OOCIOHandler io = OOCCacheManager.getIOHandler(); + OOCIOHandler.SourceReadRequest req = new OOCIOHandler.SourceReadRequest( + min.getFileName(), Types.FileFormat.BINARY, mc.getRows(), mc.getCols(), blen, mc.getNonZeros(), + Long.MAX_VALUE, true, q); + io.scheduleSourceRead(req).whenComplete((res, err) -> { + if (err != null) { + Exception ex = err instanceof Exception ? (Exception) err : new Exception(err); + q.propagateFailure(new DMLRuntimeException(ex)); + } + }); MatrixObject mout = ec.getMatrixObject(output); mout.setStreamHandle(q); } - - @SuppressWarnings("resource") - private void readBinaryBlock(OOCStream q, String fname) { - try { - //prepare file access - JobConf job = new JobConf(ConfigurationManager.getCachedJobConf()); - Path path = new Path( fname ); - FileSystem fs = IOUtilFunctions.getFileSystem(path, job); - - //check existence and non-empty file - MatrixReader.checkValidInputFile(fs, path); - - //core reading - for( Path lpath : IOUtilFunctions.getSequenceFilePaths(fs, path) ) { //1..N files - //directly read from sequence files (individual partfiles) - try( SequenceFile.Reader reader = new SequenceFile - .Reader(job, SequenceFile.Reader.file(lpath)) ) - { - MatrixIndexes key = new MatrixIndexes(); - MatrixBlock value = new MatrixBlock(); - while( reader.next(key, value) ) - q.enqueue(new IndexedMatrixValue(key, new MatrixBlock(value))); - } - } - q.closeInput(); - } - catch(Exception ex) { - throw new DMLRuntimeException(ex); - } - } } diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java index 50b5cf78218..bbf4cfb314c 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java @@ -100,6 +100,15 @@ public static OOCCacheScheduler getCache() { } } + public static OOCIOHandler getIOHandler() { + OOCIOHandler io = _ioHandler.get(); + if(io != null) + return io; + // Ensure initialization happens + getCache(); + return _ioHandler.get(); + } + /** * Removes a block from the cache without setting its data to null. */ @@ -116,11 +125,28 @@ public static void put(long streamId, int blockId, IndexedMatrixValue value) { getCache().put(key, value, ((MatrixBlock)value.getValue()).getExactSerializedSize()); } + /** + * Store a source-backed block in the OOC cache and register its source location. + */ + public static void putSourceBacked(long streamId, int blockId, IndexedMatrixValue value, + OOCIOHandler.SourceBlockDescriptor descriptor) { + BlockKey key = new BlockKey(streamId, blockId); + getCache().putSourceBacked(key, value, ((MatrixBlock) value.getValue()).getExactSerializedSize(), descriptor); + } + public static OOCStream.QueueCallback putAndPin(long streamId, int blockId, IndexedMatrixValue value) { BlockKey key = new BlockKey(streamId, blockId); return new CachedQueueCallback<>(getCache().putAndPin(key, value, ((MatrixBlock)value.getValue()).getExactSerializedSize()), null); } + public static OOCStream.QueueCallback putAndPinSourceBacked(long streamId, int blockId, + IndexedMatrixValue value, OOCIOHandler.SourceBlockDescriptor descriptor) { + BlockKey key = new BlockKey(streamId, blockId); + return new CachedQueueCallback<>( + getCache().putAndPinSourceBacked(key, value, ((MatrixBlock) value.getValue()).getExactSerializedSize(), + descriptor), null); + } + public static CompletableFuture> requestBlock(long streamId, long blockId) { BlockKey key = new BlockKey(streamId, blockId); return getCache().request(key).thenApply(e -> new CachedQueueCallback<>(e, null)); diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java index 5346b819cfe..cd04f9879aa 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java @@ -56,6 +56,28 @@ public interface OOCCacheScheduler { */ BlockEntry putAndPin(BlockKey key, Object data, long size); + /** + * Places a new source-backed block in the cache and registers the location with the IO handler. The entry is + * treated as backed by disk, so eviction does not schedule spill writes. + * + * @param key the associated key of the block + * @param data the block data + * @param size the size of the data + * @param descriptor the source location descriptor + */ + void putSourceBacked(BlockKey key, Object data, long size, OOCIOHandler.SourceBlockDescriptor descriptor); + + /** + * Places a new source-backed block in the cache and returns a pinned handle. + * + * @param key the associated key of the block + * @param data the block data + * @param size the size of the data + * @param descriptor the source location descriptor + */ + BlockEntry putAndPinSourceBacked(BlockKey key, Object data, long size, + OOCIOHandler.SourceBlockDescriptor descriptor); + /** * Forgets a block from the cache. * @param key the associated key of the block diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCIOHandler.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCIOHandler.java index dbfda4e56d7..b4d14646e0e 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCIOHandler.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCIOHandler.java @@ -20,6 +20,7 @@ package org.apache.sysds.runtime.ooc.cache; import java.util.concurrent.CompletableFuture; +import java.util.List; public interface OOCIOHandler { void shutdown(); @@ -29,4 +30,85 @@ public interface OOCIOHandler { CompletableFuture scheduleRead(BlockEntry block); CompletableFuture scheduleDeletion(BlockEntry block); + + /** + * Registers the source location of a block for future direct reads. + */ + void registerSourceLocation(BlockKey key, SourceBlockDescriptor descriptor); + + /** + * Schedule an asynchronous read from an external source into the provided target stream. + * The returned future completes when either EOF is reached or the requested byte budget + * is exhausted. When the budget is reached and keepOpenOnLimit is true, the target stream + * is kept open and a continuation token is provided so the caller can resume. + */ + CompletableFuture scheduleSourceRead(SourceReadRequest request); + + /** + * Continue a previously throttled source read using the provided continuation token. + */ + CompletableFuture continueSourceRead(SourceReadContinuation continuation, long maxBytesInFlight); + + interface SourceReadContinuation {} + + class SourceReadRequest { + public final String path; + public final org.apache.sysds.common.Types.FileFormat format; + public final long rows; + public final long cols; + public final int blen; + public final long estNnz; + public final long maxBytesInFlight; + public final boolean keepOpenOnLimit; + public final org.apache.sysds.runtime.instructions.ooc.OOCStream target; + + public SourceReadRequest(String path, org.apache.sysds.common.Types.FileFormat format, long rows, long cols, + int blen, long estNnz, long maxBytesInFlight, boolean keepOpenOnLimit, + org.apache.sysds.runtime.instructions.ooc.OOCStream target) { + this.path = path; + this.format = format; + this.rows = rows; + this.cols = cols; + this.blen = blen; + this.estNnz = estNnz; + this.maxBytesInFlight = maxBytesInFlight; + this.keepOpenOnLimit = keepOpenOnLimit; + this.target = target; + } + } + + class SourceReadResult { + public final long bytesRead; + public final boolean eof; + public final SourceReadContinuation continuation; + public final List blocks; + + public SourceReadResult(long bytesRead, boolean eof, SourceReadContinuation continuation, + List blocks) { + this.bytesRead = bytesRead; + this.eof = eof; + this.continuation = continuation; + this.blocks = blocks; + } + } + + class SourceBlockDescriptor { + public final String path; + public final org.apache.sysds.common.Types.FileFormat format; + public final org.apache.sysds.runtime.matrix.data.MatrixIndexes indexes; + public final long offset; + public final int recordLength; + public final long serializedSize; + + public SourceBlockDescriptor(String path, org.apache.sysds.common.Types.FileFormat format, + org.apache.sysds.runtime.matrix.data.MatrixIndexes indexes, long offset, int recordLength, + long serializedSize) { + this.path = path; + this.format = format; + this.indexes = indexes; + this.offset = offset; + this.recordLength = recordLength; + this.serializedSize = serializedSize; + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java index 1dbba2e3d8f..0f30914770a 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java @@ -169,22 +169,36 @@ private void scheduleDeferredRead(DeferredReadRequest deferredReadRequest) { @Override public void put(BlockKey key, Object data, long size) { - put(key, data, size, false); + put(key, data, size, false, null); } @Override public BlockEntry putAndPin(BlockKey key, Object data, long size) { - return put(key, data, size, true); + return put(key, data, size, true, null); } - private BlockEntry put(BlockKey key, Object data, long size, boolean pin) { + @Override + public void putSourceBacked(BlockKey key, Object data, long size, OOCIOHandler.SourceBlockDescriptor descriptor) { + put(key, data, size, false, descriptor); + } + + @Override + public BlockEntry putAndPinSourceBacked(BlockKey key, Object data, long size, OOCIOHandler.SourceBlockDescriptor descriptor) { + return put(key, data, size, true, descriptor); + } + + private BlockEntry put(BlockKey key, Object data, long size, boolean pin, OOCIOHandler.SourceBlockDescriptor descriptor) { if (!this._running) throw new IllegalStateException(); if (data == null) throw new IllegalArgumentException(); + if (descriptor != null) + _ioHandler.registerSourceLocation(key, descriptor); Statistics.incrementOOCEvictionPut(); BlockEntry entry = new BlockEntry(key, size, data); + if (descriptor != null) + entry.setState(BlockState.WARM); if (pin) entry.pin(); synchronized(this) { @@ -301,15 +315,15 @@ private void onCacheSizeChanged(boolean incr) { } private synchronized void sanityCheck() { - if (_cacheSize > _hardLimit) { + if (_cacheSize > _hardLimit * 1.1) { if (!_warnThrottling) { _warnThrottling = true; - System.out.println("[INFO] Throttling: " + _cacheSize/1000 + "KB - " + _bytesUpForEviction/1000 + "KB > " + _hardLimit/1000 + "KB"); + System.out.println("[WARN] Cache hard limit exceeded by over 10%: " + String.format("%.2f", _cacheSize/1000000.0) + "MB (-" + String.format("%.2f", _bytesUpForEviction/1000000.0) + "MB) > " + String.format("%.2f", _hardLimit/1000000.0) + "MB"); } } - else if (_warnThrottling) { + else if (_warnThrottling && _cacheSize < _hardLimit) { _warnThrottling = false; - System.out.println("[INFO] No more throttling: " + _cacheSize/1000 + "KB - " + _bytesUpForEviction/1000 + "KB <= " + _hardLimit/1000 + "KB"); + System.out.println("[INFO] Cache within limit: " + String.format("%.2f", _cacheSize/1000000.0) + "MB (-" + String.format("%.2f", _bytesUpForEviction/1000000.0) + "MB) <= " + String.format("%.2f", _hardLimit/1000000.0) + "MB"); } if (!SANITY_CHECKS) diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCMatrixIOHandler.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCMatrixIOHandler.java index 3cd16272d2b..a9da3ccd294 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCMatrixIOHandler.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCMatrixIOHandler.java @@ -20,12 +20,20 @@ package org.apache.sysds.runtime.ooc.cache; import org.apache.sysds.api.DMLScript; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.mapred.JobConf; +import org.apache.sysds.common.Types; +import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.io.IOUtilFunctions; +import org.apache.sysds.runtime.io.MatrixReader; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; import org.apache.sysds.runtime.ooc.stats.OOCEventLog; +import org.apache.sysds.runtime.ooc.stream.OOCSourceStream; import org.apache.sysds.runtime.util.FastBufferedDataInputStream; import org.apache.sysds.runtime.util.FastBufferedDataOutputStream; import org.apache.sysds.runtime.util.LocalFileUtils; @@ -40,6 +48,9 @@ import java.io.RandomAccessFile; import java.nio.channels.Channels; import java.nio.channels.ClosedByInterruptException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; @@ -50,9 +61,13 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicIntegerArray; +import java.util.concurrent.atomic.AtomicLongArray; +import java.util.concurrent.atomic.AtomicReference; public class OOCMatrixIOHandler implements OOCIOHandler { - private static final int WRITER_SIZE = 2; + private static final int WRITER_SIZE = 4; + private static final int READER_SIZE = 10; private static final long OVERFLOW = 8192 * 1024; private static final long MAX_PARTITION_SIZE = 8192 * 8192; @@ -63,6 +78,7 @@ public class OOCMatrixIOHandler implements OOCIOHandler { // Spill related structures private final ConcurrentHashMap _spillLocations = new ConcurrentHashMap<>(); private final ConcurrentHashMap _partitions = new ConcurrentHashMap<>(); + private final ConcurrentHashMap _sourceLocations = new ConcurrentHashMap<>(); private final AtomicInteger _partitionCounter = new AtomicInteger(0); private final CloseableQueue>>[] _q; private final AtomicLong _wCtr; @@ -70,6 +86,7 @@ public class OOCMatrixIOHandler implements OOCIOHandler { private final int _evictCallerId = OOCEventLog.registerCaller("write"); private final int _readCallerId = OOCEventLog.registerCaller("read"); + private final int _srcReadCallerId = OOCEventLog.registerCaller("read_src"); @SuppressWarnings("unchecked") public OOCMatrixIOHandler() { @@ -81,8 +98,8 @@ public OOCMatrixIOHandler() { TimeUnit.MILLISECONDS, new ArrayBlockingQueue<>(100000)); _readExec = new ThreadPoolExecutor( - 5, - 5, + READER_SIZE, + READER_SIZE, 0L, TimeUnit.MILLISECONDS, new ArrayBlockingQueue<>(100000)); @@ -161,14 +178,225 @@ public CompletableFuture scheduleRead(final BlockEntry block) { @Override public CompletableFuture scheduleDeletion(BlockEntry block) { - // TODO + _sourceLocations.remove(block.getKey()); return CompletableFuture.completedFuture(true); } + @Override + public void registerSourceLocation(BlockKey key, SourceBlockDescriptor descriptor) { + _sourceLocations.put(key, descriptor); + } + + @Override + public CompletableFuture scheduleSourceRead(SourceReadRequest request) { + return submitSourceRead(request, null, request.maxBytesInFlight); + } + + @Override + public CompletableFuture continueSourceRead(SourceReadContinuation continuation, long maxBytesInFlight) { + if (!(continuation instanceof SourceReadState state)) { + CompletableFuture failed = new CompletableFuture<>(); + failed.completeExceptionally(new DMLRuntimeException("Unsupported continuation type: " + continuation)); + return failed; + } + return submitSourceRead(state.request, state, maxBytesInFlight); + } + + private CompletableFuture submitSourceRead(SourceReadRequest request, SourceReadState state, + long maxBytesInFlight) { + if(request.format != Types.FileFormat.BINARY) + return CompletableFuture.failedFuture( + new DMLRuntimeException("Unsupported format for source read: " + request.format)); + return readBinarySourceParallel(request, state, maxBytesInFlight); + } + + private CompletableFuture readBinarySourceParallel(SourceReadRequest request, + SourceReadState state, long maxBytesInFlight) { + final long byteLimit = maxBytesInFlight > 0 ? maxBytesInFlight : Long.MAX_VALUE; + final AtomicLong bytesRead = new AtomicLong(0); + final AtomicBoolean stop = new AtomicBoolean(false); + final AtomicBoolean budgetHit = new AtomicBoolean(false); + final AtomicReference error = new AtomicReference<>(); + final Object budgetLock = new Object(); + final CompletableFuture result = new CompletableFuture<>(); + final ConcurrentLinkedDeque descriptors = new ConcurrentLinkedDeque<>(); + + JobConf job = new JobConf(ConfigurationManager.getCachedJobConf()); + Path path = new Path(request.path); + + Path[] files; + AtomicLongArray filePositions; + AtomicIntegerArray completed; + + try { + FileSystem fs = IOUtilFunctions.getFileSystem(path, job); + MatrixReader.checkValidInputFile(fs, path); + + if(state == null) { + List seqFiles = new ArrayList<>(Arrays.asList(IOUtilFunctions.getSequenceFilePaths(fs, path))); + files = seqFiles.toArray(Path[]::new); + filePositions = new AtomicLongArray(files.length); + completed = new AtomicIntegerArray(files.length); + } + else { + files = state.paths; + filePositions = state.filePositions; + completed = state.completed; + } + } + catch(IOException e) { + throw new DMLRuntimeException(e); + } + + int activeTasks = 0; + for(int i = 0; i < files.length; i++) + if(completed.get(i) == 0) + activeTasks++; + + final AtomicInteger remaining = new AtomicInteger(activeTasks); + boolean anyTask = activeTasks > 0; + + for(int i = 0; i < files.length; i++) { + if(completed.get(i) == 1) + continue; + final int fileIdx = i; + try { + _readExec.submit(() -> { + try { + readSequenceFile(job, files[fileIdx], request, fileIdx, filePositions, completed, stop, + budgetHit, bytesRead, byteLimit, budgetLock, descriptors); + } + catch(Throwable t) { + error.compareAndSet(null, t); + stop.set(true); + } + finally { + if(remaining.decrementAndGet() == 0) + completeResult(result, bytesRead, budgetHit, error, request, files, filePositions, + completed, descriptors); + } + }); + } + catch(RejectedExecutionException e) { + error.compareAndSet(null, e); + stop.set(true); + if(remaining.decrementAndGet() == 0) + completeResult(result, bytesRead, budgetHit, error, request, files, filePositions, completed, + descriptors); + break; + } + } + + if(!anyTask) { + tryCloseTarget(request.target, true); + result.complete(new SourceReadResult(bytesRead.get(), true, null, List.of())); + } + + return result; + } + + private void completeResult(CompletableFuture future, AtomicLong bytesRead, AtomicBoolean budgetHit, + AtomicReference error, SourceReadRequest request, Path[] files, AtomicLongArray filePositions, + AtomicIntegerArray completed, ConcurrentLinkedDeque descriptors) { + Throwable err = error.get(); + if (err != null) { + future.completeExceptionally(err instanceof Exception ? err : new Exception(err)); + return; + } + + if (budgetHit.get()) { + if (!request.keepOpenOnLimit) + tryCloseTarget(request.target, false); + SourceReadContinuation cont = new SourceReadState(request, files, filePositions, completed); + future.complete(new SourceReadResult(bytesRead.get(), false, cont, new ArrayList<>(descriptors))); + return; + } + + tryCloseTarget(request.target, true); + future.complete(new SourceReadResult(bytesRead.get(), true, null, new ArrayList<>(descriptors))); + } + + private void readSequenceFile(JobConf job, Path path, SourceReadRequest request, int fileIdx, + AtomicLongArray filePositions, AtomicIntegerArray completed, AtomicBoolean stop, AtomicBoolean budgetHit, + AtomicLong bytesRead, long byteLimit, Object budgetLock, ConcurrentLinkedDeque descriptors) + throws IOException { + MatrixIndexes key = new MatrixIndexes(); + MatrixBlock value = new MatrixBlock(); + + try(SequenceFile.Reader reader = new SequenceFile.Reader(job, SequenceFile.Reader.file(path))) { + long pos = filePositions.get(fileIdx); + if (pos > 0) + reader.seek(pos); + + long ioStart = DMLScript.OOC_LOG_EVENTS ? System.nanoTime() : 0; + while(!stop.get()) { + long recordStart = reader.getPosition(); + if (!reader.next(key, value)) + break; + long recordEnd = reader.getPosition(); + long blockSize = value.getExactSerializedSize(); + boolean shouldBreak = false; + + synchronized(budgetLock) { + if (stop.get()) + shouldBreak = true; + else if (bytesRead.get() + blockSize > byteLimit) { + stop.set(true); + budgetHit.set(true); + shouldBreak = true; + } + bytesRead.addAndGet(blockSize); + } + + MatrixIndexes outIdx = new MatrixIndexes(key); + MatrixBlock outBlk = new MatrixBlock(value); + IndexedMatrixValue imv = new IndexedMatrixValue(outIdx, outBlk); + SourceBlockDescriptor descriptor = new SourceBlockDescriptor(path.toString(), request.format, outIdx, + recordStart, (int)(recordEnd - recordStart), blockSize); + + if (request.target instanceof OOCSourceStream src) + src.enqueue(imv, descriptor); + else + request.target.enqueue(imv); + + descriptors.add(descriptor); + filePositions.set(fileIdx, reader.getPosition()); + + if (DMLScript.OOC_LOG_EVENTS) { + long currTime = System.nanoTime(); + OOCEventLog.onDiskReadEvent(_srcReadCallerId, ioStart, currTime, blockSize); + ioStart = currTime; + } + + if (shouldBreak) + break; // Note that we knowingly go over limit, which could result in READER_SIZE*8MB overshoot + } + + if (!stop.get()) + completed.set(fileIdx, 1); + } + } + + private void tryCloseTarget(org.apache.sysds.runtime.instructions.ooc.OOCStream target, boolean close) { + if (close) { + try { + target.closeInput(); + } + catch(Exception ignored) { + } + } + } + private void loadFromDisk(BlockEntry block) { String key = block.getKey().toFileKey(); + SourceBlockDescriptor src = _sourceLocations.get(block.getKey()); + if (src != null) { + loadFromSource(block, src); + return; + } + long ioDuration = 0; // 1. find the blocks address (spill location) SpillLocation sloc = _spillLocations.get(key); @@ -207,6 +435,28 @@ private void loadFromDisk(BlockEntry block) { } } + private void loadFromSource(BlockEntry block, SourceBlockDescriptor src) { + if (src.format != Types.FileFormat.BINARY) + throw new DMLRuntimeException("Unsupported format for source read: " + src.format); + + JobConf job = new JobConf(ConfigurationManager.getCachedJobConf()); + Path path = new Path(src.path); + + MatrixIndexes ix = new MatrixIndexes(); + MatrixBlock mb = new MatrixBlock(); + + try(SequenceFile.Reader reader = new SequenceFile.Reader(job, SequenceFile.Reader.file(path))) { + reader.seek(src.offset); + if (!reader.next(ix, mb)) + throw new DMLRuntimeException("Failed to read source block at offset " + src.offset + " in " + src.path); + } + catch(IOException e) { + throw new DMLRuntimeException(e); + } + + block.setDataUnsafe(new IndexedMatrixValue(ix, mb)); + } + private void evictTask(CloseableQueue>> q) { long byteCtr = 0; @@ -276,8 +526,7 @@ private void evictTask(CloseableQueue catch(IOException | InterruptedException ex) { throw new DMLRuntimeException(ex); } - catch(Exception e) { - // TODO + catch(Exception ignored) { } finally { IOUtilFunctions.closeSilently(dos); @@ -356,4 +605,19 @@ public int getCount() { return _count; } } + + private static class SourceReadState implements SourceReadContinuation { + final SourceReadRequest request; + final Path[] paths; + final AtomicLongArray filePositions; + final AtomicIntegerArray completed; + + SourceReadState(SourceReadRequest request, Path[] paths, AtomicLongArray filePositions, + AtomicIntegerArray completed) { + this.request = request; + this.paths = paths; + this.filePositions = filePositions; + this.completed = completed; + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/ooc/stream/OOCSourceStream.java b/src/main/java/org/apache/sysds/runtime/ooc/stream/OOCSourceStream.java new file mode 100644 index 00000000000..c48aaa45ab2 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/stream/OOCSourceStream.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.stream; + +import org.apache.sysds.runtime.instructions.ooc.SubscribableTaskQueue; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.ooc.cache.OOCIOHandler; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; + +import java.util.concurrent.ConcurrentHashMap; + +public class OOCSourceStream extends SubscribableTaskQueue { + private final ConcurrentHashMap _idx; + + public OOCSourceStream() { + this._idx = new ConcurrentHashMap<>(); + } + + public void enqueue(IndexedMatrixValue value, OOCIOHandler.SourceBlockDescriptor descriptor) { + if(descriptor == null) + throw new IllegalArgumentException("Source descriptor must not be null"); + MatrixIndexes key = new MatrixIndexes(descriptor.indexes); + _idx.put(key, descriptor); + super.enqueue(value); + } + + @Override + public void enqueue(IndexedMatrixValue val) { + throw new UnsupportedOperationException("Use enqueue(value, descriptor) for source streams"); + } + + public OOCIOHandler.SourceBlockDescriptor getDescriptor(MatrixIndexes indexes) { + return _idx.get(indexes); + } +} diff --git a/src/test/java/org/apache/sysds/runtime/ooc/cache/SourceBackedCacheSchedulerTest.java b/src/test/java/org/apache/sysds/runtime/ooc/cache/SourceBackedCacheSchedulerTest.java new file mode 100644 index 00000000000..423c2b7f425 --- /dev/null +++ b/src/test/java/org/apache/sysds/runtime/ooc/cache/SourceBackedCacheSchedulerTest.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.cache; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.instructions.ooc.SubscribableTaskQueue; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class SourceBackedCacheSchedulerTest extends AutomatedTestBase { + private static final String TEST_NAME = "SourceBackedCacheScheduler"; + private static final String TEST_DIR = "functions/ooc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + SourceBackedCacheSchedulerTest.class.getSimpleName() + "/"; + + private OOCMatrixIOHandler handler; + private OOCLRUCacheScheduler scheduler; + + @Override + @Before + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); + handler = new OOCMatrixIOHandler(); + scheduler = new OOCLRUCacheScheduler(handler, 0, Long.MAX_VALUE); + } + + @After + public void tearDown() { + if (scheduler != null) + scheduler.shutdown(); + if (handler != null) + handler.shutdown(); + } + + @Test + public void testPutSourceBackedAndReload() throws Exception { + getAndLoadTestConfiguration(TEST_NAME); + final int rows = 4; + final int cols = 4; + final int blen = 2; + + MatrixBlock src = MatrixBlock.randOperations(rows, cols, 1.0, -1, 1, "uniform", 23); + String fname = input("binary_src_cache"); + writeBinaryMatrix(src, fname, blen); + + SubscribableTaskQueue target = new SubscribableTaskQueue<>(); + OOCIOHandler.SourceReadRequest req = new OOCIOHandler.SourceReadRequest(fname, Types.FileFormat.BINARY, + rows, cols, blen, src.getNonZeros(), Long.MAX_VALUE, true, target); + + OOCIOHandler.SourceReadResult res = handler.scheduleSourceRead(req).get(); + IndexedMatrixValue imv = target.dequeue(); + OOCIOHandler.SourceBlockDescriptor desc = res.blocks.get(0); + + BlockKey key = new BlockKey(11, 0); + BlockEntry entry = scheduler.putAndPinSourceBacked(key, imv, + ((MatrixBlock) imv.getValue()).getExactSerializedSize(), desc); + org.junit.Assert.assertEquals(BlockState.WARM, entry.getState()); + + scheduler.unpin(entry); + org.junit.Assert.assertEquals(BlockState.COLD, entry.getState()); + org.junit.Assert.assertNull(entry.getDataUnsafe()); + + BlockEntry reloaded = scheduler.request(key).get(); + IndexedMatrixValue reloadImv = (IndexedMatrixValue) reloaded.getData(); + MatrixBlock expected = expectedBlock(src, desc.indexes, blen); + TestUtils.compareMatrices(expected, (MatrixBlock) reloadImv.getValue(), 1e-12); + } + + private void writeBinaryMatrix(MatrixBlock mb, String fname, int blen) throws Exception { + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + writer.writeMatrixToHDFS(mb, fname, mb.getNumRows(), mb.getNumColumns(), blen, mb.getNonZeros()); + } + + private MatrixBlock expectedBlock(MatrixBlock src, org.apache.sysds.runtime.matrix.data.MatrixIndexes idx, int blen) { + int rowStart = (int) ((idx.getRowIndex() - 1) * blen); + int colStart = (int) ((idx.getColumnIndex() - 1) * blen); + int rowEnd = Math.min(rowStart + blen - 1, src.getNumRows() - 1); + int colEnd = Math.min(colStart + blen - 1, src.getNumColumns() - 1); + return src.slice(rowStart, rowEnd, colStart, colEnd); + } +} diff --git a/src/test/java/org/apache/sysds/runtime/ooc/cache/SourceBackedReadOOCIOHandlerTest.java b/src/test/java/org/apache/sysds/runtime/ooc/cache/SourceBackedReadOOCIOHandlerTest.java new file mode 100644 index 00000000000..e688bf0f1c0 --- /dev/null +++ b/src/test/java/org/apache/sysds/runtime/ooc/cache/SourceBackedReadOOCIOHandlerTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.cache; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.instructions.ooc.SubscribableTaskQueue; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class SourceBackedReadOOCIOHandlerTest extends AutomatedTestBase { + private static final String TEST_NAME = "SourceBackedReadOOCIOHandler"; + private static final String TEST_DIR = "functions/ooc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + SourceBackedReadOOCIOHandlerTest.class.getSimpleName() + "/"; + + private OOCMatrixIOHandler handler; + + @Override + @Before + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); + handler = new OOCMatrixIOHandler(); + } + + @After + public void tearDown() { + if (handler != null) + handler.shutdown(); + } + + @Test + public void testSourceBackedScheduleRead() throws Exception { + getAndLoadTestConfiguration(TEST_NAME); + final int rows = 4; + final int cols = 4; + final int blen = 2; + + MatrixBlock src = MatrixBlock.randOperations(rows, cols, 1.0, -1, 1, "uniform", 17); + String fname = input("binary_src"); + writeBinaryMatrix(src, fname, blen); + + SubscribableTaskQueue target = new SubscribableTaskQueue<>(); + OOCIOHandler.SourceReadRequest req = new OOCIOHandler.SourceReadRequest(fname, Types.FileFormat.BINARY, + rows, cols, blen, src.getNonZeros(), Long.MAX_VALUE, true, target); + + OOCIOHandler.SourceReadResult res = handler.scheduleSourceRead(req).get(); + org.junit.Assert.assertFalse(res.blocks.isEmpty()); + + OOCIOHandler.SourceBlockDescriptor desc = res.blocks.get(0); + BlockKey key = new BlockKey(7, 0); + handler.registerSourceLocation(key, desc); + + BlockEntry entry = new BlockEntry(key, desc.serializedSize, null); + entry.setState(BlockState.COLD); + handler.scheduleRead(entry).get(); + + IndexedMatrixValue imv = (IndexedMatrixValue) entry.getDataUnsafe(); + MatrixBlock readBlock = (MatrixBlock) imv.getValue(); + MatrixBlock expected = expectedBlock(src, desc.indexes, blen); + TestUtils.compareMatrices(expected, readBlock, 1e-12); + } + + private void writeBinaryMatrix(MatrixBlock mb, String fname, int blen) throws Exception { + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + writer.writeMatrixToHDFS(mb, fname, mb.getNumRows(), mb.getNumColumns(), blen, mb.getNonZeros()); + } + + private MatrixBlock expectedBlock(MatrixBlock src, org.apache.sysds.runtime.matrix.data.MatrixIndexes idx, int blen) { + int rowStart = (int) ((idx.getRowIndex() - 1) * blen); + int colStart = (int) ((idx.getColumnIndex() - 1) * blen); + int rowEnd = Math.min(rowStart + blen - 1, src.getNumRows() - 1); + int colEnd = Math.min(colStart + blen - 1, src.getNumColumns() - 1); + return src.slice(rowStart, rowEnd, colStart, colEnd); + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/SourceReadOOCIOHandlerTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/SourceReadOOCIOHandlerTest.java new file mode 100644 index 00000000000..34dd01d6620 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/SourceReadOOCIOHandlerTest.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.instructions.ooc.SubscribableTaskQueue; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.ooc.cache.OOCIOHandler; +import org.apache.sysds.runtime.ooc.cache.OOCMatrixIOHandler; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +public class SourceReadOOCIOHandlerTest extends AutomatedTestBase { + private static final String TEST_NAME = "SourceReadOOCIOHandler"; + private static final String TEST_DIR = "functions/ooc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + SourceReadOOCIOHandlerTest.class.getSimpleName() + "/"; + + private OOCMatrixIOHandler handler; + + @Override + @Before + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); + handler = new OOCMatrixIOHandler(); + } + + @After + public void tearDown() { + if (handler != null) + handler.shutdown(); + } + + @Test + public void testSourceReadCompletes() throws Exception { + getAndLoadTestConfiguration(TEST_NAME); + final int rows = 4; + final int cols = 4; + final int blen = 2; + + MatrixBlock src = MatrixBlock.randOperations(rows, cols, 1.0, -1, 1, "uniform", 7); + String fname = input("binary_full"); + writeBinaryMatrix(src, fname, blen); + + SubscribableTaskQueue target = new SubscribableTaskQueue<>(); + OOCIOHandler.SourceReadRequest req = new OOCIOHandler.SourceReadRequest(fname, Types.FileFormat.BINARY, + rows, cols, blen, src.getNonZeros(), Long.MAX_VALUE, true, target); + + OOCIOHandler.SourceReadResult res = handler.scheduleSourceRead(req).get(); + // Drain after EOF + MatrixBlock reconstructed = drainToMatrix(target, rows, cols, blen); + + TestUtils.compareMatrices(src, reconstructed, 1e-12); + org.junit.Assert.assertTrue(res.eof); + org.junit.Assert.assertNull(res.continuation); + org.junit.Assert.assertNotNull(res.blocks); + org.junit.Assert.assertEquals((rows / blen) * (cols / blen), res.blocks.size()); + org.junit.Assert.assertTrue(res.blocks.stream().allMatch(b -> b.indexes != null)); + } + + @Test + public void testSourceReadStopsOnBudgetAndContinues() throws Exception { + getAndLoadTestConfiguration(TEST_NAME); + final int rows = 4; + final int cols = 4; + final int blen = 2; + + MatrixBlock src = MatrixBlock.randOperations(rows, cols, 1.0, -1, 1, "uniform", 13); + String fname = input("binary_budget"); + writeBinaryMatrix(src, fname, blen); + + long singleBlockSize = new MatrixBlock(blen, blen, false).getExactSerializedSize(); + long budget = singleBlockSize + 1; // ensure we stop before the second block + + SubscribableTaskQueue target = new SubscribableTaskQueue<>(); + OOCIOHandler.SourceReadRequest req = new OOCIOHandler.SourceReadRequest(fname, Types.FileFormat.BINARY, + rows, cols, blen, src.getNonZeros(), budget, true, target); + + OOCIOHandler.SourceReadResult first = handler.scheduleSourceRead(req).get(); + org.junit.Assert.assertFalse(first.eof); + org.junit.Assert.assertNotNull(first.continuation); + org.junit.Assert.assertNotNull(first.blocks); + + OOCIOHandler.SourceReadResult second = handler.continueSourceRead(first.continuation, Long.MAX_VALUE).get(); + org.junit.Assert.assertTrue(second.eof); + org.junit.Assert.assertNull(second.continuation); + org.junit.Assert.assertNotNull(second.blocks); + org.junit.Assert.assertEquals((rows / blen) * (cols / blen), first.blocks.size() + second.blocks.size()); + + MatrixBlock reconstructed = drainToMatrix(target, rows, cols, blen); + TestUtils.compareMatrices(src, reconstructed, 1e-12); + } + + private void writeBinaryMatrix(MatrixBlock mb, String fname, int blen) throws Exception { + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + writer.writeMatrixToHDFS(mb, fname, mb.getNumRows(), mb.getNumColumns(), blen, mb.getNonZeros()); + } + + private MatrixBlock drainToMatrix(SubscribableTaskQueue target, int rows, int cols, int blen) { + List blocks = new ArrayList<>(); + IndexedMatrixValue tmp; + while((tmp = target.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) { + blocks.add(tmp); + } + + MatrixBlock out = new MatrixBlock(rows, cols, false); + for (IndexedMatrixValue imv : blocks) { + int rowOffset = (int)((imv.getIndexes().getRowIndex() - 1) * blen); + int colOffset = (int)((imv.getIndexes().getColumnIndex() - 1) * blen); + ((MatrixBlock)imv.getValue()).putInto(out, rowOffset, colOffset, true); + } + out.recomputeNonZeros(); + return out; + } +}