Skip to content

Commit e400154

Browse files
committed
add placeholder logic for iotrace
1 parent d463874 commit e400154

File tree

5 files changed

+147
-8
lines changed

5 files changed

+147
-8
lines changed

src/main/java/org/apache/sysds/runtime/controlprogram/caching/UnifiedMemoryManager.java

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ public class UnifiedMemoryManager
112112
// Prescient policy
113113
private static PrescientPolicy _prescientPolicy;
114114
private static IOTrace _ioTrace;
115+
private static long _currentTime = 0;
115116

116117
// Pinned size of physical memory. Starts from 0 for each operation. Max is 70% of heap
117118
// This increases only if the input is not present in the cache and read from FS/rdd/fed/gpu
@@ -207,6 +208,66 @@ public static void cleanup() {
207208
_pinnedVirtualMemSize = 0;
208209
}
209210

211+
/**
212+
* Sets the I/O trace for the prescient policy.
213+
* This should be called once by the ExecutionContext after trace generation.
214+
* @param trace The generated IOTrace
215+
*/
216+
public static void setTrace(IOTrace trace) {
217+
_ioTrace = trace;
218+
if (_evictionPolicy instanceof PrescientPolicy) {
219+
_prescientPolicy = (PrescientPolicy) _evictionPolicy;
220+
_prescientPolicy.setTrace(_ioTrace);
221+
}
222+
else {
223+
// Optional: Log a warning if the trace is set but the policy isn't Prescient
224+
// LOG.warn("IOTrace was provided, but eviction policy is not prescient!");
225+
}
226+
}
227+
228+
/**
229+
* Updates the UMM's logical time.
230+
* This should be called by the ExecutionContext *before* each instruction.
231+
* @param logicalTime The new logical time
232+
*/
233+
public static void updateTime(long logicalTime) {
234+
_currentTime = logicalTime;
235+
prefetch();
236+
}
237+
238+
/**
239+
* Prefetches blocks that will be needed soon, based on the I/O trace.
240+
*/
241+
private static void prefetch() {
242+
if (_ioTrace == null || _prescientPolicy == null) {
243+
return; // No trace or policy, cannot prefetch
244+
}
245+
246+
// Get the list of blocks to prefetch from our policy
247+
List<String> blocksToPrefetch = _prescientPolicy.getBlocksToPrefetch(_currentTime);
248+
249+
// A real implementation MUST use an asynchronous thread pool
250+
// (e.g., from _fClean) to load these blocks without blocking the main thread.
251+
252+
for (String blockID : blocksToPrefetch) {
253+
synchronized (_mQueue) {
254+
// Check again inside lock if block was already loaded or pinned
255+
if (_mQueue.containsKey(blockID) || _pinnedEntries.contains(blockID)) {
256+
continue; // Already in memory
257+
}
258+
}
259+
260+
// --- This is a simplified version for now ---
261+
// TODO: Submit an async prefetch task to _fClean's thread pool
262+
// The task should: 1. Get block size (from metadata)
263+
// 2. Call makeSpace(blockSize)
264+
// 3. Load block from disk
265+
// 4. Add block to _mQueue (synchronized)
266+
267+
System.out.println("UMM PREFETCH [T="+_currentTime+"]: Planning to prefetch " + blockID);
268+
}
269+
}
270+
210271
/**
211272
* Print current status of UMM, including all entries.
212273
* NOTE: use only for debugging or testing.
@@ -312,10 +373,35 @@ public static int makeSpace(long reqSpace) {
312373
synchronized(_mQueue) {
313374
// Evict blobs to make room (by default FIFO)
314375
while (getUMMFree() < reqSpace && !_mQueue.isEmpty()) {
315-
//remove first unpinned entry from eviction queue
316-
var entry = _mQueue.removeFirstUnpinned(_pinnedEntries);
317-
String ftmp = entry.getKey();
318-
ByteBuffer bb = entry.getValue();
376+
// --- NEW PRESCIENT LOGIC ---
377+
String ftmp; // Block ID / filename to evict
378+
379+
if (_prescientPolicy != null && _ioTrace != null) {
380+
// Use prescient policy to find the best block to evict
381+
ftmp = _prescientPolicy.evict(_mQueue.keySet(), _pinnedEntries, _currentTime);
382+
} else {
383+
// Fallback to default LRU if prescient policy isn't set or has no trace
384+
var entry = _mQueue.removeFirstUnpinned(_pinnedEntries);
385+
ftmp = (entry != null) ? entry.getKey() : null;
386+
}
387+
388+
if (ftmp == null) {
389+
// Policy couldn't find a block to evict (e.g., all are pinned)
390+
if(!_pinnedEntries.containsAll(_mQueue.keySet())) {
391+
// This case should ideally not be reached if unpinned blocks exist
392+
throw new DMLRuntimeException("UMM: Eviction policy failed to find a candidate.");
393+
}
394+
// If we are here, all blocks are pinned, and we cannot make space.
395+
// The original exception will be thrown later.
396+
break; // Exit the while loop
397+
}
398+
399+
// Remove the chosen block from the queue
400+
ByteBuffer bb = _mQueue.remove(ftmp);
401+
// //remove first unpinned entry from eviction queue
402+
// var entry = _mQueue.removeFirstUnpinned(_pinnedEntries);
403+
// String ftmp = entry.getKey();
404+
// ByteBuffer bb = entry.getValue();
319405

320406
if(bb != null) {
321407
// Wait for pending serialization

src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTrace.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,8 @@ public IOTrace() {
4242
/**
4343
* Access to the block at a current time
4444
*/
45-
public void recordAccess(String blockID) {
46-
_trace.computeIfAbsent(blockID, k -> new ArrayList<>()).add(_currentTime);
47-
_currentTime++;
45+
public void recordAccess(String blockID, long logicalTime) {
46+
_trace.computeIfAbsent(blockID, k -> new ArrayList<>()).add(logicalTime);
4847
}
4948

5049
/**

src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTraceGenerator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ private static void processInstruction(Instruction inst, ExecutionContext ec, IO
145145
for (long j = 1; j <= numColBlocks; j++) {
146146

147147
String blockID = createBlockID(fname, i, j);
148-
trace.recordAccess(blockID);
148+
trace.recordAccess(blockID, logicalTime);
149149
}
150150
}
151151
}

src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicy.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.apache.sysds.runtime.controlprogram.caching.EvictionPolicy;
2323

2424
import java.util.HashMap;
25+
import java.util.List;
2526
import java.util.Map;
2627
import java.util.Set;
2728

@@ -32,6 +33,7 @@ public class PrescientPolicy implements EvictionPolicy {
3233

3334
// Map of block ID, access times
3435
private final Map<String, Long> accessTimeMap = new HashMap<>();
36+
private IOTrace _trace;
3537

3638
// register blocks with access time
3739
public void setAccessTime(String blockId, long accessTime) {
@@ -66,4 +68,41 @@ public String selectBlockForEviction(Set<String> candidates) {
6668
return selected;
6769
}
6870

71+
/**
72+
* Called by UMM's makeSpace() to decide which block to evict.
73+
* * @param cache The set of all block IDs currently in the buffer
74+
* @param pinned The list of all block IDs that are pinned
75+
* @param currentTime The current logical time
76+
* @return The block ID to evict
77+
*/
78+
public String evict(Set<String> cache, List<String> pinned, long currentTime) {
79+
// TODO: Implement "evict-furthest-in-future" logic here
80+
// 1. Iterate through every 'blockID' in 'cache'
81+
// 2. If 'blockID' is in 'pinned', ignore it.
82+
// 3. Use '_trace.getAccessTime(blockID)' to find its next access time > currentTime
83+
// 4. The block with the (largest next access time) or (no future access) is the winner.
84+
// 5. Return the winner's blockID.
85+
86+
return null; // Placeholder
87+
}
88+
89+
/**
90+
* Called by UMM's prefetch() to decide which blocks to load.
91+
* * @param currentTime The current logical time
92+
* @return A list of block IDs to prefetch
93+
*/
94+
public List<String> getBlocksToPrefetch(long currentTime) {
95+
// TODO: Implement prefetch logic here
96+
// 1. Define a "prefetch window" (e.g., time T+1 to T+5)
97+
// 2. Iterate through all blocks in '_trace.getTrace()'
98+
// 3. Check if a block has an access time within that window
99+
// 4. If yes, add it to a list.
100+
// 5. Return the list of blocks.
101+
102+
return java.util.Collections.emptyList(); // Placeholder
103+
}
104+
105+
public void setTrace(IOTrace ioTrace) {
106+
_trace = ioTrace;
107+
}
69108
}

src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import org.apache.commons.logging.Log;
2323
import org.apache.commons.logging.LogFactory;
24+
import org.apache.hadoop.yarn.webapp.hamlet2.HamletSpec;
2425
import org.apache.sysds.api.DMLScript;
2526
import org.apache.sysds.common.Types;
2627
import org.apache.sysds.common.Types.FileFormat;
@@ -37,6 +38,7 @@
3738
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
3839
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
3940
import org.apache.sysds.runtime.controlprogram.caching.TensorObject;
41+
import org.apache.sysds.runtime.controlprogram.caching.prescientbuffer.IOTrace;
4042
import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
4143
import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.SEALClient;
4244
import org.apache.sysds.runtime.data.TensorBlock;
@@ -93,6 +95,19 @@ public class ExecutionContext {
9395
//parfor temporary functions (created by eval)
9496
protected Set<String> _fnNames;
9597

98+
private IOTrace _ioTrace;
99+
100+
public IOTrace getIOTrace() {
101+
if (_ioTrace == null) {
102+
_ioTrace = new IOTrace();
103+
}
104+
return _ioTrace;
105+
}
106+
107+
public void setIOTrace(IOTrace ioTrace) {
108+
_ioTrace = ioTrace;
109+
}
110+
96111
/**
97112
* List of {@link GPUContext}s owned by this {@link ExecutionContext}
98113
*/

0 commit comments

Comments
 (0)