Skip to content

Commit c49e25d

Browse files
janniklindemboehm7
authored andcommitted
[SYSTEMDS-3891] Improved Stream Handling and PCA support
Closes #2368.
1 parent d78165b commit c49e25d

File tree

21 files changed

+1060
-272
lines changed

21 files changed

+1060
-272
lines changed

src/main/java/org/apache/sysds/api/DMLScript.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.CoordinatorModel;
7272
import org.apache.sysds.runtime.controlprogram.parfor.util.IDHandler;
7373
import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool;
74+
import org.apache.sysds.runtime.instructions.ooc.OOCEvictionManager;
7475
import org.apache.sysds.runtime.io.IOUtilFunctions;
7576
import org.apache.sysds.runtime.lineage.LineageCacheConfig;
7677
import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCachePolicy;
@@ -497,6 +498,8 @@ private static void execute(String dmlScriptStr, String fnameOptConfig, Map<Stri
497498
ScriptExecutorUtils.executeRuntimeProgram(rtprog, ec, ConfigurationManager.getDMLConfig(), STATISTICS ? STATISTICS_COUNT : 0, null);
498499
}
499500
finally {
501+
//cleanup OOC streams and cache
502+
OOCEvictionManager.reset();
500503
//cleanup scratch_space and all working dirs
501504
cleanupHadoopExecution(ConfigurationManager.getDMLConfig());
502505
FederatedData.clearWorkGroup();

src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
7777
//add static HOP DAG rewrite rules
7878
_dagRuleSet.add( new RewriteRemoveReadAfterWrite() ); //dependency: before blocksize
7979
_dagRuleSet.add( new RewriteBlockSizeAndReblock() );
80-
_dagRuleSet.add( new RewriteInjectOOCTee() );
8180
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
8281
_dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() );
8382
if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
@@ -94,6 +93,7 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
9493
if( OptimizerUtils.ALLOW_QUANTIZE_COMPRESS_REWRITE )
9594
_dagRuleSet.add( new RewriteQuantizationFusedCompression() );
9695

96+
9797
//add statement block rewrite rules
9898
if( OptimizerUtils.ALLOW_BRANCH_REMOVAL )
9999
_sbRuleSet.add( new RewriteRemoveUnnecessaryBranches() ); //dependency: constant folding
@@ -152,6 +152,7 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
152152
_dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse
153153
_sbRuleSet.add( new RewriteRemoveEmptyBasicBlocks() );
154154
_sbRuleSet.add( new RewriteRemoveEmptyForLoops() );
155+
_sbRuleSet.add( new RewriteInjectOOCTee() );
155156
}
156157

157158
/**

src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java

Lines changed: 152 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.sysds.hops.DataOp;
2626
import org.apache.sysds.hops.Hop;
2727
import org.apache.sysds.hops.ReorgOp;
28+
import org.apache.sysds.parser.StatementBlock;
2829

2930
import java.util.ArrayList;
3031
import java.util.HashMap;
@@ -49,73 +50,20 @@
4950
* 2. <b>Apply Rewrites (Modification):</b> Iterate over the collected candidate and put
5051
* {@code TeeOp}, and safely rewire the graph.
5152
*/
52-
public class RewriteInjectOOCTee extends HopRewriteRule {
53+
public class RewriteInjectOOCTee extends StatementBlockRewriteRule {
5354

5455
public static boolean APPLY_ONLY_XtX_PATTERN = false;
56+
57+
private static final Map<String, Integer> _transientVars = new HashMap<>();
58+
private static final Map<String, List<Hop>> _transientHops = new HashMap<>();
59+
private static final Set<String> teeTransientVars = new HashSet<>();
5560

5661
private static final Set<Long> rewrittenHops = new HashSet<>();
5762
private static final Map<Long, Hop> handledHop = new HashMap<>();
5863

5964
// Maintain a list of candidates to rewrite in the second pass
6065
private final List<Hop> rewriteCandidates = new ArrayList<>();
61-
62-
/**
63-
* Handle a generic (last-level) hop DAG with multiple roots.
64-
*
65-
* @param roots high-level operator roots
66-
* @param state program rewrite status
67-
* @return list of high-level operators
68-
*/
69-
@Override
70-
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
71-
if (roots == null) {
72-
return null;
73-
}
74-
75-
// Clear candidates for this pass
76-
rewriteCandidates.clear();
77-
78-
// PASS 1: Identify candidates without modifying the graph
79-
for (Hop root : roots) {
80-
root.resetVisitStatus();
81-
findRewriteCandidates(root);
82-
}
83-
84-
// PASS 2: Apply rewrites to identified candidates
85-
for (Hop candidate : rewriteCandidates) {
86-
applyTopDownTeeRewrite(candidate);
87-
}
88-
89-
return roots;
90-
}
91-
92-
/**
93-
* Handle a predicate hop DAG with exactly one root.
94-
*
95-
* @param root high-level operator root
96-
* @param state program rewrite status
97-
* @return high-level operator
98-
*/
99-
@Override
100-
public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
101-
if (root == null) {
102-
return null;
103-
}
104-
105-
// Clear candidates for this pass
106-
rewriteCandidates.clear();
107-
108-
// PASS 1: Identify candidates without modifying the graph
109-
root.resetVisitStatus();
110-
findRewriteCandidates(root);
111-
112-
// PASS 2: Apply rewrites to identified candidates
113-
for (Hop candidate : rewriteCandidates) {
114-
applyTopDownTeeRewrite(candidate);
115-
}
116-
117-
return root;
118-
}
66+
private boolean forceTee = false;
11967

12068
/**
12169
* First pass: Find candidates for rewrite without modifying the graph.
@@ -137,6 +85,35 @@ private void findRewriteCandidates(Hop hop) {
13785
findRewriteCandidates(input);
13886
}
13987

88+
boolean isRewriteCandidate = DMLScript.USE_OOC
89+
&& hop.getDataType().isMatrix()
90+
&& !HopRewriteUtils.isData(hop, OpOpData.TEE)
91+
&& hop.getParent().size() > 1
92+
&& (!APPLY_ONLY_XtX_PATTERN || isSelfTranposePattern(hop));
93+
94+
if (HopRewriteUtils.isData(hop, OpOpData.TRANSIENTREAD) && hop.getDataType().isMatrix()) {
95+
_transientVars.compute(hop.getName(), (key, ctr) -> {
96+
int incr = (isRewriteCandidate || forceTee) ? 2 : 1;
97+
98+
int ret = ctr == null ? 0 : ctr;
99+
ret += incr;
100+
101+
if (ret > 1)
102+
teeTransientVars.add(hop.getName());
103+
104+
return ret;
105+
});
106+
107+
_transientHops.compute(hop.getName(), (key, hops) -> {
108+
if (hops == null)
109+
return new ArrayList<>(List.of(hop));
110+
hops.add(hop);
111+
return hops;
112+
});
113+
114+
return; // We do not tee transient reads but rather inject before TWrite or PRead as caching stream
115+
}
116+
140117
// Check if this hop is a candidate for OOC Tee injection
141118
if (DMLScript.USE_OOC
142119
&& hop.getDataType().isMatrix()
@@ -160,11 +137,17 @@ private void applyTopDownTeeRewrite(Hop sharedInput) {
160137
return;
161138
}
162139

140+
int consumerCount = sharedInput.getParent().size();
141+
if (LOG.isDebugEnabled()) {
142+
LOG.debug("Inject tee for hop " + sharedInput.getHopID() + " ("
143+
+ sharedInput.getName() + "), consumers=" + consumerCount);
144+
}
145+
163146
// Take a defensive copy of consumers before modifying the graph
164147
ArrayList<Hop> consumers = new ArrayList<>(sharedInput.getParent());
165148

166149
// Create the new TeeOp with the original hop as input
167-
DataOp teeOp = new DataOp("tee_out_" + sharedInput.getName(),
150+
DataOp teeOp = new DataOp("tee_out_" + sharedInput.getName(),
168151
sharedInput.getDataType(), sharedInput.getValueType(), Types.OpOpData.TEE, null,
169152
sharedInput.getDim1(), sharedInput.getDim2(), sharedInput.getNnz(), sharedInput.getBlocksize());
170153
HopRewriteUtils.addChildReference(teeOp, sharedInput);
@@ -177,6 +160,11 @@ private void applyTopDownTeeRewrite(Hop sharedInput) {
177160
// Record that we've handled this hop
178161
handledHop.put(sharedInput.getHopID(), teeOp);
179162
rewrittenHops.add(sharedInput.getHopID());
163+
164+
if (LOG.isDebugEnabled()) {
165+
LOG.debug("Created tee hop " + teeOp.getHopID() + " -> "
166+
+ teeOp.getName());
167+
}
180168
}
181169

182170
@SuppressWarnings("unused")
@@ -196,4 +184,108 @@ else if (HopRewriteUtils.isMatrixMultiply(parent)) {
196184
}
197185
return hasTransposeConsumer && hasMatrixMultiplyConsumer;
198186
}
187+
188+
@Override
189+
public boolean createsSplitDag() {
190+
return false;
191+
}
192+
193+
@Override
194+
public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) {
195+
if (!DMLScript.USE_OOC)
196+
return List.of(sb);
197+
198+
rewriteSB(sb, state);
199+
200+
for (String tVar : teeTransientVars) {
201+
List<Hop> tHops = _transientHops.get(tVar);
202+
203+
if (tHops == null)
204+
continue;
205+
206+
for (Hop affectedHops : tHops) {
207+
applyTopDownTeeRewrite(affectedHops);
208+
}
209+
210+
tHops.clear();
211+
}
212+
213+
removeRedundantTeeChains(sb);
214+
215+
return List.of(sb);
216+
}
217+
218+
@Override
219+
public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus state) {
220+
if (!DMLScript.USE_OOC)
221+
return sbs;
222+
223+
for (StatementBlock sb : sbs)
224+
rewriteSB(sb, state);
225+
226+
for (String tVar : teeTransientVars) {
227+
List<Hop> tHops = _transientHops.get(tVar);
228+
229+
if (tHops == null)
230+
continue;
231+
232+
for (Hop affectedHops : tHops) {
233+
applyTopDownTeeRewrite(affectedHops);
234+
}
235+
}
236+
237+
for (StatementBlock sb : sbs)
238+
removeRedundantTeeChains(sb);
239+
240+
return sbs;
241+
}
242+
243+
private void rewriteSB(StatementBlock sb, ProgramRewriteStatus state) {
244+
rewriteCandidates.clear();
245+
246+
if (sb.getHops() != null) {
247+
for(Hop hop : sb.getHops()) {
248+
hop.resetVisitStatus();
249+
findRewriteCandidates(hop);
250+
}
251+
}
252+
253+
for (Hop candidate : rewriteCandidates) {
254+
applyTopDownTeeRewrite(candidate);
255+
}
256+
}
257+
258+
private void removeRedundantTeeChains(StatementBlock sb) {
259+
if (sb == null || sb.getHops() == null)
260+
return;
261+
262+
Hop.resetVisitStatus(sb.getHops());
263+
for (Hop hop : sb.getHops())
264+
removeRedundantTeeChains(hop);
265+
Hop.resetVisitStatus(sb.getHops());
266+
}
267+
268+
private void removeRedundantTeeChains(Hop hop) {
269+
if (hop.isVisited())
270+
return;
271+
272+
ArrayList<Hop> inputs = new ArrayList<>(hop.getInput());
273+
for (Hop in : inputs)
274+
removeRedundantTeeChains(in);
275+
276+
if (HopRewriteUtils.isData(hop, OpOpData.TEE) && hop.getInput().size() == 1) {
277+
Hop teeInput = hop.getInput().get(0);
278+
if (HopRewriteUtils.isData(teeInput, OpOpData.TEE)) {
279+
if (LOG.isDebugEnabled()) {
280+
LOG.debug("Remove redundant tee hop " + hop.getHopID()
281+
+ " (" + hop.getName() + ") -> " + teeInput.getHopID()
282+
+ " (" + teeInput.getName() + ")");
283+
}
284+
HopRewriteUtils.rewireAllParentChildReferences(hop, teeInput);
285+
HopRewriteUtils.removeAllChildReferences(hop);
286+
}
287+
}
288+
289+
hop.setVisited();
290+
}
199291
}

src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -471,12 +471,12 @@ public boolean hasBroadcastHandle() {
471471
return _bcHandle != null && _bcHandle.hasBackReference();
472472
}
473473

474-
public OOCStream<IndexedMatrixValue> getStreamHandle() {
474+
public synchronized OOCStream<IndexedMatrixValue> getStreamHandle() {
475475
if( !hasStreamHandle() ) {
476476
final SubscribableTaskQueue<IndexedMatrixValue> _mStream = new SubscribableTaskQueue<>();
477-
_streamHandle = _mStream;
478477
DataCharacteristics dc = getDataCharacteristics();
479478
MatrixBlock src = (MatrixBlock)acquireReadAndRelease();
479+
_streamHandle = _mStream;
480480
LongStream.range(0, dc.getNumBlocks())
481481
.mapToObj(i -> UtilFunctions.createIndexedMatrixBlock(src, dc, i))
482482
.forEach( blk -> {
@@ -489,7 +489,14 @@ public OOCStream<IndexedMatrixValue> getStreamHandle() {
489489
_mStream.closeInput();
490490
}
491491

492-
return _streamHandle.getReadStream();
492+
OOCStream<IndexedMatrixValue> stream = _streamHandle.getReadStream();
493+
if (!stream.hasStreamCache())
494+
_streamHandle = null; // To ensure read once
495+
return stream;
496+
}
497+
498+
public OOCStreamable<IndexedMatrixValue> getStreamable() {
499+
return _streamHandle;
493500
}
494501

495502
/**
@@ -499,7 +506,7 @@ public OOCStream<IndexedMatrixValue> getStreamHandle() {
499506
* @return true if existing, false otherwise
500507
*/
501508
public boolean hasStreamHandle() {
502-
return _streamHandle != null && !_streamHandle.isProcessed();
509+
return _streamHandle != null;
503510
}
504511

505512
@SuppressWarnings({ "rawtypes", "unchecked" })

src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ public class LocalTaskQueue<T>
4545

4646
protected LinkedList<T> _data = null;
4747
protected boolean _closedInput = false;
48-
private DMLRuntimeException _failure = null;
48+
protected DMLRuntimeException _failure = null;
4949
private static final Log LOG = LogFactory.getLog(LocalTaskQueue.class.getName());
5050

5151
public LocalTaskQueue()

src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import org.apache.sysds.runtime.frame.data.FrameBlock;
4747
import org.apache.sysds.runtime.instructions.Instruction;
4848
import org.apache.sysds.runtime.instructions.InstructionUtils;
49+
import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction;
4950
import org.apache.sysds.runtime.io.FileFormatProperties;
5051
import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
5152
import org.apache.sysds.runtime.io.FileFormatPropertiesHDF5;
@@ -1026,6 +1027,9 @@ private void processCopyInstruction(ExecutionContext ec) {
10261027
if ( dd == null )
10271028
throw new DMLRuntimeException("Unexpected error: could not find a data object for variable name:" + getInput1().getName() + ", while processing instruction " +this.toString());
10281029

1030+
if (DMLScript.USE_OOC && dd instanceof MatrixObject)
1031+
TeeOOCInstruction.incrRef(((MatrixObject)dd).getStreamable(), 1);
1032+
10291033
// remove existing variable bound to target name
10301034
Data input2_data = ec.removeVariable(getInput2().getName());
10311035

@@ -1117,6 +1121,8 @@ private void processSetFileNameInstruction(ExecutionContext ec){
11171121
public static void processRmvarInstruction( ExecutionContext ec, String varname ) {
11181122
// remove variable from symbol table
11191123
Data dat = ec.removeVariable(varname);
1124+
if (DMLScript.USE_OOC && dat instanceof MatrixObject)
1125+
TeeOOCInstruction.incrRef(((MatrixObject) dat).getStreamable(), -1);
11201126
//cleanup matrix data on fs/hdfs (if necessary)
11211127
if( dat != null )
11221128
ec.cleanupDataObject(dat);

0 commit comments

Comments
 (0)