2525import org .apache .sysds .hops .DataOp ;
2626import org .apache .sysds .hops .Hop ;
2727import org .apache .sysds .hops .ReorgOp ;
28+ import org .apache .sysds .parser .StatementBlock ;
2829
2930import java .util .ArrayList ;
3031import java .util .HashMap ;
4950 * 2. <b>Apply Rewrites (Modification):</b> Iterate over the collected candidate and put
5051 * {@code TeeOp}, and safely rewire the graph.
5152 */
52- public class RewriteInjectOOCTee extends HopRewriteRule {
53+ public class RewriteInjectOOCTee extends StatementBlockRewriteRule {
5354
5455 public static boolean APPLY_ONLY_XtX_PATTERN = false ;
56+
57+ private static final Map <String , Integer > _transientVars = new HashMap <>();
58+ private static final Map <String , List <Hop >> _transientHops = new HashMap <>();
59+ private static final Set <String > teeTransientVars = new HashSet <>();
5560
5661 private static final Set <Long > rewrittenHops = new HashSet <>();
5762 private static final Map <Long , Hop > handledHop = new HashMap <>();
5863
5964 // Maintain a list of candidates to rewrite in the second pass
6065 private final List <Hop > rewriteCandidates = new ArrayList <>();
61-
62- /**
63- * Handle a generic (last-level) hop DAG with multiple roots.
64- *
65- * @param roots high-level operator roots
66- * @param state program rewrite status
67- * @return list of high-level operators
68- */
69- @ Override
70- public ArrayList <Hop > rewriteHopDAGs (ArrayList <Hop > roots , ProgramRewriteStatus state ) {
71- if (roots == null ) {
72- return null ;
73- }
74-
75- // Clear candidates for this pass
76- rewriteCandidates .clear ();
77-
78- // PASS 1: Identify candidates without modifying the graph
79- for (Hop root : roots ) {
80- root .resetVisitStatus ();
81- findRewriteCandidates (root );
82- }
83-
84- // PASS 2: Apply rewrites to identified candidates
85- for (Hop candidate : rewriteCandidates ) {
86- applyTopDownTeeRewrite (candidate );
87- }
88-
89- return roots ;
90- }
91-
92- /**
93- * Handle a predicate hop DAG with exactly one root.
94- *
95- * @param root high-level operator root
96- * @param state program rewrite status
97- * @return high-level operator
98- */
99- @ Override
100- public Hop rewriteHopDAG (Hop root , ProgramRewriteStatus state ) {
101- if (root == null ) {
102- return null ;
103- }
104-
105- // Clear candidates for this pass
106- rewriteCandidates .clear ();
107-
108- // PASS 1: Identify candidates without modifying the graph
109- root .resetVisitStatus ();
110- findRewriteCandidates (root );
111-
112- // PASS 2: Apply rewrites to identified candidates
113- for (Hop candidate : rewriteCandidates ) {
114- applyTopDownTeeRewrite (candidate );
115- }
116-
117- return root ;
118- }
66+ private boolean forceTee = false ;
11967
12068 /**
12169 * First pass: Find candidates for rewrite without modifying the graph.
@@ -137,6 +85,35 @@ private void findRewriteCandidates(Hop hop) {
13785 findRewriteCandidates (input );
13886 }
13987
88+ boolean isRewriteCandidate = DMLScript .USE_OOC
89+ && hop .getDataType ().isMatrix ()
90+ && !HopRewriteUtils .isData (hop , OpOpData .TEE )
91+ && hop .getParent ().size () > 1
92+ && (!APPLY_ONLY_XtX_PATTERN || isSelfTranposePattern (hop ));
93+
94+ if (HopRewriteUtils .isData (hop , OpOpData .TRANSIENTREAD ) && hop .getDataType ().isMatrix ()) {
95+ _transientVars .compute (hop .getName (), (key , ctr ) -> {
96+ int incr = (isRewriteCandidate || forceTee ) ? 2 : 1 ;
97+
98+ int ret = ctr == null ? 0 : ctr ;
99+ ret += incr ;
100+
101+ if (ret > 1 )
102+ teeTransientVars .add (hop .getName ());
103+
104+ return ret ;
105+ });
106+
107+ _transientHops .compute (hop .getName (), (key , hops ) -> {
108+ if (hops == null )
109+ return new ArrayList <>(List .of (hop ));
110+ hops .add (hop );
111+ return hops ;
112+ });
113+
114+ return ; // We do not tee transient reads but rather inject before TWrite or PRead as caching stream
115+ }
116+
140117 // Check if this hop is a candidate for OOC Tee injection
141118 if (DMLScript .USE_OOC
142119 && hop .getDataType ().isMatrix ()
@@ -160,11 +137,17 @@ private void applyTopDownTeeRewrite(Hop sharedInput) {
160137 return ;
161138 }
162139
140+ int consumerCount = sharedInput .getParent ().size ();
141+ if (LOG .isDebugEnabled ()) {
142+ LOG .debug ("Inject tee for hop " + sharedInput .getHopID () + " ("
143+ + sharedInput .getName () + "), consumers=" + consumerCount );
144+ }
145+
163146 // Take a defensive copy of consumers before modifying the graph
164147 ArrayList <Hop > consumers = new ArrayList <>(sharedInput .getParent ());
165148
166149 // Create the new TeeOp with the original hop as input
167- DataOp teeOp = new DataOp ("tee_out_" + sharedInput .getName (),
150+ DataOp teeOp = new DataOp ("tee_out_" + sharedInput .getName (),
168151 sharedInput .getDataType (), sharedInput .getValueType (), Types .OpOpData .TEE , null ,
169152 sharedInput .getDim1 (), sharedInput .getDim2 (), sharedInput .getNnz (), sharedInput .getBlocksize ());
170153 HopRewriteUtils .addChildReference (teeOp , sharedInput );
@@ -177,6 +160,11 @@ private void applyTopDownTeeRewrite(Hop sharedInput) {
177160 // Record that we've handled this hop
178161 handledHop .put (sharedInput .getHopID (), teeOp );
179162 rewrittenHops .add (sharedInput .getHopID ());
163+
164+ if (LOG .isDebugEnabled ()) {
165+ LOG .debug ("Created tee hop " + teeOp .getHopID () + " -> "
166+ + teeOp .getName ());
167+ }
180168 }
181169
182170 @ SuppressWarnings ("unused" )
@@ -196,4 +184,108 @@ else if (HopRewriteUtils.isMatrixMultiply(parent)) {
196184 }
197185 return hasTransposeConsumer && hasMatrixMultiplyConsumer ;
198186 }
187+
188+ @ Override
189+ public boolean createsSplitDag () {
190+ return false ;
191+ }
192+
193+ @ Override
194+ public List <StatementBlock > rewriteStatementBlock (StatementBlock sb , ProgramRewriteStatus state ) {
195+ if (!DMLScript .USE_OOC )
196+ return List .of (sb );
197+
198+ rewriteSB (sb , state );
199+
200+ for (String tVar : teeTransientVars ) {
201+ List <Hop > tHops = _transientHops .get (tVar );
202+
203+ if (tHops == null )
204+ continue ;
205+
206+ for (Hop affectedHops : tHops ) {
207+ applyTopDownTeeRewrite (affectedHops );
208+ }
209+
210+ tHops .clear ();
211+ }
212+
213+ removeRedundantTeeChains (sb );
214+
215+ return List .of (sb );
216+ }
217+
218+ @ Override
219+ public List <StatementBlock > rewriteStatementBlocks (List <StatementBlock > sbs , ProgramRewriteStatus state ) {
220+ if (!DMLScript .USE_OOC )
221+ return sbs ;
222+
223+ for (StatementBlock sb : sbs )
224+ rewriteSB (sb , state );
225+
226+ for (String tVar : teeTransientVars ) {
227+ List <Hop > tHops = _transientHops .get (tVar );
228+
229+ if (tHops == null )
230+ continue ;
231+
232+ for (Hop affectedHops : tHops ) {
233+ applyTopDownTeeRewrite (affectedHops );
234+ }
235+ }
236+
237+ for (StatementBlock sb : sbs )
238+ removeRedundantTeeChains (sb );
239+
240+ return sbs ;
241+ }
242+
243+ private void rewriteSB (StatementBlock sb , ProgramRewriteStatus state ) {
244+ rewriteCandidates .clear ();
245+
246+ if (sb .getHops () != null ) {
247+ for (Hop hop : sb .getHops ()) {
248+ hop .resetVisitStatus ();
249+ findRewriteCandidates (hop );
250+ }
251+ }
252+
253+ for (Hop candidate : rewriteCandidates ) {
254+ applyTopDownTeeRewrite (candidate );
255+ }
256+ }
257+
258+ private void removeRedundantTeeChains (StatementBlock sb ) {
259+ if (sb == null || sb .getHops () == null )
260+ return ;
261+
262+ Hop .resetVisitStatus (sb .getHops ());
263+ for (Hop hop : sb .getHops ())
264+ removeRedundantTeeChains (hop );
265+ Hop .resetVisitStatus (sb .getHops ());
266+ }
267+
268+ private void removeRedundantTeeChains (Hop hop ) {
269+ if (hop .isVisited ())
270+ return ;
271+
272+ ArrayList <Hop > inputs = new ArrayList <>(hop .getInput ());
273+ for (Hop in : inputs )
274+ removeRedundantTeeChains (in );
275+
276+ if (HopRewriteUtils .isData (hop , OpOpData .TEE ) && hop .getInput ().size () == 1 ) {
277+ Hop teeInput = hop .getInput ().get (0 );
278+ if (HopRewriteUtils .isData (teeInput , OpOpData .TEE )) {
279+ if (LOG .isDebugEnabled ()) {
280+ LOG .debug ("Remove redundant tee hop " + hop .getHopID ()
281+ + " (" + hop .getName () + ") -> " + teeInput .getHopID ()
282+ + " (" + teeInput .getName () + ")" );
283+ }
284+ HopRewriteUtils .rewireAllParentChildReferences (hop , teeInput );
285+ HopRewriteUtils .removeAllChildReferences (hop );
286+ }
287+ }
288+
289+ hop .setVisited ();
290+ }
199291}
0 commit comments