11import gleam/dict
22import gleam/list
3+ import gleam/pair
34
45pub opaque type DisjointSet ( a) {
56 DisjointSet ( parents : dict . Dict ( a, a) , sizes : dict . Dict ( a, Int ) )
67}
78
9+ /// Create a disjoint set from a list of items. Initially all items are
10+ /// separate sets.
811pub fn from_list ( items : List ( a) ) -> DisjointSet ( a) {
912 let # ( parents , sizes ) =
1013 list . fold ( items , # ( dict . new ( ) , dict . new ( ) ) , fn ( acc , item ) {
@@ -13,6 +16,9 @@ pub fn from_list(items: List(a)) -> DisjointSet(a) {
1316 DisjointSet ( parents : , sizes : )
1417}
1518
19+ /// Find the set to which an item belogs. This will return an updated
20+ /// copy of the disjoint set with some path optimization and the root
21+ /// element defining the set.
1622pub fn find ( dj : DisjointSet ( a) , item : a) -> Result ( # ( DisjointSet ( a) , a) , Nil ) {
1723 case dict . get ( dj . parents , item ) {
1824 Ok ( parent ) -> {
@@ -22,13 +28,13 @@ pub fn find(dj: DisjointSet(a), item: a) -> Result(#(DisjointSet(a), a), Nil) {
2228 // Path compression to make each item point to a single representative set parent
2329 case find ( dj , parent ) {
2430 Error ( _ ) -> Error ( Nil )
25- Ok ( # ( dj , rep_parent ) ) -> {
31+ Ok ( # ( dj , root ) ) -> {
2632 let updated_dj =
2733 DisjointSet (
28- parents : dict . insert ( dj . parents , item , rep_parent ) ,
34+ parents : dict . insert ( dj . parents , item , root ) ,
2935 sizes : dj . sizes ,
3036 )
31- Ok ( # ( updated_dj , rep_parent ) )
37+ Ok ( # ( updated_dj , root ) )
3238 }
3339 }
3440 }
@@ -38,6 +44,10 @@ pub fn find(dj: DisjointSet(a), item: a) -> Result(#(DisjointSet(a), a), Nil) {
3844 }
3945}
4046
47+ /// Create a union set from the set containing the element x
48+ /// and the set containing the element y. If they are already
49+ /// in the same set nothing happens. This returns an updated
50+ /// disjoint set.
4151pub fn union ( dj : DisjointSet ( a) , x : a, y : a) -> Result ( DisjointSet ( a) , Nil ) {
4252 case find ( dj , x ) {
4353 Error ( Nil ) -> Error ( Nil )
@@ -67,13 +77,15 @@ pub fn union(dj: DisjointSet(a), x: a, y: a) -> Result(DisjointSet(a), Nil) {
6777 }
6878}
6979
80+ /// Find the size of the set which contains item.
7081pub fn size ( dj : DisjointSet ( a) , item : a) -> Result ( Int , Nil ) {
7182 case find ( dj , item ) {
7283 Error ( Nil ) -> Error ( Nil )
7384 Ok ( # ( dj , rep_parent ) ) -> dict . get ( dj . sizes , rep_parent )
7485 }
7586}
7687
88+ // I am thinking about just providing all sets
7789pub fn to_list ( dj : DisjointSet ( a) , item : a) -> Result ( List ( a) , Nil ) {
7890 case find ( dj , item ) {
7991 Error ( Nil ) -> Error ( Nil )
@@ -91,11 +103,7 @@ pub fn to_list(dj: DisjointSet(a), item: a) -> Result(List(a), Nil) {
91103}
92104
93105pub fn setlist ( dj : DisjointSet ( a) ) -> List ( a) {
94- dict . keys ( dj . parents )
95- |> list . filter ( fn ( item ) {
96- case find ( dj , item ) {
97- Ok ( # ( _ , p ) ) if p == item -> True
98- _ -> False
99- }
100- } )
106+ dict . to_list ( dj . parents )
107+ |> list . filter ( fn ( pair ) { pair . 0 == pair . 1 } )
108+ |> list . map ( pair . first )
101109}
0 commit comments