@@ -4,12 +4,14 @@ defmodule AdventOfCode.Algorithms.DisjointSet do
44
55 More on this: https://en.wikipedia.org/wiki/Disjoint_sets
66 """
7+ alias AdventOfCode.Algorithms.DisjointSet
78 defstruct parents: % { } , ranks: % { }
89
9- @ type mapped_array ( ) :: % { required ( non_neg_integer ( ) ) => non_neg_integer ( ) }
10- @ type value ( ) :: non_neg_integer ( )
10+ @ type mapped_parents ( ) :: % { term ( ) => term ( ) }
11+ @ type mapped_array ( ) :: % { required ( non_neg_integer ( ) ) => term ( ) }
12+ @ type value ( ) :: term ( )
1113 @ type t ( ) :: % __MODULE__ {
12- parents: mapped_array ( ) ,
14+ parents: mapped_parents ( ) ,
1315 ranks: mapped_array ( )
1416 }
1517
@@ -28,7 +30,7 @@ defmodule AdventOfCode.Algorithms.DisjointSet do
2830 }
2931
3032 """
31- @ spec new ( non_neg_integer ( ) | List . t ( ) ) :: t ( )
33+ @ spec new ( non_neg_integer ( ) | [ term ( ) ] ) :: t ( )
3234 def new ( 0 ) , do: % __MODULE__ { }
3335
3436 def new ( size ) when is_integer ( size ) do
@@ -104,8 +106,46 @@ defmodule AdventOfCode.Algorithms.DisjointSet do
104106 end
105107
106108 @ doc """
107- Performs a union between two elements and returns the updated set. `:error` case is matched so that it fails
108- in a piped flow.
109+ Performs a union between two elements and returns a tuple with set and find status.
110+
111+ If either of the element being unionized were not a part of the set, then it returns `:error` as first element
112+ of the tuple otherwise `:ok`
113+
114+ ## Example
115+
116+ iex> set = DisjointSet.new(5)
117+ iex> set =
118+ ...> set
119+ ...> |> DisjointSet.union(0, 2)
120+ ...> |> DisjointSet.union(4, 2)
121+ ...> |> DisjointSet.union(3, 1)
122+ iex> set
123+ %DisjointSet{
124+ parents: %{0 => 0, 1 => 3, 2 => 0, 3 => 3, 4 => 0},
125+ ranks: %{0 => 2, 1 => 1, 2 => 1, 3 => 2, 4 => 1}
126+ }
127+ iex> DisjointSet.strict_union(set, 3, 1) == {:ok, set}
128+ true
129+
130+ iex> ds = DisjointSet.new(1)
131+ iex> DisjointSet.strict_union(ds, 100, 200) == {:error, ds}
132+ true
133+
134+ """
135+ @ spec strict_union ( t ( ) , value ( ) , value ( ) ) :: { :ok , t ( ) } | { :error , t ( ) }
136+ def strict_union ( % __MODULE__ { } = disjoint_set , a , b ) do
137+ with { root_a , disjoint_set_1 } <- find ( disjoint_set , a ) ,
138+ { root_b , disjoint_set_2 } <- find ( disjoint_set_1 , b ) do
139+ { :ok , union_by_rank ( disjoint_set_2 , root_a , root_b ) }
140+ else
141+ _ -> { :error , disjoint_set }
142+ end
143+ end
144+
145+ @ doc """
146+ Performs a union between two elements and returns the updated set. It returns the set, either original or
147+ updated depending on whether there was a match or not. See `DisjointSet.strict_union` if you want to know
148+ element membership.
109149
110150 ## Example
111151
@@ -130,16 +170,10 @@ defmodule AdventOfCode.Algorithms.DisjointSet do
130170 """
131171 @ spec union ( t ( ) , value ( ) , value ( ) ) :: t ( )
132172 def union ( % __MODULE__ { } = disjoint_set , a , b ) do
133- with { root_a , disjoint_set } <- find ( disjoint_set , a ) ,
134- { root_b , disjoint_set } <- find ( disjoint_set , b ) do
135- union_by_rank ( disjoint_set , root_a , root_b )
136- else
137- _ -> disjoint_set
138- end
173+ { _ , disjoint_set } = DisjointSet . strict_union ( disjoint_set , a , b )
174+ disjoint_set
139175 end
140176
141- def union ( :error , _ , _ ) , do: :error
142-
143177 @ doc """
144178 Returns the connected components of a set of data. `:error` case is matched so that it fails
145179 in a piped flow.
@@ -160,8 +194,9 @@ defmodule AdventOfCode.Algorithms.DisjointSet do
160194
161195 """
162196 @ spec components ( t ( ) | :error ) :: [ [ term ( ) ] ]
163- def components ( % __MODULE__ { parents: parents } ) do
197+ def components ( % __MODULE__ { parents: parents } = disjoint_set ) do
164198 parents
199+ |> Enum . map ( fn { k , _ } -> { k , find ( disjoint_set , k ) |> elem ( 0 ) } end )
165200 |> Enum . group_by ( & elem ( & 1 , 1 ) , fn { a , _ } -> a end )
166201 |> Map . values ( )
167202 |> Enum . map ( & Enum . into ( & 1 , % MapSet { } ) )
0 commit comments