diff --git a/MinimumSpanningTree b/MinimumSpanningTree new file mode 100644 index 000000000..7750ddf5a --- /dev/null +++ b/MinimumSpanningTree @@ -0,0 +1,95 @@ +package main + +import ( + "fmt" + "sort" +) + +type Edge struct { + u, v int + w int +} + +type ByWeight []Edge + +func (a ByWeight) Len() int { return len(a) } +func (a ByWeight) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a ByWeight) Less(i, j int) bool { return a[i].w < a[j].w } + +type UnionFind struct { + parent []int + rank []int +} + +func NewUnionFind(n int) *UnionFind { + p := make([]int, n) + r := make([]int, n) + for i := 0; i < n; i++ { + p[i] = i + r[i] = 0 + } + return &UnionFind{parent: p, rank: r} +} + +func (dsu *UnionFind) Find(x int) int { + if dsu.parent[x] != x { + dsu.parent[x] = dsu.Find(dsu.parent[x]) + } + return dsu.parent[x] +} + +func (dsu *UnionFind) Union(a, b int) bool { + ra := dsu.Find(a) + rb := dsu.Find(b) + if ra == rb { + return false + } + if dsu.rank[ra] < dsu.rank[rb] { + dsu.parent[ra] = rb + } else if dsu.rank[ra] > dsu.rank[rb] { + dsu.parent[rb] = ra + } else { + dsu.parent[rb] = ra + dsu.rank[ra]++ + } + return true +} + +func KruskalMST(n int, edges []Edge) ([]Edge, int) { + sort.Sort(ByWeight(edges)) + dsu := NewUnionFind(n) + + var mst []Edge + totalWeight := 0 + for _, e := range edges { + if dsu.Union(e.u, e.v) { + mst = append(mst, e) + totalWeight += e.w + // early stop: if we already have n-1 edges + if len(mst) == n-1 { + break + } + } + } + return mst, totalWeight +} + +func main() { + + n := 4 + edges := []Edge{ + {u: 0, v: 1, w: 1}, + {u: 0, v: 2, w: 4}, + {u: 1, v: 2, w: 2}, + {u: 1, v: 3, w: 3}, + {u: 2, v: 3, w: 5}, + } + + mst, total := KruskalMST(n, edges) + + fmt.Println("MST edges (u - v : weight):") + for _, e := range mst { + fmt.Printf("%d - %d : %d\n", e.u, e.v, e.w) + } + fmt.Printf("Total weight: %d\n", total) +}