Skip to content

Commit 8d2effd

Browse files
committed
the discord was abuzz with disjoint sets so I had to try it
1 parent 2523824 commit 8d2effd

File tree

3 files changed

+133
-58
lines changed

3 files changed

+133
-58
lines changed

src/day08/solution.gleam

Lines changed: 21 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
import gleam/dict
21
import gleam/int
32
import gleam/io
43
import gleam/list
54
import gleam/result
6-
import gleam/set
75
import gleam/string
86
import internal/aoc_utils
7+
import internal/disjoint_set
98

109
pub fn main() {
1110
let filename = "inputs/day08.txt"
@@ -36,12 +35,11 @@ pub fn solve_p1(lines: List(String), connections: Int) -> Result(String, String)
3635
|> list.take(connections)
3736
|> list.map(fn(t) { t.0 })
3837

39-
wire(dict.new(), closest_boxes)
40-
|> dict.to_list
41-
|> list.map(fn(dtup) { dtup.1 })
42-
|> set.from_list
43-
|> set.to_list
44-
|> list.map(set.size)
38+
let circuits = wire(disjoint_set.from_list(boxes), closest_boxes)
39+
40+
disjoint_set.setlist(circuits)
41+
|> list.map(disjoint_set.size(circuits, _))
42+
|> result.values
4543
|> list.sort(int.compare)
4644
|> list.reverse
4745
|> list.take(3)
@@ -63,7 +61,11 @@ pub fn solve_p2(lines: List(String)) -> Result(String, String) {
6361
|> list.sort(fn(v1, v2) { int.compare(v1.1, v2.1) })
6462
|> list.map(fn(t) { t.0 })
6563

66-
build_circuit_sized(list.length(boxes), dict.new(), closest_boxes)
64+
build_circuit_sized(
65+
list.length(boxes),
66+
disjoint_set.from_list(boxes),
67+
closest_boxes,
68+
)
6769
|> fn(v) { { v.0 }.x * { v.1 }.x }
6870
|> int.to_string
6971
|> Ok
@@ -90,21 +92,12 @@ fn parse_line(line: String) -> Junction {
9092
// When I combine two boxes I have to update the circuit of all involed
9193
// junctions.
9294
fn wire(
93-
circuit_table: dict.Dict(Junction, set.Set(Junction)),
95+
circuit_table: disjoint_set.DisjointSet(Junction),
9496
connections: List(#(Junction, Junction)),
95-
) -> dict.Dict(Junction, set.Set(Junction)) {
97+
) -> disjoint_set.DisjointSet(Junction) {
9698
list.fold(connections, circuit_table, fn(acc, conn) {
97-
// Get the circuit for a box, if it's not part of a circuit yet
98-
// just provide a set containing only itself
99-
let s1 = dict.get(acc, conn.0) |> result.unwrap(set.from_list([conn.0]))
100-
let s2 = dict.get(acc, conn.1) |> result.unwrap(set.from_list([conn.1]))
101-
102-
// new combined circuit
103-
let combined = set.union(s1, s2)
104-
105-
set.fold(combined, acc, fn(acc_inner, jb) {
106-
dict.insert(acc_inner, jb, combined)
107-
})
99+
let assert Ok(new_table) = disjoint_set.union(acc, conn.0, conn.1)
100+
new_table
108101
})
109102
}
110103

@@ -113,33 +106,17 @@ fn wire(
113106
// connection that achieved that size.
114107
fn build_circuit_sized(
115108
size: Int,
116-
circuit_table: dict.Dict(Junction, set.Set(Junction)),
109+
circuit_table: disjoint_set.DisjointSet(Junction),
117110
connections: List(#(Junction, Junction)),
118111
) -> #(Junction, Junction) {
119112
case connections {
120113
[] -> panic as "unable to build requested circuit size"
121114
[first, ..rest] -> {
122-
// Get the circuit for a box, if it's not part of a circuit yet
123-
// just provide a set containing only itself
124-
let s1 =
125-
dict.get(circuit_table, first.0)
126-
|> result.unwrap(set.from_list([first.0]))
127-
let s2 =
128-
dict.get(circuit_table, first.1)
129-
|> result.unwrap(set.from_list([first.1]))
130-
131-
// new combined circuit
132-
let combined = set.union(s1, s2)
133-
134-
case set.size(combined) {
135-
s if s == size -> first
136-
_ -> {
137-
let new_table =
138-
set.fold(combined, circuit_table, fn(acc, jb) {
139-
dict.insert(acc, jb, combined)
140-
})
141-
build_circuit_sized(size, new_table, rest)
142-
}
115+
let assert Ok(new_table) =
116+
disjoint_set.union(circuit_table, first.0, first.1)
117+
case disjoint_set.size(new_table, first.0) {
118+
Ok(s) if s == size -> first
119+
_ -> build_circuit_sized(size, new_table, rest)
143120
}
144121
}
145122
}

src/internal/disjoint_set.gleam

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import gleam/dict
2+
import gleam/list
3+
4+
pub opaque type DisjointSet(a) {
5+
DisjointSet(parents: dict.Dict(a, a), sizes: dict.Dict(a, Int))
6+
}
7+
8+
pub fn from_list(items: List(a)) -> DisjointSet(a) {
9+
let #(parents, sizes) =
10+
list.fold(items, #(dict.new(), dict.new()), fn(acc, item) {
11+
#(dict.insert(acc.0, item, item), dict.insert(acc.1, item, 1))
12+
})
13+
DisjointSet(parents:, sizes:)
14+
}
15+
16+
pub fn find(dj: DisjointSet(a), item: a) -> Result(#(DisjointSet(a), a), Nil) {
17+
case dict.get(dj.parents, item) {
18+
Ok(parent) -> {
19+
case parent == item {
20+
True -> Ok(#(dj, item))
21+
False -> {
22+
// Path compression to make each item point to a single representative set parent
23+
case find(dj, parent) {
24+
Error(_) -> Error(Nil)
25+
Ok(#(dj, rep_parent)) -> {
26+
let updated_dj =
27+
DisjointSet(
28+
parents: dict.insert(dj.parents, item, rep_parent),
29+
sizes: dj.sizes,
30+
)
31+
Ok(#(updated_dj, rep_parent))
32+
}
33+
}
34+
}
35+
}
36+
}
37+
Error(_) -> Error(Nil)
38+
}
39+
}
40+
41+
pub fn union(dj: DisjointSet(a), x: a, y: a) -> Result(DisjointSet(a), Nil) {
42+
case find(dj, x) {
43+
Error(Nil) -> Error(Nil)
44+
Ok(#(dj, root_x)) -> {
45+
case find(dj, y) {
46+
Error(Nil) -> Error(Nil)
47+
Ok(#(dj, root_y)) if root_x != root_y -> {
48+
case dict.get(dj.sizes, root_x), dict.get(dj.sizes, root_y) {
49+
Ok(size_x), Ok(size_y) if size_x < size_y -> {
50+
Ok(DisjointSet(
51+
dict.insert(dj.parents, root_x, root_y),
52+
dict.insert(dj.sizes, root_y, size_x + size_y),
53+
))
54+
}
55+
Ok(size_x), Ok(size_y) -> {
56+
Ok(DisjointSet(
57+
dict.insert(dj.parents, root_y, root_x),
58+
dict.insert(dj.sizes, root_x, size_x + size_y),
59+
))
60+
}
61+
_, _ -> Error(Nil)
62+
}
63+
}
64+
_ -> Ok(dj)
65+
}
66+
}
67+
}
68+
}
69+
70+
pub fn size(dj: DisjointSet(a), item: a) -> Result(Int, Nil) {
71+
case find(dj, item) {
72+
Error(Nil) -> Error(Nil)
73+
Ok(#(dj, rep_parent)) -> dict.get(dj.sizes, rep_parent)
74+
}
75+
}
76+
77+
pub fn to_list(dj: DisjointSet(a), item: a) -> Result(List(a), Nil) {
78+
case find(dj, item) {
79+
Error(Nil) -> Error(Nil)
80+
Ok(#(dj, rep_parent)) -> {
81+
dict.keys(dj.parents)
82+
|> list.filter(fn(i) {
83+
case find(dj, i) {
84+
Ok(#(_, p)) if p == rep_parent -> True
85+
_ -> False
86+
}
87+
})
88+
|> Ok
89+
}
90+
}
91+
}
92+
93+
pub fn setlist(dj: DisjointSet(a)) -> List(a) {
94+
dict.keys(dj.parents)
95+
|> list.filter(fn(item) {
96+
case find(dj, item) {
97+
Ok(#(_, p)) if p == item -> True
98+
_ -> False
99+
}
100+
})
101+
}

test/aoc2025_test.gleam

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import gleam/time/duration
22
import gleeunit
3-
import gleeunit/should
43
import internal/aoc_utils
54
import internal/point
65

@@ -9,30 +8,28 @@ pub fn main() {
98
}
109

1110
pub fn solution_or_error_test() {
12-
aoc_utils.solution_or_error(Ok("This is good"))
13-
|> should.equal("This is good")
11+
assert aoc_utils.solution_or_error(Ok("This is good")) == "This is good"
1412

15-
aoc_utils.solution_or_error(Error("This is bad"))
16-
|> should.equal("ERROR: This is bad")
13+
assert aoc_utils.solution_or_error(Error("This is bad"))
14+
== "ERROR: This is bad"
1715
}
1816

1917
pub fn chunk_up_test() {
20-
["aaa", "bbb", "ccc", "", "ddd", "eee", "", "", "fff"]
21-
|> aoc_utils.chunk_around_empty_strings()
22-
|> should.equal([["aaa", "bbb", "ccc"], ["ddd", "eee"], ["fff"]])
18+
assert {
19+
["aaa", "bbb", "ccc", "", "ddd", "eee", "", "", "fff"]
20+
|> aoc_utils.chunk_around_empty_strings()
21+
}
22+
== [["aaa", "bbb", "ccc"], ["ddd", "eee"], ["fff"]]
2323
}
2424

2525
pub fn point_addition_test() {
26-
point.add(#(2, 1), #(-1, 1))
27-
|> should.equal(#(1, 2))
26+
assert point.add(#(2, 1), #(-1, 1)) == #(1, 2)
2827
}
2928

3029
pub fn point_multiplication_test() {
31-
point.mul(#(2, 5), 5)
32-
|> should.equal(#(10, 25))
30+
assert point.mul(#(2, 5), 5) == #(10, 25)
3331

34-
point.mul(#(2, 5), 0)
35-
|> should.equal(#(0, 0))
32+
assert point.mul(#(2, 5), 0) == #(0, 0)
3633
}
3734

3835
pub fn duration_to_string_test() {

0 commit comments

Comments
 (0)