|
| 1 | +package dataStructures.segmentTree; |
| 2 | + |
| 3 | +public class SegmentTree { |
| 4 | + private SegmentTreeNode root; |
| 5 | + private int[] array; |
| 6 | + |
| 7 | + private class SegmentTreeNode { |
| 8 | + private SegmentTreeNode leftChild; // left child |
| 9 | + private SegmentTreeNode rightChild; // right child |
| 10 | + private int start; // start idx of range captured |
| 11 | + private int end; // end idx of range captured |
| 12 | + private int sum; // sum of all elements between start and end index inclusive |
| 13 | + |
| 14 | + /** |
| 15 | + * Constructor |
| 16 | + * @param leftChild |
| 17 | + * @param rightChild |
| 18 | + * @param start |
| 19 | + * @param end |
| 20 | + * @param sum |
| 21 | + */ |
| 22 | + public SegmentTreeNode(SegmentTreeNode leftChild, SegmentTreeNode rightChild, int start, int end, int sum) { |
| 23 | + this.leftChild = leftChild; |
| 24 | + this.rightChild = rightChild |
| 25 | + this.start = start; |
| 26 | + this.end = end; |
| 27 | + this.sum = sum; |
| 28 | + } |
| 29 | + } |
| 30 | + |
| 31 | + public SegmentTree(int[] nums) { |
| 32 | + root = buildTree(nums, 0, nums.length - 1); |
| 33 | + array = nums; |
| 34 | + } |
| 35 | + |
| 36 | + private SegmentTreeNode buildTree(int[] nums, int start, int end) { |
| 37 | + if (start == end) { |
| 38 | + return new SegmentTreeNode(null, null, start, end, nums[start]); |
| 39 | + } |
| 40 | + int mid = start + (end-start) / 2; |
| 41 | + SegmentTreeNode left = buildTree(nums, start, mid); |
| 42 | + SegmentTreeNode right = buildTree(nums, mid + 1, end); |
| 43 | + return new SegmentTreeNode(left, right, start, end, left.sum + right.sum); |
| 44 | + } |
| 45 | + |
| 46 | + /** |
| 47 | + * Queries the sum of all values in the specified range. |
| 48 | + * @param leftEnd |
| 49 | + * @param rightEnd |
| 50 | + * @return the sum. |
| 51 | + */ |
| 52 | + public int query(int leftEnd, int rightEnd) { |
| 53 | + return query(root, leftEnd, rightEnd); |
| 54 | + } |
| 55 | + |
| 56 | + private int query(SegmentTreeNode node, int leftEnd, int rightEnd) { |
| 57 | + // this is the case when: |
| 58 | + // start end |
| 59 | + // range query: ^ ^ --> so simply capture the sum at this node! |
| 60 | + if (leftEnd <= node.start && node.end <= rightEnd) { |
| 61 | + return node.sum; |
| 62 | + } |
| 63 | + int rangeSum = 0; |
| 64 | + int mid = node.start + (node.end - node.start) / 2; |
| 65 | + // Consider the 3 possible kinds of range queries |
| 66 | + // start mid end |
| 67 | + // poss 1: ^ ^ |
| 68 | + // poss 2: ^ ^ |
| 69 | + // poss 3: ^ ^ |
| 70 | + if (leftEnd <= mid) { |
| 71 | + rangeSum += query(node.leftChild, leftEnd, Math.min(rightEnd, mid)); // poss1 / poss2 |
| 72 | + } |
| 73 | + if (mid + 1 <= rightEnd) { |
| 74 | + rangeSum += query(node.rightChild, Math.max(leftEnd, mid + 1), rightEnd); // poss2 / poss2 |
| 75 | + } |
| 76 | + return rangeSum; |
| 77 | + } |
| 78 | + |
| 79 | + /** |
| 80 | + * Updates the segment tree based on updates to the array at the specified index with the specified value. |
| 81 | + * @param idx |
| 82 | + * @param val |
| 83 | + */ |
| 84 | + public void update(int idx, int val) { |
| 85 | + if (idx > array.length) { |
| 86 | + return; |
| 87 | + } |
| 88 | + update(root, idx, val); |
| 89 | + } |
| 90 | + |
| 91 | + private void update(SegmentTreeNode node, int idx, int val) { |
| 92 | + if (node.start == node.end && node.start == idx) { |
| 93 | + node.sum = val; // previously, node held a single value; now updated |
| 94 | + } |
| 95 | + int mid = node.start + (node.end - node.start) / 2; |
| 96 | + if (idx <= mid) { |
| 97 | + update(node.leftChild, idx, val); |
| 98 | + } else { |
| 99 | + update(node.rightChild, idx, val); |
| 100 | + } |
| 101 | + node.sum = node.leftChild.sum + node.rightChild.sum; // propagate updates up |
| 102 | + } |
| 103 | +} |
0 commit comments