Skip to content

Commit 0bca047

Browse files
min-gukmboehm7
authored andcommitted
[SYSTEMDS-3790] New Federated Planner MemoTable
Closes #2141.
1 parent 3081ecf commit 0bca047

File tree

2 files changed

+346
-0
lines changed

2 files changed

+346
-0
lines changed
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.hops.fedplanner;
21+
22+
import org.apache.sysds.hops.Hop;
23+
import org.apache.sysds.hops.fedplanner.FTypes.FType;
24+
import org.apache.commons.lang3.tuple.Pair;
25+
import org.apache.commons.lang3.tuple.ImmutablePair;
26+
27+
import java.util.Comparator;
28+
import java.util.HashMap;
29+
import java.util.List;
30+
import java.util.ArrayList;
31+
import java.util.Map;
32+
33+
/**
34+
* A Memoization Table for managing federated plans (`FedPlan`) based on
35+
* combinations of Hops and FTypes. Each combination is mapped to a list
36+
* of possible execution plans, allowing for pruning and optimization.
37+
*/
38+
public class MemoTable {
39+
40+
// Maps combinations of Hop ID and FType to lists of FedPlans
41+
private final Map<Pair<Long, FTypes.FType>, List<FedPlan>> hopMemoTable = new HashMap<>();
42+
43+
/**
44+
* Represents a federated execution plan with its cost and associated references.
45+
*/
46+
public static class FedPlan {
47+
@SuppressWarnings("unused")
48+
private final Hop hopRef; // The associated Hop object
49+
private final double cost; // Cost of this federated plan
50+
@SuppressWarnings("unused")
51+
private final List<Pair<Long, FType>> planRefs; // References to dependent plans
52+
53+
public FedPlan(Hop hopRef, double cost, List<Pair<Long, FType>> planRefs) {
54+
this.hopRef = hopRef;
55+
this.cost = cost;
56+
this.planRefs = planRefs;
57+
}
58+
59+
public double getCost() {
60+
return cost;
61+
}
62+
}
63+
64+
/**
65+
* Adds a single FedPlan to the memo table for a given Hop and FType.
66+
* If the entry already exists, the new FedPlan is appended to the list.
67+
*
68+
* @param hop The Hop object.
69+
* @param fType The associated FType.
70+
* @param fedPlan The FedPlan to add.
71+
*/
72+
public void addFedPlan(Hop hop, FType fType, FedPlan fedPlan) {
73+
if (contains(hop, fType)) {
74+
List<FedPlan> fedPlanList = get(hop, fType);
75+
fedPlanList.add(fedPlan);
76+
} else {
77+
List<FedPlan> fedPlanList = new ArrayList<>();
78+
fedPlanList.add(fedPlan);
79+
hopMemoTable.put(new ImmutablePair<>(hop.getHopID(), fType), fedPlanList);
80+
}
81+
}
82+
83+
/**
84+
* Adds multiple FedPlans to the memo table for a given Hop and FType.
85+
* If the entry already exists, the new FedPlans are appended to the list.
86+
*
87+
* @param hop The Hop object.
88+
* @param fType The associated FType.
89+
* @param newFedPlanList The list of FedPlans to add.
90+
*/
91+
public void addFedPlanList(Hop hop, FType fType, List<FedPlan> fedPlanList) {
92+
if (contains(hop, fType)) {
93+
List<FedPlan> prevFedPlanList = get(hop, fType);
94+
prevFedPlanList.addAll(fedPlanList);
95+
} else {
96+
hopMemoTable.put(new ImmutablePair<>(hop.getHopID(), fType), fedPlanList);
97+
}
98+
}
99+
100+
/**
101+
* Retrieves the list of FedPlans associated with a given Hop and FType.
102+
*
103+
* @param hop The Hop object.
104+
* @param fType The associated FType.
105+
* @return The list of FedPlans, or null if no entry exists.
106+
*/
107+
public List<FedPlan> get(Hop hop, FType fType) {
108+
return hopMemoTable.get(new ImmutablePair<>(hop.getHopID(), fType));
109+
}
110+
111+
/**
112+
* Checks if the memo table contains an entry for a given Hop and FType.
113+
*
114+
* @param hop The Hop object.
115+
* @param fType The associated FType.
116+
* @return True if the entry exists, false otherwise.
117+
*/
118+
public boolean contains(Hop hop, FType fType) {
119+
return hopMemoTable.containsKey(new ImmutablePair<>(hop.getHopID(), fType));
120+
}
121+
122+
/**
123+
* Prunes the FedPlans associated with a specific Hop and FType,
124+
* keeping only the plan with the minimum cost.
125+
*
126+
* @param hop The Hop object.
127+
* @param fType The associated FType.
128+
*/
129+
public void prunePlan(Hop hop, FType fType) {
130+
prunePlan(hopMemoTable.get(new ImmutablePair<>(hop.getHopID(), fType)));
131+
}
132+
133+
/**
134+
* Prunes all entries in the memo table, retaining only the minimum-cost
135+
* FedPlan for each entry.
136+
*/
137+
public void pruneAll() {
138+
for (Map.Entry<Pair<Long, FType>, List<FedPlan>> entry : hopMemoTable.entrySet()) {
139+
prunePlan(entry.getValue());
140+
}
141+
}
142+
143+
/**
144+
* Prunes the given list of FedPlans to retain only the plan with the minimum cost.
145+
*
146+
* @param fedPlanList The list of FedPlans to prune.
147+
*/
148+
private void prunePlan(List<FedPlan> fedPlanList) {
149+
if (fedPlanList.size() > 1) {
150+
// Find the FedPlan with the minimum cost
151+
FedPlan minCostPlan = fedPlanList.stream()
152+
.min(Comparator.comparingDouble(plan -> plan.cost))
153+
.orElse(null);
154+
155+
// Retain only the minimum cost plan
156+
fedPlanList.clear();
157+
fedPlanList.add(minCostPlan);
158+
}
159+
}
160+
}
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.component.federated;
21+
22+
import org.apache.sysds.hops.Hop;
23+
import org.apache.sysds.hops.fedplanner.FTypes;
24+
import org.apache.sysds.hops.fedplanner.MemoTable;
25+
import org.apache.sysds.hops.fedplanner.MemoTable.FedPlan;
26+
import org.apache.commons.lang3.tuple.Pair;
27+
import org.junit.Before;
28+
import org.junit.Test;
29+
import org.mockito.Mock;
30+
import org.mockito.MockitoAnnotations;
31+
32+
import java.util.ArrayList;
33+
import java.util.List;
34+
35+
import static org.junit.Assert.assertEquals;
36+
import static org.junit.Assert.assertFalse;
37+
import static org.junit.Assert.assertNotNull;
38+
import static org.junit.Assert.assertNull;
39+
import static org.junit.Assert.assertTrue;
40+
import static org.mockito.Mockito.when;
41+
42+
public class MemoTableTest {
43+
44+
private MemoTable memoTable;
45+
46+
@Mock
47+
private Hop mockHop1;
48+
49+
@Mock
50+
private Hop mockHop2;
51+
52+
private java.util.Random rand;
53+
54+
@Before
55+
public void setUp() {
56+
MockitoAnnotations.openMocks(this);
57+
memoTable = new MemoTable();
58+
59+
// Set up unique IDs for mock Hops
60+
when(mockHop1.getHopID()).thenReturn(1L);
61+
when(mockHop2.getHopID()).thenReturn(2L);
62+
63+
// Initialize random generator with fixed seed for reproducible tests
64+
rand = new java.util.Random(42);
65+
}
66+
67+
@Test
68+
public void testAddAndGetSingleFedPlan() {
69+
// Initialize test data
70+
List<Pair<Long, FTypes.FType>> planRefs = new ArrayList<>();
71+
FedPlan fedPlan = new FedPlan(mockHop1, 10.0, planRefs);
72+
73+
// Verify initial state
74+
List<FedPlan> result = memoTable.get(mockHop1, FTypes.FType.FULL);
75+
assertNull("Initial FedPlan list should be null before adding any plans", result);
76+
77+
// Add single FedPlan
78+
memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, fedPlan);
79+
80+
// Verify after addition
81+
result = memoTable.get(mockHop1, FTypes.FType.FULL);
82+
assertNotNull("FedPlan list should exist after adding a plan", result);
83+
assertEquals("FedPlan list should contain exactly one plan", 1, result.size());
84+
assertEquals("FedPlan cost should be exactly 10.0", 10.0, result.get(0).getCost(), 0.001);
85+
}
86+
87+
@Test
88+
public void testAddMultipleDuplicatedFedPlans() {
89+
// Initialize test data with duplicate costs
90+
List<Pair<Long, FTypes.FType>> planRefs = new ArrayList<>();
91+
List<FedPlan> fedPlans = new ArrayList<>();
92+
fedPlans.add(new FedPlan(mockHop1, 10.0, planRefs)); // Unique cost
93+
fedPlans.add(new FedPlan(mockHop1, 20.0, planRefs)); // First duplicate
94+
fedPlans.add(new FedPlan(mockHop1, 20.0, planRefs)); // Second duplicate
95+
96+
// Add multiple plans including duplicates
97+
memoTable.addFedPlanList(mockHop1, FTypes.FType.FULL, fedPlans);
98+
99+
// Verify handling of duplicate plans
100+
List<FedPlan> result = memoTable.get(mockHop1, FTypes.FType.FULL);
101+
assertNotNull("FedPlan list should exist after adding multiple plans", result);
102+
assertEquals("FedPlan list should maintain all plans including duplicates", 3, result.size());
103+
}
104+
105+
@Test
106+
public void testContains() {
107+
// Initialize test data
108+
List<Pair<Long, FTypes.FType>> planRefs = new ArrayList<>();
109+
FedPlan fedPlan = new FedPlan(mockHop1, 10.0, planRefs);
110+
111+
// Verify initial state
112+
assertFalse("MemoTable should not contain any entries initially",
113+
memoTable.contains(mockHop1, FTypes.FType.FULL));
114+
115+
// Add plan and verify presence
116+
memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, fedPlan);
117+
118+
assertTrue("MemoTable should contain entry after adding FedPlan",
119+
memoTable.contains(mockHop1, FTypes.FType.FULL));
120+
assertFalse("MemoTable should not contain entries for different Hop",
121+
memoTable.contains(mockHop2, FTypes.FType.FULL));
122+
}
123+
124+
@Test
125+
public void testPrunePlanPruneAll() {
126+
// Initialize base test data
127+
List<Pair<Long, FTypes.FType>> planRefs = new ArrayList<>();
128+
// Create separate FedPlan lists for independent testing of each Hop
129+
List<FedPlan> fedPlans1 = new ArrayList<>(); // Plans for mockHop1
130+
List<FedPlan> fedPlans2 = new ArrayList<>(); // Plans for mockHop2
131+
132+
// Generate random cost FedPlans for both Hops
133+
double minCost = Double.MAX_VALUE;
134+
int size = 100;
135+
for(int i = 0; i < size; i++) {
136+
double cost = rand.nextDouble() * 1000; // Random cost between 0 and 1000
137+
fedPlans1.add(new FedPlan(mockHop1, cost, planRefs));
138+
fedPlans2.add(new FedPlan(mockHop2, cost, planRefs));
139+
minCost = Math.min(minCost, cost);
140+
}
141+
142+
// Add FedPlan lists to MemoTable
143+
memoTable.addFedPlanList(mockHop1, FTypes.FType.FULL, fedPlans1);
144+
memoTable.addFedPlanList(mockHop2, FTypes.FType.FULL, fedPlans2);
145+
146+
// Test selective pruning on mockHop1
147+
memoTable.prunePlan(mockHop1, FTypes.FType.FULL);
148+
149+
// Get results for verification
150+
List<FedPlan> result1 = memoTable.get(mockHop1, FTypes.FType.FULL);
151+
List<FedPlan> result2 = memoTable.get(mockHop2, FTypes.FType.FULL);
152+
153+
// Verify selective pruning results
154+
assertNotNull("Pruned mockHop1 should maintain a FedPlan list", result1);
155+
assertEquals("Pruned mockHop1 should contain exactly one minimum cost plan", 1, result1.size());
156+
assertEquals("Pruned mockHop1's plan should have the minimum cost", minCost, result1.get(0).getCost(), 0.001);
157+
158+
// Verify unpruned Hop state
159+
assertNotNull("Unpruned mockHop2 should maintain a FedPlan list", result2);
160+
assertEquals("Unpruned mockHop2 should maintain all original plans", size, result2.size());
161+
162+
// Add additional plans to both Hops
163+
for(int i = 0; i < size; i++) {
164+
double cost = rand.nextDouble() * 1000;
165+
memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, new FedPlan(mockHop1, cost, planRefs));
166+
memoTable.addFedPlan(mockHop2, FTypes.FType.FULL, new FedPlan(mockHop2, cost, planRefs));
167+
minCost = Math.min(minCost, cost);
168+
}
169+
170+
// Test global pruning
171+
memoTable.pruneAll();
172+
173+
// Verify global pruning results
174+
assertNotNull("mockHop1 should maintain a FedPlan list after global pruning", result1);
175+
assertEquals("mockHop1 should contain exactly one minimum cost plan after global pruning",
176+
1, result1.size());
177+
assertEquals("mockHop1's plan should have the global minimum cost",
178+
minCost, result1.get(0).getCost(), 0.001);
179+
180+
assertNotNull("mockHop2 should maintain a FedPlan list after global pruning", result2);
181+
assertEquals("mockHop2 should contain exactly one minimum cost plan after global pruning",
182+
1, result2.size());
183+
assertEquals("mockHop2's plan should have the global minimum cost",
184+
minCost, result2.get(0).getCost(), 0.001);
185+
}
186+
}

0 commit comments

Comments
 (0)