Skip to content

Commit b0ac9a5

Browse files
committed
Time: 2589 ms (14.27%), Space: 38.4 MB (45.56%) - LeetHub
1 parent 6c3bfab commit b0ac9a5

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# time complexity: O(nlogn)
2+
# space complexity: O(n)
3+
from typing import List
4+
5+
6+
class SegmentTree:
7+
def __init__(self, size: int):
8+
self.n = size
9+
self.segTree = [0 for _ in range(4 * size)]
10+
11+
def update(self, idx: int, val: int, nodeIdx=0, start=0, end=None):
12+
if end is None:
13+
end = self.n - 1
14+
if start == end:
15+
self.segTree[nodeIdx] += val
16+
return
17+
18+
mid = (start + end) // 2
19+
leftIdx = 2 * nodeIdx + 1
20+
rightIdx = 2 * nodeIdx + 2
21+
22+
if idx <= mid:
23+
self.update(idx, val, leftIdx, start, mid)
24+
else:
25+
self.update(idx, val, rightIdx, mid + 1, end)
26+
27+
self.segTree[nodeIdx] = self.segTree[leftIdx] + self.segTree[rightIdx]
28+
29+
def query(self, left: int, right: int, nodeIdx=0, start=0, end=None) -> int:
30+
if end is None:
31+
end = self.n - 1
32+
if right < start or left > end:
33+
return 0
34+
if left <= start and end <= right:
35+
return self.segTree[nodeIdx]
36+
37+
mid = (start + end) // 2
38+
leftIdx = 2 * nodeIdx + 1
39+
rightIdx = 2 * nodeIdx + 2
40+
return self.query(left, right, leftIdx, start, mid) + self.query(left, right, rightIdx, mid + 1, end)
41+
42+
43+
class Solution:
44+
def countSmaller(self, nums: List[int]) -> List[int]:
45+
if not nums:
46+
return []
47+
48+
sortedNums = sorted(set(nums))
49+
rankMap = {val: idx for idx, val in enumerate(sortedNums)}
50+
51+
tree = SegmentTree(len(sortedNums))
52+
result = []
53+
54+
for num in reversed(nums):
55+
idx = rankMap[num]
56+
result.append(tree.query(0, idx - 1))
57+
tree.update(idx, 1)
58+
59+
return result[::-1]
60+
61+
62+
nums = [5, 2, 6, 1]
63+
print(Solution().countSmaller(nums))
64+
nums = [-1]
65+
print(Solution().countSmaller(nums))
66+
nums = [-1, -1]
67+
print(Solution().countSmaller(nums))

0 commit comments

Comments
 (0)