Skip to content

Commit 36f9ac2

Browse files
committed
refactor: aoc 2024 replace networkx with rustworkx
1 parent 7bfb4a2 commit 36f9ac2

File tree

1 file changed

+58
-14
lines changed

1 file changed

+58
-14
lines changed

_2024/solutions/day23.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
possible fully-connected group representing the actual LAN party location.
1313
1414
The module contains a Solution class that inherits from SolutionBase and
15-
implements methods using NetworkX to find cliques in the connection graph.
15+
implements methods using rustworkx with custom clique detection.
1616
"""
1717

1818
from itertools import combinations
1919

20-
import networkx as nx
20+
import rustworkx as rx
2121

2222
from aoc.models.base import SolutionBase
2323

@@ -29,15 +29,15 @@ class Solution(SolutionBase):
2929
- Part 1: Count trios of interconnected computers with Chief Historian hint
3030
- Part 2: Find the largest fully-connected group (maximum clique)
3131
32-
The solution uses NetworkX graph library to model the network and efficiently
33-
find all cliques, which represent groups of computers that can all directly
34-
communicate with each other.
32+
The solution uses rustworkx graph library to model the network with manual
33+
clique finding implementation, which represent groups of computers that can
34+
all directly communicate with each other.
3535
"""
3636

37-
def construct_graph(self, data: list[str]) -> nx.Graph:
37+
def construct_graph(self, data: list[str]) -> rx.PyGraph:
3838
"""Construct an undirected graph from connection data.
3939
40-
Creates a NetworkX graph where nodes represent computers and edges
40+
Creates a rustworkx graph where nodes represent computers and edges
4141
represent direct network connections between them. Each connection
4242
is bidirectional.
4343
@@ -46,16 +46,60 @@ def construct_graph(self, data: list[str]) -> nx.Graph:
4646
4747
Returns
4848
-------
49-
NetworkX Graph object representing the computer network
49+
rustworkx PyGraph object representing the computer network
5050
"""
51-
graph: nx.Graph = nx.Graph()
52-
edges = []
51+
graph = rx.PyGraph()
52+
node_map = {}
53+
5354
for line in data:
5455
source, target = line.split("-")
55-
edges.append((source, target))
56-
graph.add_edges_from(edges)
56+
57+
if source not in node_map:
58+
node_map[source] = graph.add_node(source)
59+
if target not in node_map:
60+
node_map[target] = graph.add_node(target)
61+
62+
graph.add_edge(node_map[source], node_map[target], None)
63+
5764
return graph
5865

66+
def _bron_kerbosch_recursive(
67+
self, graph: rx.PyGraph, r: set[int], p: set[int], x: set[int], cliques: list[list[str]]
68+
) -> None:
69+
"""Recursive helper for Bron-Kerbosch algorithm.
70+
71+
Args:
72+
graph: rustworkx PyGraph to search
73+
r: Current clique being built
74+
p: Candidate nodes to extend clique
75+
x: Already processed nodes
76+
cliques: List to accumulate found cliques
77+
"""
78+
if not p and not x:
79+
cliques.append([graph[node] for node in r])
80+
return
81+
82+
for v in list(p):
83+
neighbors = set(graph.neighbors(v))
84+
self._bron_kerbosch_recursive(graph, r | {v}, p & neighbors, x & neighbors, cliques)
85+
p.remove(v)
86+
x.add(v)
87+
88+
def find_cliques(self, graph: rx.PyGraph) -> list[list[str]]:
89+
"""Find all maximal cliques using Bron-Kerbosch algorithm.
90+
91+
Args:
92+
graph: rustworkx PyGraph to search
93+
94+
Returns
95+
-------
96+
List of maximal cliques, where each clique is a list of node labels
97+
"""
98+
cliques: list[list[str]] = []
99+
all_nodes = set(graph.node_indices())
100+
self._bron_kerbosch_recursive(graph, set(), all_nodes, set(), cliques)
101+
return cliques
102+
59103
def part1(self, data: list[str]) -> int:
60104
"""Count sets of three interconnected computers including the Chief Historian.
61105
@@ -73,7 +117,7 @@ def part1(self, data: list[str]) -> int:
73117
graph = self.construct_graph(data)
74118
teacher_cliques = [
75119
clique
76-
for clique in nx.find_cliques(graph)
120+
for clique in self.find_cliques(graph)
77121
if len(clique) >= 3 and any(node.startswith("t") for node in clique)
78122
]
79123

@@ -101,5 +145,5 @@ def part2(self, data: list[str]) -> str:
101145
Password string of alphabetically sorted computer IDs joined by commas
102146
"""
103147
graph = self.construct_graph(data)
104-
largest_clique = max(nx.find_cliques(graph), key=len)
148+
largest_clique = max(self.find_cliques(graph), key=len)
105149
return ",".join(sorted(largest_clique))

0 commit comments

Comments
 (0)