@@ -70,13 +70,9 @@ public FedPlan addFedPlan(Hop hop, FederatedOutput fedOutType, List<Pair<Long, F
7070 /**
7171 * Retrieves the minimum cost child plan considering the parent's output type.
7272 * The cost is calculated using getParentViewCost to account for potential type mismatches.
73- *
74- * @param childHopID ?
75- * @param childFedOutType ?
76- * @return ?
7773 */
78- public FedPlan getMinCostFedPlan (long hopID , FederatedOutput fedOutType ) {
79- FedPlanVariants fedPlanVariantList = hopMemoTable .get (new ImmutablePair <>( hopID , fedOutType ) );
74+ public FedPlan getMinCostFedPlan (Pair < Long , FederatedOutput > fedPlanPair ) {
75+ FedPlanVariants fedPlanVariantList = hopMemoTable .get (fedPlanPair );
8076 return fedPlanVariantList ._fedPlanVariants .stream ()
8177 .min (Comparator .comparingDouble (FedPlan ::getTotalCost ))
8278 .orElse (null );
@@ -86,12 +82,22 @@ public FedPlanVariants getFedPlanVariants(long hopID, FederatedOutput fedOutType
8682 return hopMemoTable .get (new ImmutablePair <>(hopID , fedOutType ));
8783 }
8884
85+ public FedPlanVariants getFedPlanVariants (Pair <Long , FederatedOutput > fedPlanPair ) {
86+ return hopMemoTable .get (fedPlanPair );
87+ }
88+
8989 public FedPlan getFedPlanAfterPrune (long hopID , FederatedOutput fedOutType ) {
9090 // Todo: Consider whether to verify if pruning has been performed
9191 FedPlanVariants fedPlanVariantList = hopMemoTable .get (new ImmutablePair <>(hopID , fedOutType ));
9292 return fedPlanVariantList ._fedPlanVariants .get (0 );
9393 }
9494
95+ public FedPlan getFedPlanAfterPrune (Pair <Long , FederatedOutput > fedPlanPair ) {
96+ // Todo: Consider whether to verify if pruning has been performed
97+ FedPlanVariants fedPlanVariantList = hopMemoTable .get (fedPlanPair );
98+ return fedPlanVariantList ._fedPlanVariants .get (0 );
99+ }
100+
95101 /**
96102 * Checks if the memo table contains an entry for a given Hop and fedOutType.
97103 *
@@ -104,128 +110,14 @@ public boolean contains(long hopID, FederatedOutput fedOutType) {
104110 }
105111
106112 /**
107- * Prunes all entries in the memo table, retaining only the minimum-cost
108- * FedPlan for each entry.
109- */
110- public void pruneMemoTable () {
111- for (Map .Entry <Pair <Long , FederatedOutput >, FedPlanVariants > entry : hopMemoTable .entrySet ()) {
112- List <FedPlan > fedPlanList = entry .getValue ().getFedPlanVariants ();
113- if (fedPlanList .size () > 1 ) {
114- // Find the FedPlan with the minimum cost
115- FedPlan minCostPlan = fedPlanList .stream ()
116- .min (Comparator .comparingDouble (FedPlan ::getTotalCost ))
117- .orElse (null );
118-
119- // Retain only the minimum cost plan
120- fedPlanList .clear ();
121- fedPlanList .add (minCostPlan );
122- }
123- }
124- }
125-
126- // Todo: Separate print functions from FederatedMemoTable
127- /**
128- * Recursively prints a tree representation of the DAG starting from the given root FedPlan.
129- * Includes information about hopID, fedOutType, TotalCost, SelfCost, and NetCost for each node.
113+ * Prunes the specified entry in the memo table, retaining only the minimum-cost
114+ * FedPlan for the given Hop ID and federated output type.
130115 *
131- * @param rootFedPlan The starting point FedPlan to print
116+ * @param hopID The ID of the Hop to prune
117+ * @param federatedOutput The federated output type associated with the Hop
132118 */
133- public void printFedPlanTree (FedPlan rootFedPlan ) {
134- Set <FedPlan > visited = new HashSet <>();
135- printFedPlanTreeRecursive (rootFedPlan , visited , 0 , true );
136- }
137-
138- /**
139- * Helper method to recursively print the FedPlan tree.
140- *
141- * @param plan The current FedPlan to print
142- * @param visited Set to keep track of visited FedPlans (prevents cycles)
143- * @param depth The current depth level for indentation
144- * @param isLast Whether this node is the last child of its parent
145- */
146- private void printFedPlanTreeRecursive (FedPlan plan , Set <FedPlan > visited , int depth , boolean isLast ) {
147- if (plan == null || visited .contains (plan )) {
148- return ;
149- }
150-
151- visited .add (plan );
152-
153- Hop hop = plan .getHopRef ();
154- StringBuilder sb = new StringBuilder ();
155-
156- // Add FedPlan information
157- sb .append (String .format ("(%d) " , plan .getHopRef ().getHopID ()))
158- .append (plan .getHopRef ().getOpString ())
159- .append (" [" )
160- .append (plan .getFedOutType ())
161- .append ("]" );
162-
163- StringBuilder childs = new StringBuilder ();
164- childs .append (" (" );
165- boolean childAdded = false ;
166- for ( Hop input : hop .getInput ()){
167- childs .append (childAdded ?"," :"" );
168- childs .append (input .getHopID ());
169- childAdded = true ;
170- }
171- childs .append (")" );
172- if ( childAdded )
173- sb .append (childs .toString ());
174-
175-
176- sb .append (String .format (" {Total: %.1f, Self: %.1f, Net: %.1f}" ,
177- plan .getTotalCost (),
178- plan .getSelfCost (),
179- plan .getNetTransferCost ()));
180-
181- // Add matrix characteristics
182- sb .append (" [" )
183- .append (hop .getDim1 ()).append (", " )
184- .append (hop .getDim2 ()).append (", " )
185- .append (hop .getBlocksize ()).append (", " )
186- .append (hop .getNnz ());
187-
188- if (hop .getUpdateType ().isInPlace ()) {
189- sb .append (", " ).append (hop .getUpdateType ().toString ().toLowerCase ());
190- }
191- sb .append ("]" );
192-
193- // Add memory estimates
194- sb .append (" [" )
195- .append (OptimizerUtils .toMB (hop .getInputMemEstimate ())).append (", " )
196- .append (OptimizerUtils .toMB (hop .getIntermediateMemEstimate ())).append (", " )
197- .append (OptimizerUtils .toMB (hop .getOutputMemEstimate ())).append (" -> " )
198- .append (OptimizerUtils .toMB (hop .getMemEstimate ())).append ("MB]" );
199-
200- // Add reblock and checkpoint requirements
201- if (hop .requiresReblock () && hop .requiresCheckpoint ()) {
202- sb .append (" [rblk, chkpt]" );
203- } else if (hop .requiresReblock ()) {
204- sb .append (" [rblk]" );
205- } else if (hop .requiresCheckpoint ()) {
206- sb .append (" [chkpt]" );
207- }
208-
209- // Add execution type
210- if (hop .getExecType () != null ) {
211- sb .append (", " ).append (hop .getExecType ());
212- }
213-
214- System .out .println (sb );
215-
216- // Process child nodes
217- List <Pair <Long , FederatedOutput >> childRefs = plan .getChildFedPlans ();
218- for (int i = 0 ; i < childRefs .size (); i ++) {
219- Pair <Long , FederatedOutput > childRef = childRefs .get (i );
220- FedPlanVariants childVariants = getFedPlanVariants (childRef .getLeft (), childRef .getRight ());
221- if (childVariants == null || childVariants .getFedPlanVariants ().isEmpty ())
222- continue ;
223-
224- boolean isLastChild = (i == childRefs .size () - 1 );
225- for (FedPlan childPlan : childVariants .getFedPlanVariants ()) {
226- printFedPlanTreeRecursive (childPlan , visited , depth + 1 , isLastChild );
227- }
228- }
119+ public void pruneFedPlan (long hopID , FederatedOutput federatedOutput ) {
120+ hopMemoTable .get (new ImmutablePair <>(hopID , federatedOutput )).prune ();
229121 }
230122
231123 /**
@@ -262,6 +154,20 @@ public FedPlanVariants(Hop hopRef, FederatedOutput fedOutType) {
262154
263155 public void addFedPlan (FedPlan fedPlan ) {_fedPlanVariants .add (fedPlan );}
264156 public List <FedPlan > getFedPlanVariants () {return _fedPlanVariants ;}
157+ public boolean isEmpty () {return _fedPlanVariants .isEmpty ();}
158+
159+ public void prune () {
160+ if (_fedPlanVariants .size () > 1 ) {
161+ // Find the FedPlan with the minimum cost
162+ FedPlan minCostPlan = _fedPlanVariants .stream ()
163+ .min (Comparator .comparingDouble (FedPlan ::getTotalCost ))
164+ .orElse (null );
165+
166+ // Retain only the minimum cost plan
167+ _fedPlanVariants .clear ();
168+ _fedPlanVariants .add (minCostPlan );
169+ }
170+ }
265171 }
266172
267173 /**
0 commit comments