Skip to content

Commit 569427f

Browse files
committed
add tests and clean code
1 parent a8b3c41 commit 569427f

File tree

1 file changed

+78
-61
lines changed

1 file changed

+78
-61
lines changed

graphs/travelling_salesman_problem.py

Lines changed: 78 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -20,27 +20,54 @@ class TSPEdge(Generic[T]):
2020
weight: float
2121

2222
def __str__(self) -> str:
23+
"""
24+
Examples:
25+
>>> tsp_edge = TSPEdge.from_3_tuple(1, 2, 0.5)
26+
>>> str(tsp_edge)
27+
'(frozenset({1, 2}), 0.5)'
28+
"""
2329
return f"({self.vertices}, {self.weight})"
2430

25-
def __post_init__(self):
31+
def __post_init__(self) -> None:
2632
# Ensures that there is no loop in a vertex
2733
if len(self.vertices) != 2:
2834
raise ValueError("frozenset must have exactly 2 elements")
2935

3036
@classmethod
31-
def from_3_tuple(cls, x, y, w) -> "TSPEdge":
37+
def from_3_tuple(cls, vertex_1: T, vertex_2: T, weight: float) -> "TSPEdge":
3238
"""
3339
Construct TSPEdge from a 3-tuple (x, y, w).
3440
x & y are vertices and w is the weight.
41+
42+
Examples:
43+
>>> tsp_edge = TSPEdge.from_3_tuple(1, 2, 0.5)
44+
>>> tsp_edge.vertices
45+
frozenset({1, 2})
46+
>>> tsp_edge.weight
47+
0.5
3548
"""
36-
return cls(frozenset([x, y]), w)
49+
return cls(frozenset([vertex_1, vertex_2]), weight)
3750

3851
def __eq__(self, other: object) -> bool:
52+
"""
53+
Examples:
54+
>>> tsp_edge_1 = TSPEdge.from_3_tuple(1, 2, 0.5)
55+
>>> tsp_edge_2 = TSPEdge.from_3_tuple(2, 1, 0.7)
56+
>>> tsp_edge_1 == tsp_edge_2
57+
True
58+
"""
3959
if not isinstance(other, TSPEdge):
4060
return NotImplemented
4161
return self.vertices == other.vertices
4262

4363
def __add__(self, other: "TSPEdge") -> float:
64+
"""
65+
Examples:
66+
>>> tsp_edge_1 = TSPEdge.from_3_tuple(1, 2, 1.0)
67+
>>> tsp_edge_2 = TSPEdge.from_3_tuple(2, 1, 2.5)
68+
>>> tsp_edge_1 + tsp_edge_2
69+
3.5
70+
"""
4471
return self.weight + other.weight
4572

4673

@@ -187,7 +214,7 @@ def adjacent_tuples(path: list[T]) -> zip:
187214
Returns:
188215
zip: A zip object containing tuples of adjacent vertices.
189216
190-
Examples
217+
Examples:
191218
>>> list(adjacent_tuples([1, 2, 3, 4, 5]))
192219
[(1, 2), (2, 3), (3, 4), (4, 5)]
193220
@@ -209,6 +236,15 @@ def path_weight(path: list[T], tsp_graph: TSPGraph) -> float:
209236
210237
Returns:
211238
float: The total weight of the path.
239+
240+
Examples:
241+
>>> graph = TSPGraph.from_3_tuples((1, 2, 2), (2, 3, 4), (3, 4, 2), (4, 5, 1))
242+
>>> path_weight([1, 2, 3], graph)
243+
6
244+
>>> path_weight([1, 2, 3, 4], graph)
245+
8
246+
>>> path_weight([1, 2, 3, 4, 5], graph)
247+
9
212248
"""
213249
return sum(tsp_graph.get_edge_weight(x, y) for x, y in adjacent_tuples(path))
214250

@@ -228,6 +264,14 @@ def generate_paths(start: T, end: T, tsp_graph: TSPGraph) -> Generator[list[T]]:
228264
229265
Raises:
230266
AssertionError: If start or end is not in the graph, or if they are the same.
267+
268+
Examples:
269+
>>> graph = TSPGraph.from_3_tuples((1, 2, 2), (2, 3, 4), (3, 1, 2))
270+
>>> graph_generator = generate_paths(1, 3, graph)
271+
>>> next(graph_generator)
272+
[1, 2, 3]
273+
>>> next(graph_generator)
274+
[1, 3]
231275
"""
232276

233277
assert start in tsp_graph.vertices
@@ -257,7 +301,9 @@ def dfs(
257301
yield from dfs(start, end, set(), [])
258302

259303

260-
def nearest_neighborhood(tsp_graph: TSPGraph, v, visited_=None) -> list[T] | None:
304+
def nearest_neighborhood(
305+
tsp_graph: TSPGraph, current_vertex: T, visited_: list[T] | None = None
306+
) -> list[T] | None:
261307
"""
262308
Approximates a solution to the Traveling Salesman Problem
263309
using the Nearest Neighbor heuristic.
@@ -269,9 +315,29 @@ def nearest_neighborhood(tsp_graph: TSPGraph, v, visited_=None) -> list[T] | Non
269315
270316
Returns:
271317
list[T] | None: A complete Hamiltonian cycle if possible, otherwise None.
318+
319+
Examples:
320+
>>> edges = [
321+
... ("A", "B", 7), ("A", "D", 1), ("A", "E", 1),
322+
... ("B", "C", 3), ("B", "E", 8), ("C", "E", 2),
323+
... ("C", "D", 6), ("D", "E", 7)
324+
... ]
325+
>>> graph = TSPGraph.from_3_tuples(*edges)
326+
>>> import random
327+
>>> init_v = random.choice(list(graph.vertices))
328+
>>> result = nearest_neighborhood(graph, init_v)
329+
>>> assert result in [
330+
... ['A', 'D', 'C', 'E', 'B', 'A'],
331+
... ['E', 'A', 'D', 'C', 'B', 'E'],
332+
... None
333+
... ]
334+
>>> path_1 = ['A', 'D', 'C', 'E', 'B', 'A']
335+
>>> path_2 = ['E', 'A', 'D', 'C', 'B', 'E']
336+
>>> assert path_weight(path_1, graph) == 24 if result == path_1 else 19 or None
337+
>>> assert path_weight(path_2, graph) == 19 if result == path_2 else 24 or None
272338
"""
273339
# Initialize visited list on first call
274-
visited = visited_ or [v]
340+
visited = visited_ or [current_vertex]
275341

276342
# Base case: if all vertices are visited
277343
if len(visited) == len(tsp_graph.vertices):
@@ -283,72 +349,23 @@ def nearest_neighborhood(tsp_graph: TSPGraph, v, visited_=None) -> list[T] | Non
283349

284350
# Get unvisited neighbors
285351
filtered_neighbors = [
286-
tup for tup in tsp_graph.get_vertex_neighbor_weights(v) if tup[0] not in visited
352+
tup
353+
for tup in tsp_graph.get_vertex_neighbor_weights(current_vertex)
354+
if tup[0] not in visited
287355
]
288356

289357
# If there are unvisited neighbors, continue to the nearest one
290358
if filtered_neighbors:
291359
next_v = min(filtered_neighbors, key=lambda tup: tup[1])[0]
292-
return nearest_neighborhood(tsp_graph, v=next_v, visited_=[*visited, next_v])
360+
return nearest_neighborhood(
361+
tsp_graph, current_vertex=next_v, visited_=[*visited, next_v]
362+
)
293363
else:
294364
# No more neighbors, return None (cannot form a complete tour)
295365
return None
296366

297367

298-
def sample_1():
299-
# Reference: https://graphicmaths.com/computer-science/graph-theory/travelling-salesman-problem/
300-
301-
edges = [
302-
("A", "B", 7),
303-
("A", "D", 1),
304-
("A", "E", 1),
305-
("B", "C", 3),
306-
("B", "E", 8),
307-
("C", "E", 2),
308-
("C", "D", 6),
309-
("D", "E", 7),
310-
]
311-
312-
# Create the graph
313-
graph = TSPGraph.from_3_tuples(*edges)
314-
315-
import random
316-
317-
init_v = random.choice(list(graph.vertices))
318-
optim_path = nearest_neighborhood(graph, init_v)
319-
# optim_path = nearest_neighborhood(graph, 'A')
320-
print(f"Optimal Cycle: {optim_path}")
321-
if optim_path:
322-
print(f"Optimal Weight: {path_weight(optim_path, graph)}")
323-
324-
325-
def sample_2():
326-
# Example 8x8 weight matrix (symmetric, no self-loops)
327-
weights = [
328-
[0, 1, 2, 3, 4, 5, 6, 7],
329-
[1, 0, 8, 9, 10, 11, 12, 13],
330-
[2, 8, 0, 14, 15, 16, 17, 18],
331-
[3, 9, 14, 0, 19, 20, 21, 22],
332-
[4, 10, 15, 19, 0, 23, 24, 25],
333-
[5, 11, 16, 20, 23, 0, 26, 27],
334-
[6, 12, 17, 21, 24, 26, 0, 28],
335-
[7, 13, 18, 22, 25, 27, 28, 0],
336-
]
337-
338-
graph = TSPGraph.from_weights(weights)
339-
340-
import random
341-
342-
init_v = random.choice(list(graph.vertices))
343-
optim_path = nearest_neighborhood(graph, init_v)
344-
print(f"Optimal Cycle: {optim_path}")
345-
if optim_path:
346-
print(f"Optimal Weight: {path_weight(optim_path, graph)}")
347-
348-
349368
if __name__ == "__main__":
350369
import doctest
351370

352371
doctest.testmod()
353-
sample_1()
354-
sample_2()

0 commit comments

Comments
 (0)