Skip to content

Commit 8b8434c

Browse files
authored
feat(graph): Add Edmonds's algorithm for minimum spanning arborescence (#6771)
* feat(graph): Add Edmonds's algorithm for minimum spanning arborescence * test: Add test cases to achieve 100% coverage
1 parent f30d101 commit 8b8434c

File tree

2 files changed

+373
-0
lines changed

2 files changed

+373
-0
lines changed
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
package com.thealgorithms.graph;
2+
3+
import java.util.ArrayList;
4+
import java.util.Arrays;
5+
import java.util.List;
6+
7+
/**
8+
* An implementation of Edmonds's algorithm (also known as the Chu–Liu/Edmonds algorithm)
9+
* for finding a Minimum Spanning Arborescence (MSA).
10+
*
11+
* <p>An MSA is a directed graph equivalent of a Minimum Spanning Tree. It is a tree rooted
12+
* at a specific vertex 'r' that reaches all other vertices, such that the sum of the
13+
* weights of its edges is minimized.
14+
*
15+
* <p>The algorithm works recursively:
16+
* <ol>
17+
* <li>For each vertex other than the root, select the incoming edge with the minimum weight.</li>
18+
* <li>If the selected edges form a spanning arborescence, it is the MSA.</li>
19+
* <li>If cycles are formed, contract each cycle into a new "supernode".</li>
20+
* <li>Modify the weights of edges entering the new supernode.</li>
21+
* <li>Recursively call the algorithm on the contracted graph.</li>
22+
* <li>The final cost is the sum of the initial edge selections and the result of the recursive call.</li>
23+
* </ol>
24+
*
25+
* <p>Time Complexity: O(E * V) where E is the number of edges and V is the number of vertices.
26+
*
27+
* <p>References:
28+
* <ul>
29+
* <li><a href="https://en.wikipedia.org/wiki/Edmonds%27_algorithm">Wikipedia: Edmonds's algorithm</a></li>
30+
* </ul>
31+
*/
32+
public final class Edmonds {
33+
34+
private Edmonds() {
35+
}
36+
37+
/**
38+
* Represents a directed weighted edge in the graph.
39+
*/
40+
public static class Edge {
41+
final int from;
42+
final int to;
43+
final long weight;
44+
45+
/**
46+
* Constructs a directed edge.
47+
*
48+
* @param from source vertex
49+
* @param to destination vertex
50+
* @param weight edge weight
51+
*/
52+
public Edge(int from, int to, long weight) {
53+
this.from = from;
54+
this.to = to;
55+
this.weight = weight;
56+
}
57+
}
58+
59+
/**
60+
* Computes the total weight of the Minimum Spanning Arborescence of a directed,
61+
* weighted graph from a given root.
62+
*
63+
* @param numVertices the number of vertices, labeled {@code 0..numVertices-1}
64+
* @param edges list of directed edges in the graph
65+
* @param root the root vertex
66+
* @return the total weight of the MSA. Returns -1 if not all vertices are reachable
67+
* from the root or if a valid arborescence cannot be formed.
68+
* @throws IllegalArgumentException if {@code numVertices <= 0} or {@code root} is out of range.
69+
*/
70+
public static long findMinimumSpanningArborescence(int numVertices, List<Edge> edges, int root) {
71+
if (root < 0 || root >= numVertices) {
72+
throw new IllegalArgumentException("Invalid number of vertices or root");
73+
}
74+
if (numVertices == 1) {
75+
return 0;
76+
}
77+
78+
return findMSARecursive(numVertices, edges, root);
79+
}
80+
81+
/**
82+
* Recursive helper method for finding MSA.
83+
*/
84+
private static long findMSARecursive(int n, List<Edge> edges, int root) {
85+
long[] minWeightEdge = new long[n];
86+
int[] predecessor = new int[n];
87+
Arrays.fill(minWeightEdge, Long.MAX_VALUE);
88+
Arrays.fill(predecessor, -1);
89+
90+
for (Edge edge : edges) {
91+
if (edge.to != root && edge.weight < minWeightEdge[edge.to]) {
92+
minWeightEdge[edge.to] = edge.weight;
93+
predecessor[edge.to] = edge.from;
94+
}
95+
}
96+
// Check if all non-root nodes are reachable
97+
for (int i = 0; i < n; i++) {
98+
if (i != root && minWeightEdge[i] == Long.MAX_VALUE) {
99+
return -1; // No spanning arborescence exists
100+
}
101+
}
102+
int[] cycleId = new int[n];
103+
Arrays.fill(cycleId, -1);
104+
boolean[] visited = new boolean[n];
105+
int cycleCount = 0;
106+
107+
for (int i = 0; i < n; i++) {
108+
if (visited[i]) {
109+
continue;
110+
}
111+
112+
List<Integer> path = new ArrayList<>();
113+
int curr = i;
114+
115+
// Follow predecessor chain
116+
while (curr != -1 && !visited[curr]) {
117+
visited[curr] = true;
118+
path.add(curr);
119+
curr = predecessor[curr];
120+
}
121+
122+
// If we hit a visited node, check if it forms a cycle
123+
if (curr != -1) {
124+
boolean inCycle = false;
125+
for (int node : path) {
126+
if (node == curr) {
127+
inCycle = true;
128+
}
129+
if (inCycle) {
130+
cycleId[node] = cycleCount;
131+
}
132+
}
133+
if (inCycle) {
134+
cycleCount++;
135+
}
136+
}
137+
}
138+
if (cycleCount == 0) {
139+
long totalWeight = 0;
140+
for (int i = 0; i < n; i++) {
141+
if (i != root) {
142+
totalWeight += minWeightEdge[i];
143+
}
144+
}
145+
return totalWeight;
146+
}
147+
long cycleWeightSum = 0;
148+
for (int i = 0; i < n; i++) {
149+
if (cycleId[i] >= 0) {
150+
cycleWeightSum += minWeightEdge[i];
151+
}
152+
}
153+
154+
// Map old nodes to new nodes (cycles become supernodes)
155+
int[] newNodeMap = new int[n];
156+
int[] cycleToNewNode = new int[cycleCount];
157+
int newN = 0;
158+
159+
// Assign new node IDs to cycles first
160+
for (int i = 0; i < cycleCount; i++) {
161+
cycleToNewNode[i] = newN++;
162+
}
163+
164+
// Assign new node IDs to non-cycle nodes
165+
for (int i = 0; i < n; i++) {
166+
if (cycleId[i] == -1) {
167+
newNodeMap[i] = newN++;
168+
} else {
169+
newNodeMap[i] = cycleToNewNode[cycleId[i]];
170+
}
171+
}
172+
173+
int newRoot = newNodeMap[root];
174+
175+
// Build contracted graph
176+
List<Edge> newEdges = new ArrayList<>();
177+
for (Edge edge : edges) {
178+
int uCycleId = cycleId[edge.from];
179+
int vCycleId = cycleId[edge.to];
180+
181+
// Skip edges internal to a cycle
182+
if (uCycleId >= 0 && uCycleId == vCycleId) {
183+
continue;
184+
}
185+
186+
int newU = newNodeMap[edge.from];
187+
int newV = newNodeMap[edge.to];
188+
189+
long newWeight = edge.weight;
190+
// Adjust weight for edges entering a cycle
191+
if (vCycleId >= 0) {
192+
newWeight -= minWeightEdge[edge.to];
193+
}
194+
195+
if (newU != newV) {
196+
newEdges.add(new Edge(newU, newV, newWeight));
197+
}
198+
}
199+
return cycleWeightSum + findMSARecursive(newN, newEdges, newRoot);
200+
}
201+
}
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
package com.thealgorithms.graph;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.junit.jupiter.api.Assertions.assertThrows;
5+
6+
import java.util.ArrayList;
7+
import java.util.List;
8+
import org.junit.jupiter.api.Test;
9+
10+
class EdmondsTest {
11+
12+
@Test
13+
void testSimpleGraphNoCycle() {
14+
int n = 4;
15+
int root = 0;
16+
List<Edmonds.Edge> edges = new ArrayList<>();
17+
edges.add(new Edmonds.Edge(0, 1, 10));
18+
edges.add(new Edmonds.Edge(0, 2, 1));
19+
edges.add(new Edmonds.Edge(2, 1, 2));
20+
edges.add(new Edmonds.Edge(2, 3, 5));
21+
22+
// Expected arborescence edges: (0,2), (2,1), (2,3)
23+
// Weights: 1 + 2 + 5 = 8
24+
long result = Edmonds.findMinimumSpanningArborescence(n, edges, root);
25+
assertEquals(8, result);
26+
}
27+
28+
@Test
29+
void testGraphWithOneCycle() {
30+
int n = 4;
31+
int root = 0;
32+
List<Edmonds.Edge> edges = new ArrayList<>();
33+
edges.add(new Edmonds.Edge(0, 1, 10));
34+
edges.add(new Edmonds.Edge(2, 1, 4));
35+
edges.add(new Edmonds.Edge(1, 2, 5));
36+
edges.add(new Edmonds.Edge(2, 3, 6));
37+
38+
// Min edges: (2,1, w=4), (1,2, w=5), (2,3, w=6)
39+
// Cycle: 1 -> 2 -> 1, cost = 4 + 5 = 9
40+
// Contract {1,2} to C.
41+
// New edge (0,C) with w = 10 - min_in(1) = 10 - 4 = 6
42+
// New edge (C,3) with w = 6
43+
// Contracted MSA cost = 6 + 6 = 12
44+
// Total cost = cycle_cost + contracted_msa_cost = 9 + 12 = 21
45+
long result = Edmonds.findMinimumSpanningArborescence(n, edges, root);
46+
assertEquals(21, result);
47+
}
48+
49+
@Test
50+
void testComplexGraphWithCycle() {
51+
int n = 6;
52+
int root = 0;
53+
List<Edmonds.Edge> edges = new ArrayList<>();
54+
edges.add(new Edmonds.Edge(0, 1, 10));
55+
edges.add(new Edmonds.Edge(0, 2, 20));
56+
edges.add(new Edmonds.Edge(1, 2, 5));
57+
edges.add(new Edmonds.Edge(2, 3, 10));
58+
edges.add(new Edmonds.Edge(3, 1, 3));
59+
edges.add(new Edmonds.Edge(1, 4, 7));
60+
edges.add(new Edmonds.Edge(3, 4, 2));
61+
edges.add(new Edmonds.Edge(4, 5, 5));
62+
63+
// Min edges: (3,1,3), (1,2,5), (2,3,10), (3,4,2), (4,5,5)
64+
// Cycle: 1->2->3->1, cost = 5+10+3=18
65+
// Contract {1,2,3} to C.
66+
// Edge (0,1,10) -> (0,C), w = 10-3=7
67+
// Edge (0,2,20) -> (0,C), w = 20-5=15. Min is 7.
68+
// Edge (1,4,7) -> (C,4,7)
69+
// Edge (3,4,2) -> (C,4,2). Min is 2.
70+
// Edge (4,5,5) -> (4,5,5)
71+
// Contracted MSA: (0,C,7), (C,4,2), (4,5,5). Cost = 7+2+5=14
72+
// Total cost = 18 + 14 = 32
73+
long result = Edmonds.findMinimumSpanningArborescence(n, edges, root);
74+
assertEquals(32, result);
75+
}
76+
77+
@Test
78+
void testUnreachableNode() {
79+
int n = 4;
80+
int root = 0;
81+
List<Edmonds.Edge> edges = new ArrayList<>();
82+
edges.add(new Edmonds.Edge(0, 1, 10));
83+
edges.add(new Edmonds.Edge(2, 3, 5)); // Node 2 and 3 are unreachable from root 0
84+
85+
long result = Edmonds.findMinimumSpanningArborescence(n, edges, root);
86+
assertEquals(-1, result);
87+
}
88+
89+
@Test
90+
void testNoEdgesToNonRootNodes() {
91+
int n = 3;
92+
int root = 0;
93+
List<Edmonds.Edge> edges = new ArrayList<>();
94+
edges.add(new Edmonds.Edge(0, 1, 10)); // Node 2 is unreachable
95+
96+
long result = Edmonds.findMinimumSpanningArborescence(n, edges, root);
97+
assertEquals(-1, result);
98+
}
99+
100+
@Test
101+
void testSingleNode() {
102+
int n = 1;
103+
int root = 0;
104+
List<Edmonds.Edge> edges = new ArrayList<>();
105+
106+
long result = Edmonds.findMinimumSpanningArborescence(n, edges, root);
107+
assertEquals(0, result);
108+
}
109+
110+
@Test
111+
void testInvalidInputThrowsException() {
112+
List<Edmonds.Edge> edges = new ArrayList<>();
113+
114+
assertThrows(IllegalArgumentException.class, () -> Edmonds.findMinimumSpanningArborescence(0, edges, 0));
115+
assertThrows(IllegalArgumentException.class, () -> Edmonds.findMinimumSpanningArborescence(5, edges, -1));
116+
assertThrows(IllegalArgumentException.class, () -> Edmonds.findMinimumSpanningArborescence(5, edges, 5));
117+
}
118+
119+
@Test
120+
void testCoverageForEdgeSelectionLogic() {
121+
int n = 3;
122+
int root = 0;
123+
List<Edmonds.Edge> edges = new ArrayList<>();
124+
125+
// This will cover the `edge.weight < minWeightEdge[edge.to]` being false.
126+
edges.add(new Edmonds.Edge(0, 1, 10));
127+
edges.add(new Edmonds.Edge(2, 1, 20));
128+
129+
// This will cover the `edge.to != root` being false.
130+
edges.add(new Edmonds.Edge(1, 0, 100));
131+
132+
// A regular edge to make the graph complete
133+
edges.add(new Edmonds.Edge(0, 2, 5));
134+
135+
// Expected MSA: (0,1, w=10) and (0,2, w=5). Total weight = 15.
136+
long result = Edmonds.findMinimumSpanningArborescence(n, edges, root);
137+
assertEquals(15, result);
138+
}
139+
140+
@Test
141+
void testCoverageForContractedSelfLoop() {
142+
int n = 4;
143+
int root = 0;
144+
List<Edmonds.Edge> edges = new ArrayList<>();
145+
146+
// Connect root to the cycle components
147+
edges.add(new Edmonds.Edge(0, 1, 20));
148+
149+
// Create a cycle 1 -> 2 -> 1
150+
edges.add(new Edmonds.Edge(1, 2, 5));
151+
edges.add(new Edmonds.Edge(2, 1, 5));
152+
153+
// This is the CRITICAL edge for coverage:
154+
// It connects two nodes (1 and 2) that are part of the SAME cycle.
155+
// After contracting cycle {1, 2} into a supernode C, this edge becomes (C, C),
156+
// which means newU == newV. This will trigger the `false` branch of the `if`.
157+
edges.add(new Edmonds.Edge(1, 1, 100)); // Also a self-loop on a cycle node.
158+
159+
// Add another edge to ensure node 3 is reachable
160+
edges.add(new Edmonds.Edge(1, 3, 10));
161+
162+
// Cycle {1,2} has cost 5+5=10.
163+
// Contract {1,2} to supernode C.
164+
// Edge (0,1,20) becomes (0,C, w=20-5=15).
165+
// Edge (1,3,10) becomes (C,3, w=10).
166+
// Edge (1,1,100) is discarded because newU == newV.
167+
// Cost of contracted graph = 15 + 10 = 25.
168+
// Total cost = cycle cost + contracted cost = 10 + 25 = 35.
169+
long result = Edmonds.findMinimumSpanningArborescence(n, edges, root);
170+
assertEquals(35, result);
171+
}
172+
}

0 commit comments

Comments
 (0)