Skip to content

Commit 56b7fa0

Browse files
committed
Add Kruskal and Prim implementation
1 parent f09df8f commit 56b7fa0

File tree

4 files changed

+438
-0
lines changed

4 files changed

+438
-0
lines changed
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
package algorithms.minimumSpanningTree.kruskal;
2+
3+
import java.util.ArrayList;
4+
import java.util.List;
5+
6+
import dataStructures.disjointSet.weightedUnion.DisjointSet;
7+
8+
/**
9+
* Implementation of Kruskal's Algorithm to find MSTs
10+
* Idea:
11+
* Sort all edges by weight in non-decreasing order. Consider the edges in this order. If an edge does not form a cycle
12+
* with the edges already in the MST, add it to the MST. Repeat until all nodes are in the MST.
13+
* Actual implementation:
14+
* An Edge class is implemented for easier sorting of edges by weight and for identifying the source and destination.
15+
* A Node class is implemented for easier tracking of nodes in the graph for the disjoint set.
16+
* A DisjointSet class is used to track the nodes in the graph and to determine if adding an edge will form a cycle.
17+
*/
18+
public class Kruskal {
19+
public static int[][] getKruskalMST(Node[] nodes, int[][] adjacencyMatrix) {
20+
int numOfNodes = nodes.length;
21+
List<Edge> edges = new ArrayList<>();
22+
23+
// Convert adjacency matrix to list of edges
24+
for (int i = 0; i < numOfNodes; i++) {
25+
for (int j = i + 1; j < numOfNodes; j++) {
26+
if (adjacencyMatrix[i][j] != Integer.MAX_VALUE) {
27+
edges.add(new Edge(nodes[i], nodes[j], adjacencyMatrix[i][j]));
28+
}
29+
}
30+
}
31+
32+
// Sort edges by weight
33+
edges.sort(Edge::compareTo);
34+
35+
// Initialize Disjoint Set for vertex tracking
36+
DisjointSet<Node> ds = new DisjointSet<>(nodes);
37+
38+
// MST adjacency matrix to be returned
39+
int[][] mstMatrix = new int[numOfNodes][numOfNodes];
40+
41+
// Initialize the MST matrix to represent no edges with Integer.MAX_VALUE and 0 for self loops
42+
for (int i = 0; i < nodes.length; i++) {
43+
for (int j = 0; j < nodes.length; j++) {
44+
mstMatrix[i][j] = (i == j) ? 0 : Integer.MAX_VALUE;
45+
}
46+
}
47+
48+
// Process edges to build MST
49+
for (Edge edge : edges) {
50+
Node source = edge.getSource();
51+
Node destination = edge.getDestination();
52+
if (!ds.find(source, destination)) {
53+
mstMatrix[source.getIndex()][destination.getIndex()] = edge.getWeight();
54+
mstMatrix[destination.getIndex()][source.getIndex()] = edge.getWeight();
55+
ds.union(source, destination);
56+
}
57+
}
58+
59+
return mstMatrix;
60+
}
61+
62+
/**
63+
* Node class to represent a node in the graph
64+
* Note: In our Node class, we do not allow the currMinWeight to be updated after initialization to prevent any
65+
* reference issues in the PriorityQueue.
66+
*/
67+
static class Node {
68+
private final int index; // Index of this node in the adjacency matrix
69+
private final String identifier;
70+
71+
/**
72+
* Constructor for Node
73+
* @param identifier
74+
* @param index
75+
*/
76+
public Node(String identifier, int index) {
77+
this.identifier = identifier;
78+
this.index = index;
79+
}
80+
81+
/**
82+
* Getter for identifier
83+
* @return identifier
84+
*/
85+
public String getIdentifier() {
86+
return identifier;
87+
}
88+
89+
public int getIndex() {
90+
return index;
91+
}
92+
93+
@Override
94+
public String toString() {
95+
return "Node{" + "identifier='" + identifier + '\'' + ", index=" + index + '}';
96+
}
97+
}
98+
99+
/**
100+
* Edge class to represent an edge in the graph
101+
*/
102+
static class Edge implements Comparable<Edge> {
103+
private final Node source;
104+
private final Node destination;
105+
private final int weight;
106+
107+
/**
108+
* Constructor for Edge
109+
*/
110+
public Edge(Node source, Node destination, int weight) {
111+
this.source = source;
112+
this.destination = destination;
113+
this.weight = weight;
114+
}
115+
116+
public int getWeight() {
117+
return weight;
118+
}
119+
120+
public Node getSource() {
121+
return source;
122+
}
123+
124+
public Node getDestination() {
125+
return destination;
126+
}
127+
128+
@Override
129+
public int compareTo(Edge other) {
130+
return Integer.compare(this.weight, other.weight);
131+
}
132+
}
133+
}
134+
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
package algorithms.minimumSpanningTree.prim;
2+
3+
import java.util.Arrays;
4+
import java.util.PriorityQueue;
5+
6+
/**
7+
* Implementation of Prim's Algorithm to find MSTs
8+
* Idea:
9+
* Starting from any source (this will be the first node to be in the MST), pick the lightest outgoing edge, and
10+
* include the node at the other end as part of a set of nodes S. Now repeatedly do the above by picking the lightest
11+
* outgoing edge adjacent to any node in the MST (ensure the other end of the node is not already in the MST).
12+
* Repeat until S contains all nodes in the graph. S is the MST.
13+
* Actual implementation:
14+
* No Edge class was implemented. Instead, the weights of the edges are stored in a 2D array adjacency matrix. An
15+
* adjacency list may be used instead
16+
* A Node class is implemented to encapsulate the current minimum weight to reach the node.
17+
*/
18+
public class Prim {
19+
public static int[][] getPrimsMST(Node[] nodes, int[][] adjacencyMatrix) {
20+
// Recall that PriorityQueue is a min heap by default
21+
PriorityQueue<Node> pq = new PriorityQueue<>((a, b) -> a.getCurrMinWeight() - b.getCurrMinWeight());
22+
int[][] mstMatrix = new int[nodes.length][nodes.length]; // MST adjacency matrix
23+
24+
int[] parent = new int[nodes.length]; // To track the parent node of each node in the MST
25+
Arrays.fill(parent, -1); // Initialize parent array with -1, indicating no parent
26+
27+
boolean[] visited = new boolean[nodes.length]; // To track visited nodes
28+
Arrays.fill(visited, false); // Initialize visited array with false, indicating not visited
29+
30+
// Initialize the MST matrix to represent no edges with Integer.MAX_VALUE and 0 for self loops
31+
for (int i = 0; i < nodes.length; i++) {
32+
for (int j = 0; j < nodes.length; j++) {
33+
mstMatrix[i][j] = (i == j) ? 0 : Integer.MAX_VALUE;
34+
}
35+
}
36+
37+
// Add all nodes to the priority queue, with each node's curr min weight already set to Integer.MAX_VALUE
38+
pq.addAll(Arrays.asList(nodes));
39+
40+
while (!pq.isEmpty()) {
41+
Node current = pq.poll();
42+
43+
int currentIndex = current.getIndex();
44+
45+
if (visited[currentIndex]) { // Skip if node is already visited
46+
continue;
47+
}
48+
49+
visited[currentIndex] = true;
50+
51+
for (int i = 0; i < nodes.length; i++) {
52+
if (adjacencyMatrix[currentIndex][i] != Integer.MAX_VALUE && !visited[nodes[i].getIndex()]) {
53+
int weight = adjacencyMatrix[currentIndex][i];
54+
55+
if (weight < nodes[i].getCurrMinWeight()) {
56+
Node newNode = new Node(nodes[i].getIdentifier(), nodes[i].getIndex(), weight);
57+
parent[i] = currentIndex; // Set current node as parent of adjacent node
58+
pq.add(newNode);
59+
}
60+
}
61+
}
62+
}
63+
64+
// Build MST matrix based on parent array
65+
for (int i = 1; i < nodes.length; i++) {
66+
int p = parent[i];
67+
if (p != -1) {
68+
int weight = adjacencyMatrix[p][i];
69+
mstMatrix[p][i] = weight;
70+
mstMatrix[i][p] = weight; // For undirected graphs
71+
}
72+
}
73+
74+
return mstMatrix;
75+
}
76+
77+
/**
78+
* Node class to represent a node in the graph
79+
* Note: In our Node class, we do not allow the currMinWeight to be updated after initialization to prevent any
80+
* reference issues in the PriorityQueue.
81+
*/
82+
static class Node {
83+
private final int currMinWeight; // Current minimum weight to get to this node
84+
private int index; // Index of this node in the adjacency matrix
85+
private final String identifier;
86+
87+
/**
88+
* Constructor for Node
89+
* @param identifier
90+
* @param index
91+
* @param currMinWeight
92+
*/
93+
public Node(String identifier, int index, int currMinWeight) {
94+
this.identifier = identifier;
95+
this.index = index;
96+
this.currMinWeight = currMinWeight;
97+
}
98+
99+
/**
100+
* Constructor for Node with default currMinWeight
101+
* @param identifier
102+
* @param index
103+
*/
104+
public Node(String identifier, int index) {
105+
this.identifier = identifier;
106+
this.index = index;
107+
this.currMinWeight = Integer.MAX_VALUE;
108+
}
109+
110+
/**
111+
* Getter and setter for currMinWeight
112+
*/
113+
public int getCurrMinWeight() {
114+
return currMinWeight;
115+
}
116+
117+
/**
118+
* Getter for identifier
119+
* @return identifier
120+
*/
121+
public String getIdentifier() {
122+
return identifier;
123+
}
124+
125+
public int getIndex() {
126+
return index;
127+
}
128+
129+
@Override
130+
public String toString() {
131+
return "Node{" + "identifier='" + identifier + '\'' + ", index=" + index + '}';
132+
}
133+
}
134+
}
135+
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
package algorithms.minimumSpanningTree.kruskal;
2+
3+
import static org.junit.Assert.assertArrayEquals;
4+
5+
import org.junit.Test;
6+
7+
public class KruskalTest {
8+
@Test
9+
public void test_simpleGraph() {
10+
// Graph setup (Adjacency Matrix)
11+
// B
12+
// / \
13+
// 1 1
14+
// / \
15+
// A - 1 - C
16+
int[][] adjacencyMatrix = {
17+
{0, 1, 1}, // A: A-B, A-C
18+
{1, 0, 1}, // B: B-A, B-C
19+
{1, 1, 0} // C: C-A, C-B
20+
};
21+
22+
Kruskal.Node[] nodes = {
23+
new Kruskal.Node("A", 0),
24+
new Kruskal.Node("B", 1),
25+
new Kruskal.Node("C", 2)
26+
};
27+
28+
// Run Kruskal's algorithm
29+
int[][] actualMST = Kruskal.getKruskalMST(nodes, adjacencyMatrix);
30+
31+
// Expected MST
32+
// A -1- B -1- C
33+
int[][] expectedMST = {
34+
{0, 1, 1}, // A: A-B, A-C
35+
{1, 0, Integer.MAX_VALUE}, // B: B-A
36+
{1, Integer.MAX_VALUE, 0} // C: C-A
37+
};
38+
39+
// Assertion
40+
assertArrayEquals(expectedMST, actualMST);
41+
}
42+
43+
@Test
44+
public void test_complexGraph() {
45+
// Graph setup
46+
// A
47+
// / | \
48+
// 1 4 3
49+
/// | \
50+
//B --3-- D
51+
// \ | /
52+
// 2 4 1
53+
// \|/
54+
// C
55+
int[][] adjacencyMatrix = {
56+
{0, 1, 4, 3}, // A: A-B, A-C, A-D
57+
{1, 0, 2, 3}, // B: B-A, B-C, B-D
58+
{4, 2, 0, 1}, // C: C-A, C-B, C-D
59+
{3, 3, 1, 0} // D: D-A, D-B, D-C
60+
};
61+
62+
Kruskal.Node[] nodes = {
63+
new Kruskal.Node("A", 0),
64+
new Kruskal.Node("B", 1),
65+
new Kruskal.Node("C", 2),
66+
new Kruskal.Node("D", 3)
67+
};
68+
69+
// Run Prim's algorithm
70+
int[][] actualMST = Kruskal.getKruskalMST(nodes, adjacencyMatrix);
71+
72+
// Expected MST
73+
// Based on the graph, assuming the MST is correctly computed
74+
int[][] expectedMST = {
75+
{0, 1, Integer.MAX_VALUE, Integer.MAX_VALUE}, // A: A-B
76+
{1, 0, 2, Integer.MAX_VALUE}, // B: B-A, B-C
77+
{Integer.MAX_VALUE, 2, 0, 1}, // C: C-B, C-D
78+
{Integer.MAX_VALUE, Integer.MAX_VALUE, 1, 0} // D: D-C
79+
};
80+
81+
// Assertion
82+
assertArrayEquals(expectedMST, actualMST);
83+
}
84+
}

0 commit comments

Comments
 (0)