2020package org .apache .sysds .hops .fedplanner ;
2121
2222import org .apache .sysds .hops .Hop ;
23- import org .apache .sysds .hops .OptimizerUtils ;
2423import org .apache .commons .lang3 .tuple .Pair ;
2524import org .apache .commons .lang3 .tuple .ImmutablePair ;
2625import org .apache .sysds .runtime .instructions .fed .FEDInstruction .FederatedOutput ;
2928import java .util .List ;
3029import java .util .ArrayList ;
3130import java .util .Map ;
32- import java .util .HashSet ;
33- import java .util .Set ;
3431
3532/**
3633 * A Memoization Table for managing federated plans (FedPlan) based on combinations of Hops and fedOutTypes.
@@ -71,12 +68,11 @@ public FedPlan addFedPlan(Hop hop, FederatedOutput fedOutType, List<Pair<Long, F
7168 * Retrieves the minimum cost child plan considering the parent's output type.
7269 * The cost is calculated using getParentViewCost to account for potential type mismatches.
7370 *
74- * @param childHopID ?
75- * @param childFedOutType ?
76- * @return ?
71+ * @param fedPlanPair ???
72+ * @return min cost fed plan
7773 */
78- public FedPlan getMinCostChildFedPlan ( long childHopID , FederatedOutput childFedOutType ) {
79- FedPlanVariants fedPlanVariantList = hopMemoTable .get (new ImmutablePair <>( childHopID , childFedOutType ) );
74+ public FedPlan getMinCostFedPlan ( Pair < Long , FederatedOutput > fedPlanPair ) {
75+ FedPlanVariants fedPlanVariantList = hopMemoTable .get (fedPlanPair );
8076 return fedPlanVariantList ._fedPlanVariants .stream ()
8177 .min (Comparator .comparingDouble (FedPlan ::getTotalCost ))
8278 .orElse (null );
@@ -86,6 +82,22 @@ public FedPlanVariants getFedPlanVariants(long hopID, FederatedOutput fedOutType
8682 return hopMemoTable .get (new ImmutablePair <>(hopID , fedOutType ));
8783 }
8884
85+ public FedPlanVariants getFedPlanVariants (Pair <Long , FederatedOutput > fedPlanPair ) {
86+ return hopMemoTable .get (fedPlanPair );
87+ }
88+
89+ public FedPlan getFedPlanAfterPrune (long hopID , FederatedOutput fedOutType ) {
90+ // Todo: Consider whether to verify if pruning has been performed
91+ FedPlanVariants fedPlanVariantList = hopMemoTable .get (new ImmutablePair <>(hopID , fedOutType ));
92+ return fedPlanVariantList ._fedPlanVariants .get (0 );
93+ }
94+
95+ public FedPlan getFedPlanAfterPrune (Pair <Long , FederatedOutput > fedPlanPair ) {
96+ // Todo: Consider whether to verify if pruning has been performed
97+ FedPlanVariants fedPlanVariantList = hopMemoTable .get (fedPlanPair );
98+ return fedPlanVariantList ._fedPlanVariants .get (0 );
99+ }
100+
89101 /**
90102 * Checks if the memo table contains an entry for a given Hop and fedOutType.
91103 *
@@ -98,162 +110,77 @@ public boolean contains(long hopID, FederatedOutput fedOutType) {
98110 }
99111
100112 /**
101- * Prunes all entries in the memo table, retaining only the minimum-cost
102- * FedPlan for each entry.
103- */
104- public void pruneMemoTable () {
105- for (Map .Entry <Pair <Long , FederatedOutput >, FedPlanVariants > entry : hopMemoTable .entrySet ()) {
106- List <FedPlan > fedPlanList = entry .getValue ().getFedPlanVariants ();
107- if (fedPlanList .size () > 1 ) {
108- // Find the FedPlan with the minimum cost
109- FedPlan minCostPlan = fedPlanList .stream ()
110- .min (Comparator .comparingDouble (FedPlan ::getTotalCost ))
111- .orElse (null );
112-
113- // Retain only the minimum cost plan
114- fedPlanList .clear ();
115- fedPlanList .add (minCostPlan );
116- }
117- }
118- }
119-
120- /**
121- * Recursively prints a tree representation of the DAG starting from the given root FedPlan.
122- * Includes information about hopID, fedOutType, TotalCost, SelfCost, and NetCost for each node.
113+ * Prunes the specified entry in the memo table, retaining only the minimum-cost
114+ * FedPlan for the given Hop ID and federated output type.
123115 *
124- * @param rootFedPlan The starting point FedPlan to print
116+ * @param hopID The ID of the Hop to prune
117+ * @param federatedOutput The federated output type associated with the Hop
125118 */
126- public void printFedPlanTree (FedPlan rootFedPlan ) {
127- Set <FedPlan > visited = new HashSet <>();
128- printFedPlanTreeRecursive (rootFedPlan , visited , 0 , true );
119+ public void pruneFedPlan (long hopID , FederatedOutput federatedOutput ) {
120+ hopMemoTable .get (new ImmutablePair <>(hopID , federatedOutput )).prune ();
129121 }
130122
131123 /**
132- * Helper method to recursively print the FedPlan tree.
133- *
134- * @param plan The current FedPlan to print
135- * @param visited Set to keep track of visited FedPlans (prevents cycles)
136- * @param depth The current depth level for indentation
137- * @param isLast Whether this node is the last child of its parent
124+ * Represents common properties and costs associated with a Hop.
125+ * This class holds a reference to the Hop and tracks its execution and network transfer costs.
138126 */
139- private void printFedPlanTreeRecursive (FedPlan plan , Set <FedPlan > visited , int depth , boolean isLast ) {
140- if (plan == null || visited .contains (plan )) {
141- return ;
142- }
143-
144- visited .add (plan );
145-
146- Hop hop = plan .getHopRef ();
147- StringBuilder sb = new StringBuilder ();
148-
149- // Add FedPlan information
150- sb .append (String .format ("(%d) " , plan .getHopRef ().getHopID ()))
151- .append (plan .getHopRef ().getOpString ())
152- .append (" [" )
153- .append (plan .getFedOutType ())
154- .append ("]" );
155-
156- StringBuilder childs = new StringBuilder ();
157- childs .append (" (" );
158- boolean childAdded = false ;
159- for ( Hop input : hop .getInput ()){
160- childs .append (childAdded ?"," :"" );
161- childs .append (input .getHopID ());
162- childAdded = true ;
163- }
164- childs .append (")" );
165- if ( childAdded )
166- sb .append (childs .toString ());
167-
168-
169- sb .append (String .format (" {Total: %.1f, Self: %.1f, Net: %.1f}" ,
170- plan .getTotalCost (),
171- plan .getSelfCost (),
172- plan .getNetTransferCost ()));
173-
174- // Add matrix characteristics
175- sb .append (" [" )
176- .append (hop .getDim1 ()).append (", " )
177- .append (hop .getDim2 ()).append (", " )
178- .append (hop .getBlocksize ()).append (", " )
179- .append (hop .getNnz ());
180-
181- if (hop .getUpdateType ().isInPlace ()) {
182- sb .append (", " ).append (hop .getUpdateType ().toString ().toLowerCase ());
183- }
184- sb .append ("]" );
185-
186- // Add memory estimates
187- sb .append (" [" )
188- .append (OptimizerUtils .toMB (hop .getInputMemEstimate ())).append (", " )
189- .append (OptimizerUtils .toMB (hop .getIntermediateMemEstimate ())).append (", " )
190- .append (OptimizerUtils .toMB (hop .getOutputMemEstimate ())).append (" -> " )
191- .append (OptimizerUtils .toMB (hop .getMemEstimate ())).append ("MB]" );
192-
193- // Add reblock and checkpoint requirements
194- if (hop .requiresReblock () && hop .requiresCheckpoint ()) {
195- sb .append (" [rblk, chkpt]" );
196- } else if (hop .requiresReblock ()) {
197- sb .append (" [rblk]" );
198- } else if (hop .requiresCheckpoint ()) {
199- sb .append (" [chkpt]" );
200- }
201-
202- // Add execution type
203- if (hop .getExecType () != null ) {
204- sb .append (", " ).append (hop .getExecType ());
205- }
206-
207- System .out .println (sb );
208-
209- // Process child nodes
210- List <Pair <Long , FederatedOutput >> childRefs = plan .getChildFedPlans ();
211- for (int i = 0 ; i < childRefs .size (); i ++) {
212- Pair <Long , FederatedOutput > childRef = childRefs .get (i );
213- FedPlanVariants childVariants = getFedPlanVariants (childRef .getLeft (), childRef .getRight ());
214- if (childVariants == null || childVariants .getFedPlanVariants ().isEmpty ())
215- continue ;
127+ public static class HopCommon {
128+ protected final Hop hopRef ; // Reference to the associated Hop
129+ protected double selfCost ; // Current execution cost (compute + memory access)
130+ protected double netTransferCost ; // Network transfer cost
216131
217- boolean isLastChild = ( i == childRefs . size () - 1 );
218- for ( FedPlan childPlan : childVariants . getFedPlanVariants ()) {
219- printFedPlanTreeRecursive ( childPlan , visited , depth + 1 , isLastChild ) ;
220- }
132+ protected HopCommon ( Hop hopRef ) {
133+ this . hopRef = hopRef ;
134+ this . selfCost = 0 ;
135+ this . netTransferCost = 0 ;
221136 }
222137 }
223138
224139 /**
225- * Represents a collection of federated execution plan variants for a specific Hop.
226- * Contains cost information and references to the associated plans.
140+ * Represents a collection of federated execution plan variants for a specific Hop and FederatedOutput.
141+ * This class contains cost information and references to the associated plans.
142+ * It uses HopCommon to store common properties and costs related to the Hop.
227143 */
228144 public static class FedPlanVariants {
229- protected final Hop hopRef ; // Reference to the associated Hop
230- protected double selfCost ; // Current execution cost (compute + memory access)
231- protected double netTransferCost ; // Network transfer cost
232- private final FederatedOutput fedOutType ; // Output type (FOUT/LOUT)
233- protected List <FedPlan > _fedPlanVariants ; // List of plan variants
145+ protected HopCommon hopCommon ; // Common properties and costs for the Hop
146+ private final FederatedOutput fedOutType ; // Output type (FOUT/LOUT)
147+ protected List <FedPlan > _fedPlanVariants ; // List of plan variants
234148
235149 public FedPlanVariants (Hop hopRef , FederatedOutput fedOutType ) {
236- this .hopRef = hopRef ;
150+ this .hopCommon = new HopCommon ( hopRef ) ;
237151 this .fedOutType = fedOutType ;
238- this .selfCost = 0 ;
239- this .netTransferCost = 0 ;
240152 this ._fedPlanVariants = new ArrayList <>();
241153 }
242154
243- public int size () {return _fedPlanVariants .size ();}
244155 public void addFedPlan (FedPlan fedPlan ) {_fedPlanVariants .add (fedPlan );}
245156 public List <FedPlan > getFedPlanVariants () {return _fedPlanVariants ;}
157+ public boolean isEmpty () {return _fedPlanVariants .isEmpty ();}
158+
159+ public void prune () {
160+ if (_fedPlanVariants .size () > 1 ) {
161+ // Find the FedPlan with the minimum cost
162+ FedPlan minCostPlan = _fedPlanVariants .stream ()
163+ .min (Comparator .comparingDouble (FedPlan ::getTotalCost ))
164+ .orElse (null );
165+
166+ // Retain only the minimum cost plan
167+ _fedPlanVariants .clear ();
168+ _fedPlanVariants .add (minCostPlan );
169+ }
170+ }
246171 }
247172
248173 /**
249174 * Represents a single federated execution plan with its associated costs and dependencies.
250- * Contains :
175+ * This class contains :
251176 * 1. selfCost: Cost of current hop (compute + input/output memory access)
252177 * 2. totalCost: Cumulative cost including this plan and all child plans
253178 * 3. netTransferCost: Network transfer cost for this plan to parent plan.
179+ *
180+ * FedPlan is linked to FedPlanVariants, which in turn uses HopCommon to manage common properties and costs.
254181 */
255182 public static class FedPlan {
256- private double totalCost ; // Total cost including child plans
183+ private double totalCost ; // Total cost including child plans
257184 private final FedPlanVariants fedPlanVariants ; // Reference to variant list
258185 private final List <Pair <Long , FederatedOutput >> childFedPlans ; // Child plan references
259186
@@ -264,25 +191,26 @@ public FedPlan(List<Pair<Long, FederatedOutput>> childFedPlans, FedPlanVariants
264191 }
265192
266193 public void setTotalCost (double totalCost ) {this .totalCost = totalCost ;}
267- public void setSelfCost (double selfCost ) {fedPlanVariants .selfCost = selfCost ;}
268- public void setNetTransferCost (double netTransferCost ) {fedPlanVariants .netTransferCost = netTransferCost ;}
269-
270- public Hop getHopRef () {return fedPlanVariants .hopRef ;}
194+ public void setSelfCost (double selfCost ) {fedPlanVariants .hopCommon .selfCost = selfCost ;}
195+ public void setNetTransferCost (double netTransferCost ) {fedPlanVariants .hopCommon .netTransferCost = netTransferCost ;}
196+
197+ public Hop getHopRef () {return fedPlanVariants .hopCommon .hopRef ;}
198+ public long getHopID () {return fedPlanVariants .hopCommon .hopRef .getHopID ();}
271199 public FederatedOutput getFedOutType () {return fedPlanVariants .fedOutType ;}
272200 public double getTotalCost () {return totalCost ;}
273- public double getSelfCost () {return fedPlanVariants .selfCost ;}
274- private double getNetTransferCost () {return fedPlanVariants .netTransferCost ;}
201+ public double getSelfCost () {return fedPlanVariants .hopCommon . selfCost ;}
202+ public double getNetTransferCost () {return fedPlanVariants . hopCommon .netTransferCost ;}
275203 public List <Pair <Long , FederatedOutput >> getChildFedPlans () {return childFedPlans ;}
276204
277205 /**
278206 * Calculates the conditional network transfer cost based on output type compatibility.
279207 * Returns 0 if output types match, otherwise returns the network transfer cost.
280- * @param parentFedOutType ?
281- * @return ?
208+ * @param parentFedOutType The federated output type of the parent plan.
209+ * @return The conditional network transfer cost.
282210 */
283211 public double getCondNetTransferCost (FederatedOutput parentFedOutType ) {
284212 if (parentFedOutType == getFedOutType ()) return 0 ;
285- return fedPlanVariants .netTransferCost ;
213+ return fedPlanVariants .hopCommon . netTransferCost ;
286214 }
287215 }
288216}
0 commit comments