Skip to content

Commit c9bab33

Browse files
committed
feat(data structures, heaps): min cost to connect sticks
1 parent 8c05de3 commit c9bab33

File tree

9 files changed

+250
-2
lines changed

9 files changed

+250
-2
lines changed

datastructures/trees/heaps/binary/min_heap/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def __init__(self, heap: List[HeapNode] = None):
1212
self.idx_of_element = {}
1313
if heap is None:
1414
heap = []
15+
self.heap: List[HeapNode] = []
1516
self.heap = self.build_heap(heap)
1617
heapify(self.heap)
1718

@@ -97,7 +98,7 @@ def insert_data(self, data: Any):
9798
"""
9899
if data is None:
99100
raise TypeError("Data item can not be None")
100-
node = HeapNode(data, data.__name__)
101+
node = HeapNode(data)
101102
self.heap.append(node)
102103
self.__bubble_up(len(self.heap) - 1)
103104

datastructures/trees/heaps/node.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Optional
12
from datastructures.trees.binary.tree import BinaryTreeNode, T
23

34

@@ -7,10 +8,31 @@ class HeapNode(BinaryTreeNode):
78
a Binary Tree Node which exhibit similar properties.
89
"""
910

10-
def __init__(self, data: T, key: T):
11+
def __init__(self, data: T, key: Optional[T] = None):
1112
super().__init__(data)
1213
self.key = key
1314

1415
@property
1516
def name(self):
1617
return self.__class__.__name__
18+
19+
def __eq__(self, other: 'HeapNode') -> bool:
20+
return self.data == other.data
21+
22+
def __lt__(self, other: 'HeapNode') -> bool:
23+
return self.data < other.data
24+
25+
def __gt__(self, other: 'HeapNode') -> bool:
26+
return self.data > other.data
27+
28+
def __le__(self, other: 'HeapNode') -> bool:
29+
return self.data <= other.data
30+
31+
def __ge__(self, other: 'HeapNode') -> bool:
32+
return self.data >= other.data
33+
34+
def __ne__(self, other: 'HeapNode') -> bool:
35+
return self.data != other.data
36+
37+
def __hash__(self) -> int:
38+
return hash(self.data)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Minimum Cost to Connect Sticks
2+
3+
You are given a set of sticks with positive integer lengths represented as an array, sticks, where sticks[i] denotes the
4+
length of the ith stick.
5+
6+
You can connect any two sticks into one stick at a cost equal to the sum of their lengths. Once two sticks are combined,
7+
they form a new stick whose length is the sum of the two original sticks. This process continues until there is only one
8+
stick remaining.
9+
10+
Your task is to determine the minimum cost required to connect all the sticks into a single stick.
11+
12+
Constraints:
13+
14+
- 1 ≤ sticks.length ≤ 10^3
15+
- 1 ≤ sticks[i] ≤ 10^3
16+
17+
## Examples
18+
19+
![Example 1](./images/examples/min_cost_to_connect_sticks_1.png)
20+
![Example 2](./images/examples/min_cost_to_connect_sticks_2.png)
21+
![Example 3](./images/examples/min_cost_to_connect_sticks_3.png)
22+
![Example 4](./images/examples/min_cost_to_connect_sticks_4.png)
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from typing import List
2+
from heapq import heapify, heappop, heappush
3+
from datastructures.trees.heaps.binary.min_heap import MinHeap, HeapNode
4+
5+
6+
def connect_sticks(sticks: List[int]) -> int:
7+
"""
8+
Calculates the minimum cost to connect all sticks.
9+
10+
The function takes a list of integers, representing the length of each stick.
11+
It returns an integer representing the minimum cost to connect all sticks.
12+
13+
The cost to connect two sticks is the sum of their lengths.
14+
15+
For example, given the list [3, 4, 5], the minimum cost to connect all sticks is 12 (3 + 4 + 5).
16+
17+
The time complexity of this function is O(nlogn), where n is the number of sticks.
18+
The space complexity of this function is O(1), because the heap is built in place from the input list
19+
"""
20+
if len(sticks) <= 1:
21+
return 0
22+
23+
# Create a min heap from the list of sticks
24+
heapify(sticks)
25+
26+
# Initialize the total cost to zero
27+
total_cost = 0
28+
29+
# While there are more than one sticks left
30+
while len(sticks) > 1:
31+
# Extract the two smallest sticks from the heap
32+
first = heappop(sticks)
33+
second = heappop(sticks)
34+
# Calculate the cost to connect the two sticks
35+
cost = first + second
36+
# Add the cost to the total cost
37+
total_cost += cost
38+
# Push the connected stick back into the heap
39+
heappush(sticks, cost)
40+
41+
# Return the total cost
42+
return total_cost
43+
44+
45+
def connect_sticks_2(sticks: List[int]) -> int:
46+
"""
47+
Calculates the minimum cost to connect all sticks.
48+
49+
The function takes a list of integers, representing the length of each stick.
50+
It returns an integer representing the minimum cost to connect all sticks.
51+
52+
The cost to connect two sticks is the sum of their lengths.
53+
54+
For example, given the list [3, 4, 5], the minimum cost to connect all sticks is 12 (3 + 4 + 5).
55+
56+
The time complexity of this function is O(nlogn), where n is the number of sticks.
57+
The space complexity of this function is O(1), because the heap is built in place from the input list
58+
"""
59+
if len(sticks) <= 1:
60+
return 0
61+
62+
nodes: List[HeapNode] = []
63+
for stick in sticks:
64+
node = HeapNode(stick)
65+
nodes.append(node)
66+
67+
# Create a min heap from the list of sticks
68+
min_heap = MinHeap(nodes)
69+
70+
# Initialize the total cost to zero
71+
total_cost = 0
72+
73+
# While there are more than one sticks left
74+
while len(min_heap) > 1:
75+
# Extract the two smallest sticks from the heap
76+
first = min_heap.remove_min()
77+
second = min_heap.remove_min()
78+
# Calculate the cost to connect the two sticks
79+
cost = first.data + second.data
80+
# Add the cost to the total cost
81+
total_cost += cost
82+
# Push the connected stick back into the heap
83+
min_heap.insert_data(cost)
84+
85+
# Return the total cost
86+
return total_cost
88.4 KB
Loading
110 KB
Loading
58.5 KB
Loading
90.4 KB
Loading
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import unittest
2+
from . import connect_sticks, connect_sticks_2
3+
4+
5+
class MinCostConnectSticksTestCase(unittest.TestCase):
6+
def test_1(self):
7+
sticks = [3,5,4]
8+
expected = 19
9+
actual = connect_sticks(sticks)
10+
self.assertEqual(expected, actual)
11+
12+
def test_2(self):
13+
sticks = [2,9,4,6]
14+
expected = 39
15+
actual = connect_sticks(sticks)
16+
self.assertEqual(expected, actual)
17+
18+
def test_3(self):
19+
sticks = [23]
20+
expected = 0
21+
actual = connect_sticks(sticks)
22+
self.assertEqual(expected, actual)
23+
24+
def test_4(self):
25+
sticks = [3,3,3]
26+
expected = 15
27+
actual = connect_sticks(sticks)
28+
self.assertEqual(expected, actual)
29+
30+
def test_5(self):
31+
sticks = [1, 10, 3, 3, 3]
32+
expected = 40
33+
actual = connect_sticks(sticks)
34+
self.assertEqual(expected, actual)
35+
36+
def test_6(self):
37+
sticks = [7,10,16]
38+
expected = 50
39+
actual = connect_sticks(sticks)
40+
self.assertEqual(expected, actual)
41+
42+
def test_7(self):
43+
sticks = [5,120,7,30,10]
44+
expected = 258
45+
actual = connect_sticks(sticks)
46+
self.assertEqual(expected, actual)
47+
48+
def test_8(self):
49+
sticks = [100,200,300,400,500]
50+
expected = 3300
51+
actual = connect_sticks(sticks)
52+
self.assertEqual(expected, actual)
53+
54+
def test_9(self):
55+
sticks = [20,20,20,20]
56+
expected = 160
57+
actual = connect_sticks(sticks)
58+
self.assertEqual(expected, actual)
59+
60+
class MinCostConnectSticks2TestCase(unittest.TestCase):
61+
def test_1(self):
62+
sticks = [3,5,4]
63+
expected = 19
64+
actual = connect_sticks_2(sticks)
65+
self.assertEqual(expected, actual)
66+
67+
def test_2(self):
68+
sticks = [2,9,4,6]
69+
expected = 39
70+
actual = connect_sticks_2(sticks)
71+
self.assertEqual(expected, actual)
72+
73+
def test_3(self):
74+
sticks = [23]
75+
expected = 0
76+
actual = connect_sticks_2(sticks)
77+
self.assertEqual(expected, actual)
78+
79+
def test_4(self):
80+
sticks = [3,3,3]
81+
expected = 15
82+
actual = connect_sticks_2(sticks)
83+
self.assertEqual(expected, actual)
84+
85+
def test_5(self):
86+
sticks = [1, 10, 3, 3, 3]
87+
expected = 40
88+
actual = connect_sticks_2(sticks)
89+
self.assertEqual(expected, actual)
90+
91+
def test_6(self):
92+
sticks = [7,10,16]
93+
expected = 50
94+
actual = connect_sticks_2(sticks)
95+
self.assertEqual(expected, actual)
96+
97+
def test_7(self):
98+
sticks = [5,120,7,30,10]
99+
expected = 258
100+
actual = connect_sticks_2(sticks)
101+
self.assertEqual(expected, actual)
102+
103+
def test_8(self):
104+
sticks = [100,200,300,400,500]
105+
expected = 3300
106+
actual = connect_sticks_2(sticks)
107+
self.assertEqual(expected, actual)
108+
109+
def test_9(self):
110+
sticks = [20,20,20,20]
111+
expected = 160
112+
actual = connect_sticks_2(sticks)
113+
self.assertEqual(expected, actual)
114+
115+
116+
if __name__ == '__main__':
117+
unittest.main()

0 commit comments

Comments
 (0)