Skip to content

Commit 0fcfc6d

Browse files
committed
Update CostEnumeratorTest, printFedPlanTreeRecursive
1 parent a5d4020 commit 0fcfc6d

File tree

3 files changed

+92
-30
lines changed

3 files changed

+92
-30
lines changed

src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
package org.apache.sysds.hops.fedplanner;
2121

2222
import org.apache.sysds.hops.Hop;
23+
import org.apache.sysds.hops.OptimizerUtils;
2324
import org.apache.commons.lang3.tuple.Pair;
2425
import org.apache.commons.lang3.tuple.ImmutablePair;
2526
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
26-
2727
import java.util.Comparator;
2828
import java.util.HashMap;
2929
import java.util.List;
@@ -139,19 +139,68 @@ private void printFedPlanTreeRecursive(FedPlan plan, Set<FedPlan> visited, int d
139139

140140
visited.add(plan);
141141

142-
// Create indentation and connectors for tree visualization
143-
String indent = " ".repeat(depth);
144-
String prefix = depth == 0 ? "└──" :
145-
isLast ? "└─" : "├─";
146-
147-
// Print plan information
148-
System.out.printf("%s%sHop %d [%s] (Total: %.3f, Self: %.3f, Net: %.3f)%n",
149-
indent, prefix,
150-
plan.getHopRef().getHopID(),
151-
plan.getFedOutType(),
152-
plan.getTotalCost(),
153-
plan.getSelfCost(),
154-
plan.getNetTransferCost());
142+
Hop hop = plan.getHopRef();
143+
StringBuilder sb = new StringBuilder();
144+
145+
// Add FedPlan information
146+
sb.append(String.format("(%d) ", plan.getHopRef().getHopID()))
147+
.append(plan.getHopRef().getOpString())
148+
.append(" [")
149+
.append(plan.getFedOutType())
150+
.append("]");
151+
152+
StringBuilder childs = new StringBuilder();
153+
childs.append(" (");
154+
boolean childAdded = false;
155+
for( Hop input : hop.getInput()){
156+
childs.append(childAdded?",":"");
157+
childs.append(input.getHopID());
158+
childAdded = true;
159+
}
160+
childs.append(")");
161+
if( childAdded )
162+
sb.append(childs.toString());
163+
164+
165+
sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}",
166+
plan.getTotalCost(),
167+
plan.getSelfCost(),
168+
plan.getNetTransferCost()));
169+
170+
// Add matrix characteristics
171+
sb.append(" [")
172+
.append(hop.getDim1()).append(", ")
173+
.append(hop.getDim2()).append(", ")
174+
.append(hop.getBlocksize()).append(", ")
175+
.append(hop.getNnz());
176+
177+
if (hop.getUpdateType().isInPlace()) {
178+
sb.append(", ").append(hop.getUpdateType().toString().toLowerCase());
179+
}
180+
sb.append("]");
181+
182+
// Add memory estimates
183+
sb.append(" [")
184+
.append(OptimizerUtils.toMB(hop.getInputMemEstimate())).append(", ")
185+
.append(OptimizerUtils.toMB(hop.getIntermediateMemEstimate())).append(", ")
186+
.append(OptimizerUtils.toMB(hop.getOutputMemEstimate())).append(" -> ")
187+
.append(OptimizerUtils.toMB(hop.getMemEstimate())).append("MB]");
188+
189+
// Add reblock and checkpoint requirements
190+
if (hop.requiresReblock() && hop.requiresCheckpoint()) {
191+
sb.append(" [rblk, chkpt]");
192+
} else if (hop.requiresReblock()) {
193+
sb.append(" [rblk]");
194+
} else if (hop.requiresCheckpoint()) {
195+
sb.append(" [chkpt]");
196+
}
197+
198+
// Add execution type
199+
if (hop.getExecType() != null) {
200+
sb.append(", ").append(hop.getExecType());
201+
}
202+
203+
System.out.println(sb);
155204

156205
// Process child nodes
157206
List<Pair<Long, FederatedOutput>> childRefs = plan.getChildFedPlans();

src/test/java/org/apache/sysds/test/component/federated/privacy/FederatedPlanCostEnumeratorTest.java

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,28 +39,15 @@
3939

4040
public class FederatedPlanCostEnumeratorTest extends AutomatedTestBase
4141
{
42-
private static final String TEST_DIR = "component/parfor/";
42+
private static final String TEST_DIR = "functions/federated/privacy/";
4343
private static final String HOME = SCRIPT_DIR + TEST_DIR;
4444
private static final String TEST_CLASS_DIR = TEST_DIR + FederatedPlanCostEnumeratorTest.class.getSimpleName() + "/";
4545

4646
@Override
4747
public void setUp() {}
4848

4949
@Test
50-
public void testDependencyAnalysis1() { runTest("parfor1.dml"); }
51-
52-
@Test
53-
public void testDependencyAnalysis3() { runTest("parfor3.dml"); }
54-
55-
@Test
56-
public void testDependencyAnalysis4() { runTest("parfor4.dml"); }
57-
58-
@Test
59-
public void testDependencyAnalysis6() { runTest("parfor6.dml"); }
60-
61-
@Test
62-
public void testDependencyAnalysis7() { runTest("parfor7.dml"); }
63-
50+
public void testDependencyAnalysis1() { runTest("cost.dml"); }
6451

6552
private void runTest( String scriptFilename ) {
6653
int index = scriptFilename.lastIndexOf(".dml");
@@ -83,7 +70,8 @@ private void runTest( String scriptFilename ) {
8370
dmlt.liveVariableAnalysis(prog);
8471
dmlt.validateParseTree(prog);
8572
dmlt.constructHops(prog);
86-
73+
dmlt.rewriteHopsDAG(prog);
74+
dmlt.constructLops(prog);
8775
/* TODO) In the current DAG, Hop's _outputMemEstimate is not initialized
8876
// This leads to incorrect fedplan generation, so test code needs to be modified
8977
// If needed, modify costEstimator to handle cases where _outputMemEstimate is not initialized
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
a = matrix(7,10,10);
23+
b = a + a^2;
24+
c = sqrt(b);
25+
print(sum(c));

0 commit comments

Comments
 (0)