2828import org .apache .commons .lang3 .tuple .Pair ;
2929import org .apache .commons .lang3 .tuple .ImmutablePair ;
3030import org .apache .sysds .runtime .instructions .fed .FEDInstruction .FederatedOutput ;
31+ import org .apache .sysds .common .Types .ExecType ;
3132
3233/**
3334 * A Memoization Table for managing federated plans (FedPlan) based on combinations of Hops and fedOutTypes.
@@ -46,13 +47,19 @@ public FedPlanVariants getFedPlanVariants(Pair<Long, FederatedOutput> fedPlanPai
4647 return hopMemoTable .get (fedPlanPair );
4748 }
4849
49- public FedPlan getFedPlanAfterPrune (long hopID , FederatedOutput fedOutType ) {
50- FedPlanVariants fedPlanVariantList = hopMemoTable .get (new ImmutablePair <>(hopID , fedOutType ));
50+ public FedPlan getFedPlanAfterPrune (long hopID , FederatedOutput federatedOutput ) {
51+ FedPlanVariants fedPlanVariantList = hopMemoTable .get (new ImmutablePair <>(hopID , federatedOutput ));
52+ if (fedPlanVariantList == null || fedPlanVariantList .isEmpty ()) {
53+ return null ;
54+ }
5155 return fedPlanVariantList ._fedPlanVariants .get (0 );
5256 }
5357
5458 public FedPlan getFedPlanAfterPrune (Pair <Long , FederatedOutput > fedPlanPair ) {
5559 FedPlanVariants fedPlanVariantList = hopMemoTable .get (fedPlanPair );
60+ if (fedPlanVariantList == null || fedPlanVariantList .isEmpty ()) {
61+ return null ;
62+ }
5663 return fedPlanVariantList ._fedPlanVariants .get (0 );
5764 }
5865
@@ -61,13 +68,17 @@ public boolean contains(long hopID, FederatedOutput fedOutType) {
6168 }
6269
6370 /**
64- * Represents a single federated execution plan with its associated costs and dependencies.
71+ * Represents a single federated execution plan with its associated costs and
72+ * dependencies.
6573 * This class contains:
66- * 1. selfCost: Cost of the current hop (computation + input/output memory access).
67- * 2. cumulativeCost: Total cost including this plan's selfCost and all child plans' cumulativeCost.
74+ * 1. selfCost: Cost of the current hop (computation + input/output memory
75+ * access).
76+ * 2. cumulativeCost: Total cost including this plan's selfCost and all child
77+ * plans' cumulativeCost.
6878 * 3. forwardingCost: Network transfer cost for this plan to the parent plan.
6979 *
70- * FedPlan is linked to FedPlanVariants, which in turn uses HopCommon to manage common properties and costs.
80+ * FedPlan is linked to FedPlanVariants, which in turn uses HopCommon to manage
81+ * common properties and costs.
7182 */
7283 public static class FedPlan {
7384 private double cumulativeCost ; // Total cost = sum of selfCost + cumulativeCost of child plans
@@ -84,10 +95,31 @@ public FedPlan(double cumulativeCost, FedPlanVariants fedPlanVariants, List<Pair
8495 public long getHopID () {return fedPlanVariants .hopCommon .getHopRef ().getHopID ();}
8596 public FederatedOutput getFedOutType () {return fedPlanVariants .getFedOutType ();}
8697 public double getCumulativeCost () {return cumulativeCost ;}
98+ public double getCumulativeCostPerParents () {
99+ double cumulativeCostPerParents = cumulativeCost ;
100+ int numOfParents = fedPlanVariants .hopCommon .getNumOfParents ();
101+ if (numOfParents >= 2 ){
102+ cumulativeCostPerParents /= numOfParents ;
103+ }
104+ return cumulativeCostPerParents ;
105+ }
87106 public double getSelfCost () {return fedPlanVariants .hopCommon .getSelfCost ();}
88107 public double getForwardingCost () {return fedPlanVariants .hopCommon .getForwardingCost ();}
89- public double getWeight () {return fedPlanVariants .hopCommon .getWeight ();}
108+ public double getForwardingCostPerParents () {
109+ double forwardingCostPerParents = fedPlanVariants .hopCommon .getForwardingCost ();
110+ int numOfParents = fedPlanVariants .hopCommon .getNumOfParents ();
111+ if (numOfParents >= 2 ){
112+ forwardingCostPerParents /= numOfParents ;
113+ }
114+ return forwardingCostPerParents ;
115+ }
116+ public double getComputeWeight () {return fedPlanVariants .hopCommon .getComputeWeight ();}
117+ public double getNetworkWeight () {return fedPlanVariants .hopCommon .getNetworkWeight ();}
118+ public double getChildForwardingWeight (List <Pair <Long , Double >> childLoopContext ) {return fedPlanVariants .hopCommon .getChildForwardingWeight (childLoopContext );}
119+ public List <Pair <Long , Double >> getLoopContext () {return fedPlanVariants .hopCommon .getLoopContext ();}
90120 public List <Pair <Long , FederatedOutput >> getChildFedPlans () {return childFedPlans ;}
121+ public void setFederatedOutput (FederatedOutput fedOutType ) {fedPlanVariants .hopCommon .hopRef .setFederatedOutput (fedOutType );}
122+ public void setForcedExecType (ExecType execType ) {fedPlanVariants .hopCommon .hopRef .setForcedExecType (execType );}
91123 }
92124
93125 /**
@@ -111,8 +143,8 @@ public FedPlanVariants(HopCommon hopCommon, FederatedOutput fedOutType) {
111143 public List <FedPlan > getFedPlanVariants () {return _fedPlanVariants ;}
112144 public FederatedOutput getFedOutType () {return fedOutType ;}
113145
114- public void pruneFedPlans () {
115- if (_fedPlanVariants .size () > 1 ) {
146+ public boolean pruneFedPlans () {
147+ if (! _fedPlanVariants .isEmpty () ) {
116148 // Find the FedPlan with the minimum cumulative cost
117149 FedPlan minCostPlan = _fedPlanVariants .stream ()
118150 .min (Comparator .comparingDouble (FedPlan ::getCumulativeCost ))
@@ -121,33 +153,63 @@ public void pruneFedPlans() {
121153 // Retain only the minimum cost plan
122154 _fedPlanVariants .clear ();
123155 _fedPlanVariants .add (minCostPlan );
156+ return true ;
124157 }
158+ return false ;
125159 }
126160 }
127161
128162 /**
129163 * Represents common properties and costs associated with a Hop.
130164 * This class holds a reference to the Hop and tracks its execution and network forwarding (transfer) costs.
165+ * It also maintains the loop context information to properly calculate forwarding costs within loops.
131166 */
132167 public static class HopCommon {
133168 protected final Hop hopRef ; // Reference to the associated Hop
134169 protected double selfCost ; // Cost of the hop's computation and memory access
135170 protected double forwardingCost ; // Cost of forwarding the hop's output to its parent
136- protected double weight ; // Weight used to calculate cost based on hop execution frequency
171+ protected int numOfParents ;
172+ protected double computeWeight ; // Weight used to calculate cost based on hop execution frequency
173+ protected double networkWeight ; // Weight used to calculate cost based on hop execution frequency
174+ protected List <Pair <Long , Double >> loopContext ; // Loop context in which this hop exists
137175
138- public HopCommon (Hop hopRef , double weight ) {
176+ public HopCommon (Hop hopRef , double computeWeight , double networkWeight , int numOfParents , List < Pair < Long , Double >> loopContext ) {
139177 this .hopRef = hopRef ;
140178 this .selfCost = 0 ;
141179 this .forwardingCost = 0 ;
142- this .weight = weight ;
180+ this .numOfParents = numOfParents ;
181+ this .computeWeight = computeWeight ;
182+ this .networkWeight = networkWeight ;
183+ this .loopContext = loopContext != null ? new ArrayList <>(loopContext ) : new ArrayList <>();
143184 }
144185
145186 public Hop getHopRef () {return hopRef ;}
146187 public double getSelfCost () {return selfCost ;}
147188 public double getForwardingCost () {return forwardingCost ;}
148- public double getWeight () {return weight ;}
189+ public double getComputeWeight () {return computeWeight ;}
190+ public double getNetworkWeight () {return networkWeight ;}
191+ public int getNumOfParents () {return numOfParents ;}
192+ public List <Pair <Long , Double >> getLoopContext () {return loopContext ;}
149193
150194 protected void setSelfCost (double selfCost ) {this .selfCost = selfCost ;}
151195 protected void setForwardingCost (double forwardingCost ) {this .forwardingCost = forwardingCost ;}
196+ protected void setNumOfParentHops (int numOfParentHops ) {this .numOfParents = numOfParentHops ;}
197+
198+ public double getChildForwardingWeight (List <Pair <Long , Double >> childLoopContext ) {
199+ if (loopContext .isEmpty ()) {
200+ return networkWeight ;
201+ }
202+
203+ double forwardingWeight = this .networkWeight ;
204+
205+ for (int i = 0 ; i < loopContext .size (); i ++) {
206+ if (i >= childLoopContext .size () || loopContext .get (i ).getLeft () != childLoopContext .get (i ).getLeft ()) {
207+ forwardingWeight /=loopContext .get (i ).getRight ();
208+ }
209+ }
210+
211+ // Check if the innermost loops are the same
212+ return forwardingWeight ;
213+ }
152214 }
153215}
0 commit comments