|
2 | 2 | # space complexity: O(n) |
3 | 3 | from typing import List |
4 | 4 |
|
5 | | -class TreeNode: |
6 | | - def __init__(self, start, end, val=0, left=None, right=None): |
7 | | - self.val = val |
8 | | - self.start = start |
9 | | - self.end = end |
10 | | - self.left = left |
11 | | - self.right = right |
12 | 5 |
|
13 | 6 | class SegmentTree: |
14 | | - def __init__(self, n): |
15 | | - self.root = self.build(0, n - 1) |
16 | | - |
17 | | - def build(self, l, r): |
18 | | - if l == r: |
19 | | - return TreeNode(l, r, 0) |
20 | | - leftTree = self.build(l, (l + r) // 2) |
21 | | - rightTree = self.build((l + r) // 2 + 1, r) |
22 | | - return TreeNode(l, r, 0, leftTree, rightTree) |
23 | | - |
24 | | - def update(self, root, index, value): |
25 | | - if root.start == root.end == index: |
26 | | - root.val += value |
27 | | - return root.val |
28 | | - if root.start > index or root.end < index: |
29 | | - return root.val |
30 | | - root.val = self.update(root.left, index, value) + self.update(root.right, index, value) |
31 | | - return root.val |
32 | | - |
33 | | - def query(self, root, l, r) -> int: |
34 | | - if root.start > r or root.end < l: |
| 7 | + def __init__(self, size: int): |
| 8 | + self.n = size |
| 9 | + self.segTree = [0 for _ in range(4 * size)] |
| 10 | + |
| 11 | + def update(self, idx: int, val: int, nodeIdx=0, start=0, end=None): |
| 12 | + if end is None: |
| 13 | + end = self.n - 1 |
| 14 | + if start == end: |
| 15 | + self.segTree[nodeIdx] += val |
| 16 | + return |
| 17 | + |
| 18 | + mid = (start + end) // 2 |
| 19 | + leftIdx = 2 * nodeIdx + 1 |
| 20 | + rightIdx = 2 * nodeIdx + 2 |
| 21 | + |
| 22 | + if idx <= mid: |
| 23 | + self.update(idx, val, leftIdx, start, mid) |
| 24 | + else: |
| 25 | + self.update(idx, val, rightIdx, mid + 1, end) |
| 26 | + |
| 27 | + self.segTree[nodeIdx] = self.segTree[leftIdx] + self.segTree[rightIdx] |
| 28 | + |
| 29 | + def query(self, left: int, right: int, nodeIdx=0, start=0, end=None) -> int: |
| 30 | + if end is None: |
| 31 | + end = self.n - 1 |
| 32 | + if right < start or left > end: |
35 | 33 | return 0 |
36 | | - if l <= root.start and root.end <= r: |
37 | | - return root.val |
38 | | - return self.query(root.left, l, r) + self.query(root.right, l, r) |
| 34 | + if left <= start and end <= right: |
| 35 | + return self.segTree[nodeIdx] |
| 36 | + |
| 37 | + mid = (start + end) // 2 |
| 38 | + leftIdx = 2 * nodeIdx + 1 |
| 39 | + rightIdx = 2 * nodeIdx + 2 |
| 40 | + return self.query(left, right, leftIdx, start, mid) + self.query(left, right, rightIdx, mid + 1, end) |
| 41 | + |
39 | 42 |
|
40 | 43 | class Solution: |
41 | 44 | def countSmaller(self, nums: List[int]) -> List[int]: |
42 | 45 | if not nums: |
43 | 46 | return [] |
| 47 | + |
44 | 48 | sortedNums = sorted(set(nums)) |
45 | 49 | rankMap = {val: idx for idx, val in enumerate(sortedNums)} |
| 50 | + |
46 | 51 | tree = SegmentTree(len(sortedNums)) |
47 | 52 | result = [] |
48 | | - for n in reversed(nums): |
49 | | - idx = rankMap[n] |
50 | | - result.append(tree.query(tree.root, 0, idx - 1)) |
51 | | - tree.update(tree.root, idx, 1) |
| 53 | + |
| 54 | + for num in reversed(nums): |
| 55 | + idx = rankMap[num] |
| 56 | + result.append(tree.query(0, idx - 1)) |
| 57 | + tree.update(idx, 1) |
| 58 | + |
52 | 59 | return result[::-1] |
53 | 60 |
|
| 61 | + |
54 | 62 | nums = [5, 2, 6, 1] |
55 | | -print(Solution().countSmaller(nums)) |
| 63 | +print(Solution().countSmaller(nums)) |
56 | 64 | nums = [-1] |
57 | | -print(Solution().countSmaller(nums)) |
| 65 | +print(Solution().countSmaller(nums)) |
58 | 66 | nums = [-1, -1] |
59 | | -print(Solution().countSmaller(nums)) |
| 67 | +print(Solution().countSmaller(nums)) |
0 commit comments